From b146a0c64144639bf02bafda239238e3a8f5c84d Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Mon, 10 May 2021 22:52:49 -0400 Subject: [PATCH] set bindparam.expanding in coercion again Adjusted the logic added as part of :ticket:`6397` in 1.4.12 so that internal mutation of the :class:`.BindParameter` object occurs within the clause construction phase as it did before, rather than in the compilation phase. In the latter case, the mutation still produced side effects against the incoming construct and additionally could potentially interfere with other internal mutation routines. In order to solve the issue of the correct operator being present on the BindParameter.expand_op, we necessarily have to expand the BinaryExpression._negate() routine to flip the operator on the BindParameter also. Fixes: #6460 Change-Id: I1e53a9aeee4de4fc11af51d7593431532731561b --- doc/build/changelog/changelog_14.rst | 2 +- doc/build/changelog/unreleased_14/6460.rst | 10 +++++ lib/sqlalchemy/sql/coercions.py | 7 +-- lib/sqlalchemy/sql/compiler.py | 22 +-------- lib/sqlalchemy/sql/elements.py | 19 +++++++- lib/sqlalchemy/sql/lambdas.py | 2 + test/sql/test_external_traversal.py | 16 +++++++ test/sql/test_lambdas.py | 37 +++++++++++++++ test/sql/test_operators.py | 52 +++++++++++++++++++--- 9 files changed, 136 insertions(+), 31 deletions(-) create mode 100644 doc/build/changelog/unreleased_14/6460.rst diff --git a/doc/build/changelog/changelog_14.rst b/doc/build/changelog/changelog_14.rst index 0bf29f2864..a236a3551b 100644 --- a/doc/build/changelog/changelog_14.rst +++ b/doc/build/changelog/changelog_14.rst @@ -204,7 +204,7 @@ This document details individual issue-level changes made throughout .. change:: :tags: bug, sql - :tickets: 6258 6397 + :tickets: 6258, 6397 Revised the "EMPTY IN" expression to no longer rely upon using a subquery, as this was causing some compatibility and performance problems. The new diff --git a/doc/build/changelog/unreleased_14/6460.rst b/doc/build/changelog/unreleased_14/6460.rst new file mode 100644 index 0000000000..faeecd4382 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6460.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 6460 + + Adjusted the logic added as part of :ticket:`6397` in 1.4.12 so that + internal mutation of the :class:`.BindParameter` object occurs within the + clause construction phase as it did before, rather than in the compilation + phase. In the latter case, the mutation still produced side effects against + the incoming construct and additionally could potentially interfere with + other internal mutation routines. diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 820fc1bf19..517bfd57dd 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -561,9 +561,10 @@ class InElementImpl(RoleImpl): return element.self_group(against=operator) elif isinstance(element, elements.BindParameter): - # previously we were adding expanding flags here but - # we now do this in the compiler where we have more context - # see compiler.py -> _render_in_expr_w_bindparam + element = element._clone(maintain_key=True) + element.expanding = True + element.expand_op = operator + return element else: return element diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index dedd75f5cb..734e654921 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1904,32 +1904,14 @@ class SQLCompiler(Compiled): binary, override_operator=operators.match_op ) - def visit_in_op_binary(self, binary, operator, **kw): - return self._render_in_expr_w_bindparam(binary, operator, **kw) - def visit_not_in_op_binary(self, binary, operator, **kw): # The brackets are required in the NOT IN operation because the empty # case is handled using the form "(col NOT IN (null) OR 1 = 1)". # The presence of the OR makes the brackets required. - return "(%s)" % self._render_in_expr_w_bindparam( - binary, operator, **kw + return "(%s)" % self._generate_generic_binary( + binary, OPERATORS[operator], **kw ) - def _render_in_expr_w_bindparam(self, binary, operator, **kw): - opstring = OPERATORS[operator] - - if isinstance(binary.right, elements.BindParameter): - if not binary.right.expanding or not binary.right.expand_op: - # note that by cloning here, we rely upon the - # _cache_key_bind_match dictionary to resolve - # clones of bindparam() objects to the ones that are - # present in our cache key. - binary.right = binary.right._clone(maintain_key=True) - binary.right.expanding = True - binary.right.expand_op = operator - - return self._generate_generic_binary(binary, opstring, **kw) - def visit_empty_set_op_expr(self, type_, expand_op): if expand_op is operators.not_in_op: if len(type_) > 1: diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 416a4e82ea..cdb1dbca8a 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -256,6 +256,15 @@ class ClauseElement( return c + def _negate_in_binary(self, negated_op, original_op): + """a hook to allow the right side of a binary expression to respond + to a negation of the binary expression. + + Used for the special case of expanding bind parameter with IN. + + """ + return self + def _with_binary_element_type(self, type_): """in the context of binary expression, convert the type of this object to the one given. @@ -1510,6 +1519,14 @@ class BindParameter(roles.InElementRole, ColumnElement): literal_execute=True, ) + def _negate_in_binary(self, negated_op, original_op): + if self.expand_op is original_op: + bind = self._clone() + bind.expand_op = negated_op + return bind + else: + return self + def _with_binary_element_type(self, type_): c = ClauseElement._clone(self) c.type = type_ @@ -3729,7 +3746,7 @@ class BinaryExpression(ColumnElement): if self.negate is not None: return BinaryExpression( self.left, - self.right, + self.right._negate_in_binary(self.negate, self.operator), self.negate, negate=self.operator, type_=self.type, diff --git a/lib/sqlalchemy/sql/lambdas.py b/lib/sqlalchemy/sql/lambdas.py index ddc4774db8..b3f47252ab 100644 --- a/lib/sqlalchemy/sql/lambdas.py +++ b/lib/sqlalchemy/sql/lambdas.py @@ -270,6 +270,8 @@ class LambdaElement(elements.ClauseElement): bind = bindparam_lookup[thing.key] if thing.expanding: bind.expanding = True + bind.expand_op = thing.expand_op + bind.type = thing.type return bind if self._rec.is_sequence: diff --git a/test/sql/test_external_traversal.py b/test/sql/test_external_traversal.py index 1c76f37fa7..9e829baeab 100644 --- a/test/sql/test_external_traversal.py +++ b/test/sql/test_external_traversal.py @@ -225,6 +225,22 @@ class TraversalTest( dialect="default", ) + def test_expanding_in_bindparam_safe_to_clone(self): + expr = column("x").in_([1, 2, 3]) + + expr2 = expr._clone() + + # shallow copy, bind is used twice + is_(expr.right, expr2.right) + + stmt = and_(expr, expr2) + self.assert_compile( + stmt, "x IN ([POSTCOMPILE_x_1]) AND x IN ([POSTCOMPILE_x_1])" + ) + self.assert_compile( + stmt, "x IN (1, 2, 3) AND x IN (1, 2, 3)", literal_binds=True + ) + def test_traversal_size(self): """Test :ticket:`6304`. diff --git a/test/sql/test_lambdas.py b/test/sql/test_lambdas.py index cdfd92ece5..2de969521e 100644 --- a/test/sql/test_lambdas.py +++ b/test/sql/test_lambdas.py @@ -156,6 +156,43 @@ class LambdaElementTest( ) eq_(result.all(), [(e["id"],) for e in data if e["name"] in case]) + def test_in_expr_compile(self, user_address_fixture): + users, _ = user_address_fixture + + def go(val): + stmt = lambdas.lambda_stmt(lambda: select(users.c.id)) + stmt += lambda s: s.where(users.c.name.in_(val)) + stmt += lambda s: s.order_by(users.c.id) + return stmt + + # note this also requires the type of the bind is copied + self.assert_compile( + go([]), + "SELECT users.id FROM users " + "WHERE users.name IN (NULL) AND (1 != 1) ORDER BY users.id", + literal_binds=True, + ) + self.assert_compile( + go(["u1", "u2"]), + "SELECT users.id FROM users " + "WHERE users.name IN ('u1', 'u2') ORDER BY users.id", + literal_binds=True, + ) + + def test_bind_type(self, user_address_fixture): + users, _ = user_address_fixture + + def go(val): + stmt = lambdas.lambda_stmt(lambda: select(users.c.id)) + stmt += lambda s: s.where(users.c.name == val) + return stmt + + self.assert_compile( + go("u1"), + "SELECT users.id FROM users " "WHERE users.name = 'u1'", + literal_binds=True, + ) + def test_stale_checker_embedded(self): def go(x): diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index f3e5282fd1..984379c6b7 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -1975,15 +1975,20 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): literal_binds=True, ) - @testing.combinations(True, False) - def test_in_empty_tuple(self, is_in): + @testing.combinations(True, False, argnames="is_in") + @testing.combinations(True, False, argnames="negate") + def test_in_empty_tuple(self, is_in, negate): a, b, c = ( column("a", Integer), column("b", String), column("c", LargeBinary), ) t1 = tuple_(a, b, c) - expr = t1.in_([]) if is_in else t1.not_in([]) + + if negate: + expr = ~t1.not_in([]) if is_in else ~t1.in_([]) + else: + expr = t1.in_([]) if is_in else t1.not_in([]) if is_in: self.assert_compile( @@ -2010,10 +2015,15 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): dialect="default_enhanced", ) - @testing.combinations(True, False) - def test_in_empty_single(self, is_in): + @testing.combinations(True, False, argnames="is_in") + @testing.combinations(True, False, argnames="negate") + def test_in_empty_single(self, is_in, negate): a = column("a", Integer) - expr = a.in_([]) if is_in else a.not_in([]) + + if negate: + expr = ~a.not_in([]) if is_in else ~a.in_([]) + else: + expr = a.in_([]) if is_in else a.not_in([]) if is_in: self.assert_compile( @@ -2040,6 +2050,36 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): dialect="default_enhanced", ) + def test_in_self_plus_negated(self): + a = column("a", Integer) + + expr1 = a.in_([5]) + expr2 = ~expr1 + + stmt = and_(expr1, expr2) + self.assert_compile( + stmt, "a IN ([POSTCOMPILE_a_1]) AND (a NOT IN ([POSTCOMPILE_a_2]))" + ) + self.assert_compile( + stmt, "a IN (5) AND (a NOT IN (5))", literal_binds=True + ) + + def test_in_self_plus_negated_empty(self): + a = column("a", Integer) + + expr1 = a.in_([]) + expr2 = ~expr1 + + stmt = and_(expr1, expr2) + self.assert_compile( + stmt, "a IN ([POSTCOMPILE_a_1]) AND (a NOT IN ([POSTCOMPILE_a_2]))" + ) + self.assert_compile( + stmt, + "a IN (NULL) AND (1 != 1) AND (a NOT IN (NULL) OR (1 = 1))", + literal_binds=True, + ) + def test_in_set(self): s = {1, 2, 3} self.assert_compile( -- 2.47.3