From: Mike Bayer Date: Wed, 28 Apr 2021 22:31:51 +0000 (-0400) Subject: Use non-subquery form for empty IN X-Git-Tag: rel_1_4_12~5^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=aba308868544b21bafa0b3435701ddc908654b0a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Use non-subquery form for empty IN Revised the "EMPTY IN" expression to no longer rely upon using a subquery, as this was causing some compatibility and performance problems. The new approach for selected databases takes advantage of using a NULL-returning IN expression combined with the usual "1 != 1" or "1 = 1" expression appended by AND or OR. The expression is now the default for all backends other than SQLite, which still had some compatibility issues regarding tuple "IN" for older SQLite versions. Third party dialects can still override how the "empty set" expression renders by implementing a new compiler method ``def visit_empty_set_op_expr(self, type_, expand_op)``, which takes precedence over the existing ``def visit_empty_set_expr(self, element_types)`` which remains in place. Fixes: #6258 Fixes: #6397 Change-Id: I2df09eb00d2ad3b57039ae48128fdf94641b5e59 --- diff --git a/doc/build/changelog/unreleased_14/6258.rst b/doc/build/changelog/unreleased_14/6258.rst new file mode 100644 index 0000000000..83c36689da --- /dev/null +++ b/doc/build/changelog/unreleased_14/6258.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, sql + :tickets: 6258 6397 + + Revised the "EMPTY IN" expression to no longer rely upon using a subquery, + as this was causing some compatibility and performance problems. The new + approach for selected databases takes advantage of using a NULL-returning + IN expression combined with the usual "1 != 1" or "1 = 1" expression + appended by AND or OR. The expression is now the default for all backends + other than SQLite, which still had some compatibility issues regarding + tuple "IN" for older SQLite versions. + + Third party dialects can still override how the "empty set" expression + renders by implementing a new compiler method + ``def visit_empty_set_op_expr(self, type_, expand_op)``, which takes + precedence over the existing + ``def visit_empty_set_expr(self, element_types)`` which remains in place. + diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index 59d40fef0d..66a556ae0d 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -1299,6 +1299,11 @@ class SQLiteCompiler(compiler.SQLCompiler): self.process(binary.right, **kw), ) + def visit_empty_set_op_expr(self, type_, expand_op): + # slightly old SQLite versions don't seem to be able to handle + # the empty set impl + return self.visit_empty_set_expr(type_) + def visit_empty_set_expr(self, element_types): return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % ( ", ".join("1" for type_ in element_types or [INTEGER()]), diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index b7aba9d747..820fc1bf19 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -561,14 +561,9 @@ class InElementImpl(RoleImpl): return element.self_group(against=operator) elif isinstance(element, elements.BindParameter): - if not element.expanding: - # coercing to expanding at the moment to work with the - # lambda system. not sure if this is the right approach. - # is there a valid use case to send a single non-expanding - # param to IN? check for ARRAY type? - element = element._clone(maintain_key=True) - element.expanding = True - + # previously we were adding expanding flags here but + # we now do this in the compiler where we have more context + # see compiler.py -> _render_in_expr_w_bindparam return element else: return element diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 6168248ff7..e9e05b7e9a 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1903,6 +1903,45 @@ class SQLCompiler(Compiled): binary, override_operator=operators.match_op ) + def visit_in_op_binary(self, binary, operator, **kw): + return self._render_in_expr_w_bindparam(binary, operator, **kw) + + def visit_not_in_op_binary(self, binary, operator, **kw): + return self._render_in_expr_w_bindparam(binary, operator, **kw) + + def _render_in_expr_w_bindparam(self, binary, operator, **kw): + opstring = OPERATORS[operator] + + if isinstance(binary.right, elements.BindParameter): + if not binary.right.expanding or not binary.right.expand_op: + # note that by cloning here, we rely upon the + # _cache_key_bind_match dictionary to resolve + # clones of bindparam() objects to the ones that are + # present in our cache key. + binary.right = binary.right._clone(maintain_key=True) + binary.right.expanding = True + binary.right.expand_op = operator + + return self._generate_generic_binary(binary, opstring, **kw) + + def visit_empty_set_op_expr(self, type_, expand_op): + if expand_op is operators.not_in_op: + if len(type_) > 1: + return "(%s)) OR (1 = 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) OR (1 = 1" + elif expand_op is operators.in_op: + if len(type_) > 1: + return "(%s)) AND (1 != 1" % ( + ", ".join("NULL" for element in type_) + ) + else: + return "NULL) AND (1 != 1" + else: + return self.visit_empty_set_expr(type_) + def visit_empty_set_expr(self, element_types): raise NotImplementedError( "Dialect '%s' does not support empty set expression." @@ -1959,12 +1998,12 @@ class SQLCompiler(Compiled): to_update = [] if parameter.type._is_tuple_type: - replacement_expression = self.visit_empty_set_expr( - parameter.type.types + replacement_expression = self.visit_empty_set_op_expr( + parameter.type.types, parameter.expand_op ) else: - replacement_expression = self.visit_empty_set_expr( - [parameter.type] + replacement_expression = self.visit_empty_set_op_expr( + [parameter.type], parameter.expand_op ) elif isinstance(values[0], (tuple, list)): @@ -3900,6 +3939,9 @@ class StrSQLCompiler(SQLCompiler): for t in extra_froms ) + def visit_empty_set_op_expr(self, type_, expand_op): + return self.visit_empty_set_expr(type_) + def visit_empty_set_expr(self, type_): return "SELECT 1 WHERE 1!=1" diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 696f3b2492..e27b978021 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1411,7 +1411,17 @@ class BindParameter(roles.InElementRole, ColumnElement): self.callable = callable_ self.isoutparam = isoutparam self.required = required + + # indicate an "expanding" parameter; the compiler sets this + # automatically in the compiler _render_in_expr_w_bindparam method + # for an IN expression self.expanding = expanding + + # this is another hint to help w/ expanding and is typically + # set in the compiler _render_in_expr_w_bindparam method for an + # IN expression + self.expand_op = None + self.literal_execute = literal_execute if _is_crud: self._is_crud = True diff --git a/lib/sqlalchemy/testing/suite/test_select.py b/lib/sqlalchemy/testing/suite/test_select.py index 7b35dc3fa3..1614acd3da 100644 --- a/lib/sqlalchemy/testing/suite/test_select.py +++ b/lib/sqlalchemy/testing/suite/test_select.py @@ -1016,163 +1016,246 @@ class ExpandingBoundInTest(fixtures.TablesTest): with config.db.connect() as conn: eq_(conn.execute(select, params).fetchall(), result) - def test_multiple_empty_sets(self): + def test_multiple_empty_sets_bindparam(self): # test that any anonymous aliasing used by the dialect # is fine with duplicates table = self.tables.some_table - stmt = ( select(table.c.id) - .where(table.c.x.in_(bindparam("q", expanding=True))) - .where(table.c.y.in_(bindparam("p", expanding=True))) + .where(table.c.x.in_(bindparam("q"))) + .where(table.c.y.in_(bindparam("p"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": [], "p": []}) - @testing.requires.tuple_in_w_empty - def test_empty_heterogeneous_tuples(self): + def test_multiple_empty_sets_direct(self): + # test that any anonymous aliasing used by the dialect + # is fine with duplicates table = self.tables.some_table - stmt = ( select(table.c.id) - .where( - tuple_(table.c.x, table.c.z).in_( - bindparam("q", expanding=True) - ) - ) + .where(table.c.x.in_([])) + .where(table.c.y.in_([])) .order_by(table.c.id) ) + self._assert_result(stmt, []) + @testing.requires.tuple_in_w_empty + def test_empty_heterogeneous_tuples_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) self._assert_result(stmt, [], params={"q": []}) @testing.requires.tuple_in_w_empty - def test_empty_homogeneous_tuples(self): + def test_empty_heterogeneous_tuples_direct(self): table = self.tables.some_table + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)]) + go([], []) + + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where( - tuple_(table.c.x, table.c.y).in_( - bindparam("q", expanding=True) - ) - ) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": []}) - def test_bound_in_scalar(self): + @testing.requires.tuple_in_w_empty + def test_empty_homogeneous_tuples_direct(self): table = self.tables.some_table + def go(val, expected): + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(val)) + .order_by(table.c.id) + ) + self._assert_result(stmt, expected) + + go([], []) + go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)]) + go([], []) + + def test_bound_in_scalar_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.x.in_(bindparam("q", expanding=True))) + .where(table.c.x.in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]}) - @testing.requires.tuple_in - def test_bound_in_two_tuple(self): + def test_bound_in_scalar_direct(self): table = self.tables.some_table - stmt = ( select(table.c.id) - .where( - tuple_(table.c.x, table.c.y).in_( - bindparam("q", expanding=True) - ) - ) + .where(table.c.x.in_([2, 3, 4])) .order_by(table.c.id) ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + @testing.requires.tuple_in + def test_bound_in_two_tuple_bindparam(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_(bindparam("q"))) + .order_by(table.c.id) + ) self._assert_result( stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]} ) @testing.requires.tuple_in - def test_bound_in_heterogeneous_two_tuple(self): + def test_bound_in_two_tuple_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)])) + .order_by(table.c.id) + ) + self._assert_result(stmt, [(2,), (3,), (4,)]) + + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_bindparam(self): table = self.tables.some_table + stmt = ( + select(table.c.id) + .where(tuple_(table.c.x, table.c.z).in_(bindparam("q"))) + .order_by(table.c.id) + ) + self._assert_result( + stmt, + [(2,), (3,), (4,)], + params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, + ) + @testing.requires.tuple_in + def test_bound_in_heterogeneous_two_tuple_direct(self): + table = self.tables.some_table stmt = ( select(table.c.id) .where( tuple_(table.c.x, table.c.z).in_( - bindparam("q", expanding=True) + [(2, "z2"), (3, "z3"), (4, "z4")] ) ) .order_by(table.c.id) ) - self._assert_result( stmt, [(2,), (3,), (4,)], - params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) @testing.requires.tuple_in - def test_bound_in_heterogeneous_two_tuple_text(self): + def test_bound_in_heterogeneous_two_tuple_text_bindparam(self): + # note this becomes ARRAY if we dont use expanding + # explicitly right now stmt = text( "select id FROM some_table WHERE (x, z) IN :q ORDER BY id" ).bindparams(bindparam("q", expanding=True)) - self._assert_result( stmt, [(2,), (3,), (4,)], params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]}, ) - def test_empty_set_against_integer(self): + def test_empty_set_against_integer_bindparam(self): table = self.tables.some_table - stmt = ( select(table.c.id) - .where(table.c.x.in_(bindparam("q", expanding=True))) + .where(table.c.x.in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": []}) - def test_empty_set_against_integer_negation(self): + def test_empty_set_against_integer_direct(self): table = self.tables.some_table + stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + def test_empty_set_against_integer_negation_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.x.not_in(bindparam("q", expanding=True))) + .where(table.c.x.not_in(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) - def test_empty_set_against_string(self): + def test_empty_set_against_integer_negation_direct(self): table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + def test_empty_set_against_string_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.z.in_(bindparam("q", expanding=True))) + .where(table.c.z.in_(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [], params={"q": []}) - def test_empty_set_against_string_negation(self): + def test_empty_set_against_string_direct(self): table = self.tables.some_table + stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id) + self._assert_result(stmt, []) + def test_empty_set_against_string_negation_bindparam(self): + table = self.tables.some_table stmt = ( select(table.c.id) - .where(table.c.z.not_in(bindparam("q", expanding=True))) + .where(table.c.z.not_in(bindparam("q"))) .order_by(table.c.id) ) - self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []}) - def test_null_in_empty_set_is_false(self, connection): + def test_empty_set_against_string_negation_direct(self): + table = self.tables.some_table + stmt = ( + select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id) + ) + self._assert_result(stmt, [(1,), (2,), (3,), (4,)]) + + def test_null_in_empty_set_is_false_bindparam(self, connection): + stmt = select( + case( + [ + ( + null().in_(bindparam("foo", value=())), + true(), + ) + ], + else_=false(), + ) + ) + in_(connection.execute(stmt).fetchone()[0], (False, 0)) + + def test_null_in_empty_set_is_false_direct(self, connection): stmt = select( case( [ ( - null().in_(bindparam("foo", value=(), expanding=True)), + null().in_([]), true(), ) ], diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index 24a83c9ee3..897c60f003 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -116,6 +116,46 @@ class LambdaElementTest( result = go() eq_(result.all(), [(2,)]) + def test_in_expressions(self, user_address_fixture, connection): + """test #6397. we initially were going to use two different + forms for "empty in" vs. regular "in", but instead we have an + improved substitution for "empty in". regardless, as there's more + going on with these, make sure lambdas work with them including + caching. + + """ + users, _ = user_address_fixture + data = [ + {"id": 1, "name": "u1"}, + {"id": 2, "name": "u2"}, + {"id": 3, "name": "u3"}, + ] + connection.execute(users.insert(), data) + + def go(val): + stmt = lambdas.lambda_stmt(lambda: select(users.c.id)) + stmt += lambda s: s.where(users.c.name.in_(val)) + stmt += lambda s: s.order_by(users.c.id) + return connection.execute(stmt) + + for case in [ + [], + ["u1", "u2"], + ["u3"], + [], + ["u1", "u2"], + ]: + with testing.assertsql.assert_engine(testing.db) as asserter_: + result = go(case) + asserter_.assert_( + CompiledSQL( + "SELECT users.id FROM users WHERE users.name " + "IN ([POSTCOMPILE_val_1]) ORDER BY users.id", + params={"val_1": case}, + ) + ) + eq_(result.all(), [(e["id"],) for e in data if e["name"] in case]) + def test_stale_checker_embedded(self): def go(x):