]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
set bindparam.expanding in coercion again
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 May 2021 02:52:49 +0000 (22:52 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 11 May 2021 03:42:41 +0000 (23:42 -0400)
Adjusted the logic added as part of :ticket:`6397` in 1.4.12 so that
internal mutation of the :class:`.BindParameter` object occurs within the
clause construction phase as it did before, rather than in the compilation
phase. In the latter case, the mutation still produced side effects against
the incoming construct and additionally could potentially interfere with
other internal mutation routines.

In order to solve the issue of the correct operator being present
on the BindParameter.expand_op, we necessarily have to expand the
BinaryExpression._negate() routine to flip the operator on the
BindParameter also.

Fixes: #6460
Change-Id: I1e53a9aeee4de4fc11af51d7593431532731561b

doc/build/changelog/changelog_14.rst
doc/build/changelog/unreleased_14/6460.rst [new file with mode: 0644]
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/lambdas.py
test/sql/test_external_traversal.py
test/sql/test_lambdas.py
test/sql/test_operators.py

index 0bf29f2864b5678025f3a3db03639c14ba3fbb40..a236a3551bb6c3bcd69828ca316efc6bec5868b5 100644 (file)
@@ -204,7 +204,7 @@ This document details individual issue-level changes made throughout
 
     .. change::
         :tags: bug, sql
-        :tickets: 6258 6397
+        :tickets: 6258, 6397
 
         Revised the "EMPTY IN" expression to no longer rely upon using a subquery,
         as this was causing some compatibility and performance problems. The new
diff --git a/doc/build/changelog/unreleased_14/6460.rst b/doc/build/changelog/unreleased_14/6460.rst
new file mode 100644 (file)
index 0000000..faeecd4
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 6460
+
+    Adjusted the logic added as part of :ticket:`6397` in 1.4.12 so that
+    internal mutation of the :class:`.BindParameter` object occurs within the
+    clause construction phase as it did before, rather than in the compilation
+    phase. In the latter case, the mutation still produced side effects against
+    the incoming construct and additionally could potentially interfere with
+    other internal mutation routines.
index 820fc1bf19f21ebe433dc58fe4c108b5d2650253..517bfd57dd3191303f8f67f985cb74ab244aa474 100644 (file)
@@ -561,9 +561,10 @@ class InElementImpl(RoleImpl):
             return element.self_group(against=operator)
 
         elif isinstance(element, elements.BindParameter):
-            # previously we were adding expanding flags here but
-            # we now do this in the compiler where we have more context
-            # see compiler.py -> _render_in_expr_w_bindparam
+            element = element._clone(maintain_key=True)
+            element.expanding = True
+            element.expand_op = operator
+
             return element
         else:
             return element
index dedd75f5cbe33c3969a17c1c6d8fb69d0a61cf5c..734e6549219e5fc207aaed39cff0f6a9c1614d14 100644 (file)
@@ -1904,32 +1904,14 @@ class SQLCompiler(Compiled):
             binary, override_operator=operators.match_op
         )
 
-    def visit_in_op_binary(self, binary, operator, **kw):
-        return self._render_in_expr_w_bindparam(binary, operator, **kw)
-
     def visit_not_in_op_binary(self, binary, operator, **kw):
         # The brackets are required in the NOT IN operation because the empty
         # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
         # The presence of the OR makes the brackets required.
-        return "(%s)" % self._render_in_expr_w_bindparam(
-            binary, operator, **kw
+        return "(%s)" % self._generate_generic_binary(
+            binary, OPERATORS[operator], **kw
         )
 
-    def _render_in_expr_w_bindparam(self, binary, operator, **kw):
-        opstring = OPERATORS[operator]
-
-        if isinstance(binary.right, elements.BindParameter):
-            if not binary.right.expanding or not binary.right.expand_op:
-                # note that by cloning here, we rely upon the
-                # _cache_key_bind_match dictionary to resolve
-                # clones of bindparam() objects to the ones that are
-                # present in our cache key.
-                binary.right = binary.right._clone(maintain_key=True)
-                binary.right.expanding = True
-                binary.right.expand_op = operator
-
-        return self._generate_generic_binary(binary, opstring, **kw)
-
     def visit_empty_set_op_expr(self, type_, expand_op):
         if expand_op is operators.not_in_op:
             if len(type_) > 1:
index 416a4e82ea24fbd9d45ad7e2850a1cc24de2415c..cdb1dbca8a5aba2ef252ba36245c82e3d0ff834a 100644 (file)
@@ -256,6 +256,15 @@ class ClauseElement(
 
         return c
 
+    def _negate_in_binary(self, negated_op, original_op):
+        """a hook to allow the right side of a binary expression to respond
+        to a negation of the binary expression.
+
+        Used for the special case of expanding bind parameter with IN.
+
+        """
+        return self
+
     def _with_binary_element_type(self, type_):
         """in the context of binary expression, convert the type of this
         object to the one given.
@@ -1510,6 +1519,14 @@ class BindParameter(roles.InElementRole, ColumnElement):
             literal_execute=True,
         )
 
+    def _negate_in_binary(self, negated_op, original_op):
+        if self.expand_op is original_op:
+            bind = self._clone()
+            bind.expand_op = negated_op
+            return bind
+        else:
+            return self
+
     def _with_binary_element_type(self, type_):
         c = ClauseElement._clone(self)
         c.type = type_
@@ -3729,7 +3746,7 @@ class BinaryExpression(ColumnElement):
         if self.negate is not None:
             return BinaryExpression(
                 self.left,
-                self.right,
+                self.right._negate_in_binary(self.negate, self.operator),
                 self.negate,
                 negate=self.operator,
                 type_=self.type,
index ddc4774db857dff0abdafa566af2c0bd0373f6a7..b3f47252abf8b4e01fb0ecc473e64640a1c29dca 100644 (file)
@@ -270,6 +270,8 @@ class LambdaElement(elements.ClauseElement):
                     bind = bindparam_lookup[thing.key]
                     if thing.expanding:
                         bind.expanding = True
+                        bind.expand_op = thing.expand_op
+                        bind.type = thing.type
                     return bind
 
         if self._rec.is_sequence:
index 1c76f37fa7b9841635722bcb01834d6b597180d0..9e829baeabf353214d5bdca4352a5987eb1ae5b5 100644 (file)
@@ -225,6 +225,22 @@ class TraversalTest(
             dialect="default",
         )
 
+    def test_expanding_in_bindparam_safe_to_clone(self):
+        expr = column("x").in_([1, 2, 3])
+
+        expr2 = expr._clone()
+
+        # shallow copy, bind is used twice
+        is_(expr.right, expr2.right)
+
+        stmt = and_(expr, expr2)
+        self.assert_compile(
+            stmt, "x IN ([POSTCOMPILE_x_1]) AND x IN ([POSTCOMPILE_x_1])"
+        )
+        self.assert_compile(
+            stmt, "x IN (1, 2, 3) AND x IN (1, 2, 3)", literal_binds=True
+        )
+
     def test_traversal_size(self):
         """Test :ticket:`6304`.
 
index cdfd92ece579b0d152ec923e9bfd0553bcd87e1f..2de969521e41ab918759ff8c9bd776d9334fbc04 100644 (file)
@@ -156,6 +156,43 @@ class LambdaElementTest(
             )
             eq_(result.all(), [(e["id"],) for e in data if e["name"] in case])
 
+    def test_in_expr_compile(self, user_address_fixture):
+        users, _ = user_address_fixture
+
+        def go(val):
+            stmt = lambdas.lambda_stmt(lambda: select(users.c.id))
+            stmt += lambda s: s.where(users.c.name.in_(val))
+            stmt += lambda s: s.order_by(users.c.id)
+            return stmt
+
+        # note this also requires the type of the bind is copied
+        self.assert_compile(
+            go([]),
+            "SELECT users.id FROM users "
+            "WHERE users.name IN (NULL) AND (1 != 1) ORDER BY users.id",
+            literal_binds=True,
+        )
+        self.assert_compile(
+            go(["u1", "u2"]),
+            "SELECT users.id FROM users "
+            "WHERE users.name IN ('u1', 'u2') ORDER BY users.id",
+            literal_binds=True,
+        )
+
+    def test_bind_type(self, user_address_fixture):
+        users, _ = user_address_fixture
+
+        def go(val):
+            stmt = lambdas.lambda_stmt(lambda: select(users.c.id))
+            stmt += lambda s: s.where(users.c.name == val)
+            return stmt
+
+        self.assert_compile(
+            go("u1"),
+            "SELECT users.id FROM users " "WHERE users.name = 'u1'",
+            literal_binds=True,
+        )
+
     def test_stale_checker_embedded(self):
         def go(x):
 
index f3e5282fd1beb9ba13d1654a4f7ecf4563263175..984379c6b72ee583b4110a3d7f63d4d409a44fa7 100644 (file)
@@ -1975,15 +1975,20 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 literal_binds=True,
             )
 
-    @testing.combinations(True, False)
-    def test_in_empty_tuple(self, is_in):
+    @testing.combinations(True, False, argnames="is_in")
+    @testing.combinations(True, False, argnames="negate")
+    def test_in_empty_tuple(self, is_in, negate):
         a, b, c = (
             column("a", Integer),
             column("b", String),
             column("c", LargeBinary),
         )
         t1 = tuple_(a, b, c)
-        expr = t1.in_([]) if is_in else t1.not_in([])
+
+        if negate:
+            expr = ~t1.not_in([]) if is_in else ~t1.in_([])
+        else:
+            expr = t1.in_([]) if is_in else t1.not_in([])
 
         if is_in:
             self.assert_compile(
@@ -2010,10 +2015,15 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 dialect="default_enhanced",
             )
 
-    @testing.combinations(True, False)
-    def test_in_empty_single(self, is_in):
+    @testing.combinations(True, False, argnames="is_in")
+    @testing.combinations(True, False, argnames="negate")
+    def test_in_empty_single(self, is_in, negate):
         a = column("a", Integer)
-        expr = a.in_([]) if is_in else a.not_in([])
+
+        if negate:
+            expr = ~a.not_in([]) if is_in else ~a.in_([])
+        else:
+            expr = a.in_([]) if is_in else a.not_in([])
 
         if is_in:
             self.assert_compile(
@@ -2040,6 +2050,36 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL):
                 dialect="default_enhanced",
             )
 
+    def test_in_self_plus_negated(self):
+        a = column("a", Integer)
+
+        expr1 = a.in_([5])
+        expr2 = ~expr1
+
+        stmt = and_(expr1, expr2)
+        self.assert_compile(
+            stmt, "a IN ([POSTCOMPILE_a_1]) AND (a NOT IN ([POSTCOMPILE_a_2]))"
+        )
+        self.assert_compile(
+            stmt, "a IN (5) AND (a NOT IN (5))", literal_binds=True
+        )
+
+    def test_in_self_plus_negated_empty(self):
+        a = column("a", Integer)
+
+        expr1 = a.in_([])
+        expr2 = ~expr1
+
+        stmt = and_(expr1, expr2)
+        self.assert_compile(
+            stmt, "a IN ([POSTCOMPILE_a_1]) AND (a NOT IN ([POSTCOMPILE_a_2]))"
+        )
+        self.assert_compile(
+            stmt,
+            "a IN (NULL) AND (1 != 1) AND (a NOT IN (NULL) OR (1 = 1))",
+            literal_binds=True,
+        )
+
     def test_in_set(self):
         s = {1, 2, 3}
         self.assert_compile(