]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Use non-subquery form for empty IN
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 28 Apr 2021 22:31:51 +0000 (18:31 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 29 Apr 2021 18:43:09 +0000 (14:43 -0400)
Revised the "EMPTY IN" expression to no longer rely upon using a subquery,
as this was causing some compatibility and performance problems. The new
approach for selected databases takes advantage of using a NULL-returning
IN expression combined with the usual "1 != 1" or "1 = 1" expression
appended by AND or OR. The expression is now the default for all backends
other than SQLite, which still had some compatibility issues regarding
tuple "IN" for older SQLite versions.

Third party dialects can still override how the "empty set" expression
renders by implementing a new compiler method
``def visit_empty_set_op_expr(self, type_, expand_op)``, which takes
precedence over the existing
``def visit_empty_set_expr(self, element_types)`` which remains in place.

Fixes: #6258
Fixes: #6397
Change-Id: I2df09eb00d2ad3b57039ae48128fdf94641b5e59

doc/build/changelog/unreleased_14/6258.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/sqlite/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/testing/suite/test_select.py
test/sql/test_lambdas.py

diff --git a/doc/build/changelog/unreleased_14/6258.rst b/doc/build/changelog/unreleased_14/6258.rst
new file mode 100644 (file)
index 0000000..83c3668
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 6258 6397
+
+    Revised the "EMPTY IN" expression to no longer rely upon using a subquery,
+    as this was causing some compatibility and performance problems. The new
+    approach for selected databases takes advantage of using a NULL-returning
+    IN expression combined with the usual "1 != 1" or "1 = 1" expression
+    appended by AND or OR. The expression is now the default for all backends
+    other than SQLite, which still had some compatibility issues regarding
+    tuple "IN" for older SQLite versions.
+
+    Third party dialects can still override how the "empty set" expression
+    renders by implementing a new compiler method
+    ``def visit_empty_set_op_expr(self, type_, expand_op)``, which takes
+    precedence over the existing
+    ``def visit_empty_set_expr(self, element_types)`` which remains in place.
+
index 59d40fef0defca8e3c1af75665f2c5f8ffb23d68..66a556ae0d5cb3497275159bd993fff37e599663 100644 (file)
@@ -1299,6 +1299,11 @@ class SQLiteCompiler(compiler.SQLCompiler):
             self.process(binary.right, **kw),
         )
 
+    def visit_empty_set_op_expr(self, type_, expand_op):
+        # slightly old SQLite versions don't seem to be able to handle
+        # the empty set impl
+        return self.visit_empty_set_expr(type_)
+
     def visit_empty_set_expr(self, element_types):
         return "SELECT %s FROM (SELECT %s) WHERE 1!=1" % (
             ", ".join("1" for type_ in element_types or [INTEGER()]),
index b7aba9d7476eda04fb1841f8496812cecb65505a..820fc1bf19f21ebe433dc58fe4c108b5d2650253 100644 (file)
@@ -561,14 +561,9 @@ class InElementImpl(RoleImpl):
             return element.self_group(against=operator)
 
         elif isinstance(element, elements.BindParameter):
-            if not element.expanding:
-                # coercing to expanding at the moment to work with the
-                # lambda system.  not sure if this is the right approach.
-                # is there a valid use case to send a single non-expanding
-                # param to IN? check for ARRAY type?
-                element = element._clone(maintain_key=True)
-                element.expanding = True
-
+            # previously we were adding expanding flags here but
+            # we now do this in the compiler where we have more context
+            # see compiler.py -> _render_in_expr_w_bindparam
             return element
         else:
             return element
index 6168248ff7918fc3f1575d05a0cdaf483991ce26..e9e05b7e9a5439fc7c15153f4172bcc1321fe75a 100644 (file)
@@ -1903,6 +1903,45 @@ class SQLCompiler(Compiled):
             binary, override_operator=operators.match_op
         )
 
+    def visit_in_op_binary(self, binary, operator, **kw):
+        return self._render_in_expr_w_bindparam(binary, operator, **kw)
+
+    def visit_not_in_op_binary(self, binary, operator, **kw):
+        return self._render_in_expr_w_bindparam(binary, operator, **kw)
+
+    def _render_in_expr_w_bindparam(self, binary, operator, **kw):
+        opstring = OPERATORS[operator]
+
+        if isinstance(binary.right, elements.BindParameter):
+            if not binary.right.expanding or not binary.right.expand_op:
+                # note that by cloning here, we rely upon the
+                # _cache_key_bind_match dictionary to resolve
+                # clones of bindparam() objects to the ones that are
+                # present in our cache key.
+                binary.right = binary.right._clone(maintain_key=True)
+                binary.right.expanding = True
+                binary.right.expand_op = operator
+
+        return self._generate_generic_binary(binary, opstring, **kw)
+
+    def visit_empty_set_op_expr(self, type_, expand_op):
+        if expand_op is operators.not_in_op:
+            if len(type_) > 1:
+                return "(%s)) OR (1 = 1" % (
+                    ", ".join("NULL" for element in type_)
+                )
+            else:
+                return "NULL) OR (1 = 1"
+        elif expand_op is operators.in_op:
+            if len(type_) > 1:
+                return "(%s)) AND (1 != 1" % (
+                    ", ".join("NULL" for element in type_)
+                )
+            else:
+                return "NULL) AND (1 != 1"
+        else:
+            return self.visit_empty_set_expr(type_)
+
     def visit_empty_set_expr(self, element_types):
         raise NotImplementedError(
             "Dialect '%s' does not support empty set expression."
@@ -1959,12 +1998,12 @@ class SQLCompiler(Compiled):
             to_update = []
             if parameter.type._is_tuple_type:
 
-                replacement_expression = self.visit_empty_set_expr(
-                    parameter.type.types
+                replacement_expression = self.visit_empty_set_op_expr(
+                    parameter.type.types, parameter.expand_op
                 )
             else:
-                replacement_expression = self.visit_empty_set_expr(
-                    [parameter.type]
+                replacement_expression = self.visit_empty_set_op_expr(
+                    [parameter.type], parameter.expand_op
                 )
 
         elif isinstance(values[0], (tuple, list)):
@@ -3900,6 +3939,9 @@ class StrSQLCompiler(SQLCompiler):
             for t in extra_froms
         )
 
+    def visit_empty_set_op_expr(self, type_, expand_op):
+        return self.visit_empty_set_expr(type_)
+
     def visit_empty_set_expr(self, type_):
         return "SELECT 1 WHERE 1!=1"
 
index 696f3b249246000af25e4c81d55c316c5c68cd84..e27b978021557bc67d3a4ed46317a2a18c9632cd 100644 (file)
@@ -1411,7 +1411,17 @@ class BindParameter(roles.InElementRole, ColumnElement):
         self.callable = callable_
         self.isoutparam = isoutparam
         self.required = required
+
+        # indicate an "expanding" parameter; the compiler sets this
+        # automatically in the compiler _render_in_expr_w_bindparam method
+        # for an IN expression
         self.expanding = expanding
+
+        # this is another hint to help w/ expanding and is typically
+        # set in the compiler _render_in_expr_w_bindparam method for an
+        # IN expression
+        self.expand_op = None
+
         self.literal_execute = literal_execute
         if _is_crud:
             self._is_crud = True
index 7b35dc3fa360a9b82df7415af8a0e05f7a8db7c9..1614acd3da3f6bf5407e0587b9994e5bdbc5c349 100644 (file)
@@ -1016,163 +1016,246 @@ class ExpandingBoundInTest(fixtures.TablesTest):
         with config.db.connect() as conn:
             eq_(conn.execute(select, params).fetchall(), result)
 
-    def test_multiple_empty_sets(self):
+    def test_multiple_empty_sets_bindparam(self):
         # test that any anonymous aliasing used by the dialect
         # is fine with duplicates
         table = self.tables.some_table
-
         stmt = (
             select(table.c.id)
-            .where(table.c.x.in_(bindparam("q", expanding=True)))
-            .where(table.c.y.in_(bindparam("p", expanding=True)))
+            .where(table.c.x.in_(bindparam("q")))
+            .where(table.c.y.in_(bindparam("p")))
             .order_by(table.c.id)
         )
-
         self._assert_result(stmt, [], params={"q": [], "p": []})
 
-    @testing.requires.tuple_in_w_empty
-    def test_empty_heterogeneous_tuples(self):
+    def test_multiple_empty_sets_direct(self):
+        # test that any anonymous aliasing used by the dialect
+        # is fine with duplicates
         table = self.tables.some_table
-
         stmt = (
             select(table.c.id)
-            .where(
-                tuple_(table.c.x, table.c.z).in_(
-                    bindparam("q", expanding=True)
-                )
-            )
+            .where(table.c.x.in_([]))
+            .where(table.c.y.in_([]))
             .order_by(table.c.id)
         )
+        self._assert_result(stmt, [])
 
+    @testing.requires.tuple_in_w_empty
+    def test_empty_heterogeneous_tuples_bindparam(self):
+        table = self.tables.some_table
+        stmt = (
+            select(table.c.id)
+            .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+            .order_by(table.c.id)
+        )
         self._assert_result(stmt, [], params={"q": []})
 
     @testing.requires.tuple_in_w_empty
-    def test_empty_homogeneous_tuples(self):
+    def test_empty_heterogeneous_tuples_direct(self):
         table = self.tables.some_table
 
+        def go(val, expected):
+            stmt = (
+                select(table.c.id)
+                .where(tuple_(table.c.x, table.c.z).in_(val))
+                .order_by(table.c.id)
+            )
+            self._assert_result(stmt, expected)
+
+        go([], [])
+        go([(2, "z2"), (3, "z3"), (4, "z4")], [(2,), (3,), (4,)])
+        go([], [])
+
+    @testing.requires.tuple_in_w_empty
+    def test_empty_homogeneous_tuples_bindparam(self):
+        table = self.tables.some_table
         stmt = (
             select(table.c.id)
-            .where(
-                tuple_(table.c.x, table.c.y).in_(
-                    bindparam("q", expanding=True)
-                )
-            )
+            .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
             .order_by(table.c.id)
         )
-
         self._assert_result(stmt, [], params={"q": []})
 
-    def test_bound_in_scalar(self):
+    @testing.requires.tuple_in_w_empty
+    def test_empty_homogeneous_tuples_direct(self):
         table = self.tables.some_table
 
+        def go(val, expected):
+            stmt = (
+                select(table.c.id)
+                .where(tuple_(table.c.x, table.c.y).in_(val))
+                .order_by(table.c.id)
+            )
+            self._assert_result(stmt, expected)
+
+        go([], [])
+        go([(1, 2), (2, 3), (3, 4)], [(1,), (2,), (3,)])
+        go([], [])
+
+    def test_bound_in_scalar_bindparam(self):
+        table = self.tables.some_table
         stmt = (
             select(table.c.id)
-            .where(table.c.x.in_(bindparam("q", expanding=True)))
+            .where(table.c.x.in_(bindparam("q")))
             .order_by(table.c.id)
         )
-
         self._assert_result(stmt, [(2,), (3,), (4,)], params={"q": [2, 3, 4]})
 
-    @testing.requires.tuple_in
-    def test_bound_in_two_tuple(self):
+    def test_bound_in_scalar_direct(self):
         table = self.tables.some_table
-
         stmt = (
             select(table.c.id)
-            .where(
-                tuple_(table.c.x, table.c.y).in_(
-                    bindparam("q", expanding=True)
-                )
-            )
+            .where(table.c.x.in_([2, 3, 4]))
             .order_by(table.c.id)
         )
+        self._assert_result(stmt, [(2,), (3,), (4,)])
 
+    @testing.requires.tuple_in
+    def test_bound_in_two_tuple_bindparam(self):
+        table = self.tables.some_table
+        stmt = (
+            select(table.c.id)
+            .where(tuple_(table.c.x, table.c.y).in_(bindparam("q")))
+            .order_by(table.c.id)
+        )
         self._assert_result(
             stmt, [(2,), (3,), (4,)], params={"q": [(2, 3), (3, 4), (4, 5)]}
         )
 
     @testing.requires.tuple_in
-    def test_bound_in_heterogeneous_two_tuple(self):
+    def test_bound_in_two_tuple_direct(self):
+        table = self.tables.some_table
+        stmt = (
+            select(table.c.id)
+            .where(tuple_(table.c.x, table.c.y).in_([(2, 3), (3, 4), (4, 5)]))
+            .order_by(table.c.id)
+        )
+        self._assert_result(stmt, [(2,), (3,), (4,)])
+
+    @testing.requires.tuple_in
+    def test_bound_in_heterogeneous_two_tuple_bindparam(self):
         table = self.tables.some_table
+        stmt = (
+            select(table.c.id)
+            .where(tuple_(table.c.x, table.c.z).in_(bindparam("q")))
+            .order_by(table.c.id)
+        )
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+        )
 
+    @testing.requires.tuple_in
+    def test_bound_in_heterogeneous_two_tuple_direct(self):
+        table = self.tables.some_table
         stmt = (
             select(table.c.id)
             .where(
                 tuple_(table.c.x, table.c.z).in_(
-                    bindparam("q", expanding=True)
+                    [(2, "z2"), (3, "z3"), (4, "z4")]
                 )
             )
             .order_by(table.c.id)
         )
-
         self._assert_result(
             stmt,
             [(2,), (3,), (4,)],
-            params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
         )
 
     @testing.requires.tuple_in
-    def test_bound_in_heterogeneous_two_tuple_text(self):
+    def test_bound_in_heterogeneous_two_tuple_text_bindparam(self):
+        # note this becomes ARRAY if we dont use expanding
+        # explicitly right now
         stmt = text(
             "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
         ).bindparams(bindparam("q", expanding=True))
-
         self._assert_result(
             stmt,
             [(2,), (3,), (4,)],
             params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
         )
 
-    def test_empty_set_against_integer(self):
+    def test_empty_set_against_integer_bindparam(self):
         table = self.tables.some_table
-
         stmt = (
             select(table.c.id)
-            .where(table.c.x.in_(bindparam("q", expanding=True)))
+            .where(table.c.x.in_(bindparam("q")))
             .order_by(table.c.id)
         )
-
         self._assert_result(stmt, [], params={"q": []})
 
-    def test_empty_set_against_integer_negation(self):
+    def test_empty_set_against_integer_direct(self):
         table = self.tables.some_table
+        stmt = select(table.c.id).where(table.c.x.in_([])).order_by(table.c.id)
+        self._assert_result(stmt, [])
 
+    def test_empty_set_against_integer_negation_bindparam(self):
+        table = self.tables.some_table
         stmt = (
             select(table.c.id)
-            .where(table.c.x.not_in(bindparam("q", expanding=True)))
+            .where(table.c.x.not_in(bindparam("q")))
             .order_by(table.c.id)
         )
-
         self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
 
-    def test_empty_set_against_string(self):
+    def test_empty_set_against_integer_negation_direct(self):
         table = self.tables.some_table
+        stmt = (
+            select(table.c.id).where(table.c.x.not_in([])).order_by(table.c.id)
+        )
+        self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
 
+    def test_empty_set_against_string_bindparam(self):
+        table = self.tables.some_table
         stmt = (
             select(table.c.id)
-            .where(table.c.z.in_(bindparam("q", expanding=True)))
+            .where(table.c.z.in_(bindparam("q")))
             .order_by(table.c.id)
         )
-
         self._assert_result(stmt, [], params={"q": []})
 
-    def test_empty_set_against_string_negation(self):
+    def test_empty_set_against_string_direct(self):
         table = self.tables.some_table
+        stmt = select(table.c.id).where(table.c.z.in_([])).order_by(table.c.id)
+        self._assert_result(stmt, [])
 
+    def test_empty_set_against_string_negation_bindparam(self):
+        table = self.tables.some_table
         stmt = (
             select(table.c.id)
-            .where(table.c.z.not_in(bindparam("q", expanding=True)))
+            .where(table.c.z.not_in(bindparam("q")))
             .order_by(table.c.id)
         )
-
         self._assert_result(stmt, [(1,), (2,), (3,), (4,)], params={"q": []})
 
-    def test_null_in_empty_set_is_false(self, connection):
+    def test_empty_set_against_string_negation_direct(self):
+        table = self.tables.some_table
+        stmt = (
+            select(table.c.id).where(table.c.z.not_in([])).order_by(table.c.id)
+        )
+        self._assert_result(stmt, [(1,), (2,), (3,), (4,)])
+
+    def test_null_in_empty_set_is_false_bindparam(self, connection):
+        stmt = select(
+            case(
+                [
+                    (
+                        null().in_(bindparam("foo", value=())),
+                        true(),
+                    )
+                ],
+                else_=false(),
+            )
+        )
+        in_(connection.execute(stmt).fetchone()[0], (False, 0))
+
+    def test_null_in_empty_set_is_false_direct(self, connection):
         stmt = select(
             case(
                 [
                     (
-                        null().in_(bindparam("foo", value=(), expanding=True)),
+                        null().in_([]),
                         true(),
                     )
                 ],
index 24a83c9ee3d2a920a71ba6868496a7aa51e9fb30..897c60f0039a91afa6c3dc14585177e292ce192e 100644 (file)
@@ -116,6 +116,46 @@ class LambdaElementTest(
         result = go()
         eq_(result.all(), [(2,)])
 
+    def test_in_expressions(self, user_address_fixture, connection):
+        """test #6397.   we initially were going to use two different
+        forms for "empty in" vs. regular "in", but instead we have an
+        improved substitution for "empty in".  regardless, as there's more
+        going on with these, make sure lambdas work with them including
+        caching.
+
+        """
+        users, _ = user_address_fixture
+        data = [
+            {"id": 1, "name": "u1"},
+            {"id": 2, "name": "u2"},
+            {"id": 3, "name": "u3"},
+        ]
+        connection.execute(users.insert(), data)
+
+        def go(val):
+            stmt = lambdas.lambda_stmt(lambda: select(users.c.id))
+            stmt += lambda s: s.where(users.c.name.in_(val))
+            stmt += lambda s: s.order_by(users.c.id)
+            return connection.execute(stmt)
+
+        for case in [
+            [],
+            ["u1", "u2"],
+            ["u3"],
+            [],
+            ["u1", "u2"],
+        ]:
+            with testing.assertsql.assert_engine(testing.db) as asserter_:
+                result = go(case)
+            asserter_.assert_(
+                CompiledSQL(
+                    "SELECT users.id FROM users WHERE users.name "
+                    "IN ([POSTCOMPILE_val_1]) ORDER BY users.id",
+                    params={"val_1": case},
+                )
+            )
+            eq_(result.all(), [(e["id"],) for e in data if e["name"] in case])
+
     def test_stale_checker_embedded(self):
         def go(x):