]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support tuples of heterogeneous types for empty expanding IN
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Oct 2018 14:40:38 +0000 (10:40 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Oct 2018 14:40:38 +0000 (10:40 -0400)
Pass a list of all the types for the left side of an
IN expression to the visit_empty_set_expr() method, so that
the "empty expanding IN" can produce clauses for each element.

Fixes: #4271
Change-Id: I2738b9df2292ac01afda37f16d4fa56ae7bf9147

doc/build/changelog/migration_13.rst
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/testing/suite/test_select.py

index 5a8e3ce05bc75e34d3bd4189a87be3b2d3e0d430..40363d426889965200991d29ced1ff8b8013de73 100644 (file)
@@ -413,6 +413,24 @@ backend, such as "SELECT CAST(NULL AS INTEGER) WHERE 1!=1" for Postgresql,
     ...
     SELECT 1 WHERE 1 IN (SELECT CAST(NULL AS INTEGER) WHERE 1!=1)
 
+The feature also works for tuple-oriented IN statements, where the "empty IN"
+expression will be expanded to support the elements given inside the tuple,
+such as on Postgresql::
+
+    >>> from sqlalchemy import create_engine
+    >>> from sqlalchemy import select, literal_column, tuple_, bindparam
+    >>> e = create_engine("postgresql://scott:tiger@localhost/test", echo=True)
+    >>> with e.connect() as conn:
+    ...      conn.execute(
+    ...          select([literal_column('1')]).
+    ...          where(tuple_(50, "somestring").in_(bindparam('q', expanding=True))),
+    ...          q=[]
+    ...      )
+    ...
+    SELECT 1 WHERE (%(param_1)s, %(param_2)s)
+    IN (SELECT CAST(NULL AS INTEGER), CAST(NULL AS VARCHAR) WHERE 1!=1)
+
+
 :ticket:`4271`
 
 .. _change_3981:
index 45e8c251044800f08b8c2d5679e1791824b4c4f5..43966d1dc880a62a630104c82c085f652bd01fd1 100644 (file)
@@ -1179,8 +1179,18 @@ class MySQLCompiler(compiler.SQLCompiler):
                                  fromhints=from_hints, **kw)
             for t in [from_table] + extra_froms)
 
-    def visit_empty_set_expr(self, type_):
-        return 'SELECT 1 FROM (SELECT 1) as _empty_set WHERE 1!=1'
+    def visit_empty_set_expr(self, element_types):
+        return (
+            "SELECT %(outer)s FROM (SELECT %(inner)s) "
+            "as _empty_set WHERE 1!=1" % {
+                "inner": ", ".join(
+                    "1 AS _in_%s" % idx
+                    for idx, type_ in enumerate(element_types)),
+                "outer": ", ".join(
+                    "_in_%s" % idx
+                    for idx, type_ in enumerate(element_types))
+            }
+        )
 
 
 class MySQLDDLCompiler(compiler.DDLCompiler):
index 11fcc41d5988a11ef8fef8ac2e6b8c7b76a1f722..5251a000daaab963a95aca11daf30f958313b85f 100644 (file)
@@ -1485,14 +1485,17 @@ class PGCompiler(compiler.SQLCompiler):
                 if escape else ''
             )
 
-    def visit_empty_set_expr(self, type_, **kw):
+    def visit_empty_set_expr(self, element_types):
         # cast the empty set to the type we are comparing against.  if
         # we are comparing against the null type, pick an arbitrary
         # datatype for the empty set
-        if type_._isnull:
-            type_ = INTEGER()
-        return 'SELECT CAST(NULL AS %s) WHERE 1!=1' % \
-               self.dialect.type_compiler.process(type_, **kw)
+        return 'SELECT %s WHERE 1!=1' % (
+            ", ".join(
+                "CAST(NULL AS %s)" % self.dialect.type_compiler.process(
+                    INTEGER() if type_._isnull else type_,
+                ) for type_ in element_types or [INTEGER()]
+            ),
+        )
 
     def render_literal_value(self, value, type_):
         value = super(PGCompiler, self).render_literal_value(value, type_)
index f48217a4e9cbbda27b4e4a894f4d920a58b5517c..5c96e4240e1ead003db6bf7d344449598e8c3720 100644 (file)
@@ -737,7 +737,10 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
                     to_update = []
                     replacement_expressions[name] = (
                         self.compiled.visit_empty_set_expr(
-                            type_=parameter.type)
+                            parameter._expanding_in_types
+                            if parameter._expanding_in_types
+                            else [parameter.type]
+                        )
                     )
 
                 elif isinstance(values[0], (tuple, list)):
index 2f68b7e2e46c65fbc4ce193fc47d97ff5d055def..27ee4afc68977d1449875a4e8e5e0a7e74308b44 100644 (file)
@@ -1056,7 +1056,7 @@ class SQLCompiler(Compiled):
                 self._emit_empty_in_warning()
             return self.process(binary.left == binary.left)
 
-    def visit_empty_set_expr(self, type_):
+    def visit_empty_set_expr(self, element_types):
         raise NotImplementedError(
             "Dialect '%s' does not support empty set expression." %
             self.dialect.name
index 5d02f65a15a35f3f5f2205ff5270eaa60e785026..8149f9731df99339d9460566e59f732face067bd 100644 (file)
@@ -15,7 +15,8 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \
     Null, _const_expr, _clause_element_as_expr, \
     ClauseList, ColumnElement, TextClause, UnaryExpression, \
     collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
-    Slice, Visitable, _literal_as_binds, CollectionAggregate
+    Slice, Visitable, _literal_as_binds, CollectionAggregate, \
+    Tuple
 from .selectable import SelectBase, Alias, Selectable, ScalarSelect
 
 
@@ -145,6 +146,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
     elif isinstance(seq_or_selectable, ClauseElement):
         if isinstance(seq_or_selectable, BindParameter) and \
                 seq_or_selectable.expanding:
+
+            if isinstance(expr, Tuple):
+                seq_or_selectable = (
+                    seq_or_selectable._with_expanding_in_types(
+                        [elem.type for elem in expr]
+                    )
+                )
+
             return _boolean_compare(
                 expr, op,
                 seq_or_selectable,
index dd16b68628dde3e533798a40dce74544ec76ebad..de3b7992af9c1084e0d82e18320ddc518bde46e5 100644 (file)
@@ -865,6 +865,7 @@ class BindParameter(ColumnElement):
     __visit_name__ = 'bindparam'
 
     _is_crud = False
+    _expanding_in_types = ()
 
     def __init__(self, key, value=NO_ARG, type_=None,
                  unique=False, required=NO_ARG,
@@ -1134,6 +1135,15 @@ class BindParameter(ColumnElement):
         else:
             self.type = type_
 
+    def _with_expanding_in_types(self, types):
+        """Return a copy of this :class:`.BindParameter` in
+        the context of an expanding IN against a tuple.
+
+        """
+        cloned = self._clone()
+        cloned._expanding_in_types = types
+        return cloned
+
     def _with_value(self, value):
         """Return a copy of this :class:`.BindParameter` with the given value
         set.
index 78b34f496278f6428a9c4f60931f96adb3dc0bc8..73ce02492f5d8a8cadd6323b50f6790bc6b7c542 100644 (file)
@@ -402,6 +402,34 @@ class ExpandingBoundInTest(fixtures.TablesTest):
             params={"q": [], "p": []},
         )
 
+    @testing.requires.tuple_in
+    def test_empty_heterogeneous_tuples(self):
+        table = self.tables.some_table
+
+        stmt = select([table.c.id]).where(
+            tuple_(table.c.x, table.c.z).in_(
+                bindparam('q', expanding=True))).order_by(table.c.id)
+
+        self._assert_result(
+            stmt,
+            [],
+            params={"q": []},
+        )
+
+    @testing.requires.tuple_in
+    def test_empty_homogeneous_tuples(self):
+        table = self.tables.some_table
+
+        stmt = select([table.c.id]).where(
+            tuple_(table.c.x, table.c.y).in_(
+                bindparam('q', expanding=True))).order_by(table.c.id)
+
+        self._assert_result(
+            stmt,
+            [],
+            params={"q": []},
+        )
+
     def test_bound_in_scalar(self):
         table = self.tables.some_table
 
@@ -428,6 +456,20 @@ class ExpandingBoundInTest(fixtures.TablesTest):
             params={"q": [(2, 3), (3, 4), (4, 5)]},
         )
 
+    @testing.requires.tuple_in
+    def test_bound_in_heterogeneous_two_tuple(self):
+        table = self.tables.some_table
+
+        stmt = select([table.c.id]).where(
+            tuple_(table.c.x, table.c.z).in_(
+                bindparam('q', expanding=True))).order_by(table.c.id)
+
+        self._assert_result(
+            stmt,
+            [(2, ), (3, ), (4, )],
+            params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
+        )
+
     def test_empty_set_against_integer(self):
         table = self.tables.some_table