]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Infer types in BindParameter when expanding=True
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Apr 2021 01:43:17 +0000 (21:43 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 8 Apr 2021 14:29:08 +0000 (10:29 -0400)
Enhanced the "expanding" feature used for :meth:`_sql.ColumnOperators.in_`
operations to infer the type of expression from the right hand list of
elements, if the left hand side does not have any explicit type set up.
This allows the expression to support stringification among other things.
In 1.3, "expanding" was not automatically used for
:meth:`_sql.ColumnOperators.in_` expressions, so in that sense this change
fixes a behavioral regression.

Fixes: #6222
Change-Id: Icdfda1e2c226a21896cafd6d8f251547794451c2

doc/build/changelog/unreleased_14/6222.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py
test/sql/test_compare.py
test/sql/test_operators.py
test/sql/test_selectable.py

diff --git a/doc/build/changelog/unreleased_14/6222.rst b/doc/build/changelog/unreleased_14/6222.rst
new file mode 100644 (file)
index 0000000..7464b09
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 6222
+
+    Enhanced the "expanding" feature used for :meth:`_sql.ColumnOperators.in_`
+    operations to infer the type of expression from the right hand list of
+    elements, if the left hand side does not have any explicit type set up.
+    This allows the expression to support stringification among other things.
+    In 1.3, "expanding" was not automatically used for
+    :meth:`_sql.ColumnOperators.in_` expressions, so in that sense this change
+    fixes a behavioral regression.
+
index dfa6c0f8f41cd9922b2720869a176777d6853572..9e1b690888abd744de55edd45ddf734793eb10c1 100644 (file)
@@ -1392,15 +1392,26 @@ class BindParameter(roles.InElementRole, ColumnElement):
         self.literal_execute = literal_execute
         if _is_crud:
             self._is_crud = True
+
         if type_ is None:
+            if expanding and value:
+                check_value = value[0]
+            else:
+                check_value = value
             if _compared_to_type is not None:
                 self.type = _compared_to_type.coerce_compared_value(
-                    _compared_to_operator, value
+                    _compared_to_operator, check_value
                 )
             else:
-                self.type = type_api._resolve_value_to_type(value)
+                self.type = type_api._resolve_value_to_type(check_value)
         elif isinstance(type_, type):
             self.type = type_()
+        elif type_._is_tuple_type:
+            if expanding and value:
+                check_value = value[0]
+            else:
+                check_value = value
+            self.type = type_._resolve_values_to_types(check_value)
         else:
             self.type = type_
 
index e57a14681d77d4915ae4577fd3fb40df4cf0a0b4..bdfb7d8332f211cbd7cc378a8c130c38aa9cd99d 100644 (file)
@@ -2912,8 +2912,20 @@ class TupleType(TypeEngine):
     _is_tuple_type = True
 
     def __init__(self, *types):
+        self._fully_typed = NULLTYPE not in types
         self.types = types
 
+    def _resolve_values_to_types(self, value):
+        if self._fully_typed:
+            return self
+        else:
+            return TupleType(
+                *[
+                    _resolve_value_to_type(elem) if typ is NULLTYPE else typ
+                    for typ, elem in zip(self.types, value)
+                ]
+            )
+
     def result_processor(self, dialect, coltype):
         raise NotImplementedError(
             "The tuple type does not support being fetched "
index 02dd0661a5a3737a7db147c5685defcd990e5449..21a349d76faa2416616261557e2b8c632cdf921f 100644 (file)
@@ -692,12 +692,32 @@ class CoreFixtures(object):
             column("y").in_(
                 bindparam(
                     "q",
+                    # note that a different cache key is created if
+                    # the value given to the bindparam is [], as the type
+                    # cannot be inferred for the empty list but can
+                    # for the non-empty list as of #6222
+                    random_choices(range(10), k=random.randint(1, 7)),
+                    expanding=True,
+                )
+            ),
+            column("y2").in_(
+                bindparam(
+                    "q",
+                    # for typed param, empty and not empty param will have
+                    # the same type
                     random_choices(range(10), k=random.randint(0, 7)),
+                    type_=Integer,
                     expanding=True,
                 )
             ),
-            column("z").in_(random_choices(range(10), k=random.randint(0, 7))),
-            column("x") == random.randint(1, 10),
+            # don't include empty for untyped, will create different cache
+            # key
+            column("z").in_(random_choices(range(10), k=random.randint(1, 7))),
+            # empty is fine for typed, will create the same cache key
+            column("z2", Integer).in_(
+                random_choices(range(10), k=random.randint(0, 7))
+            ),
+            column("x") == random.randint(0, 10),
         )
     ]
 
index 0d7f331e0e2df9caa6100919128533ce455222a9..878360b9d8da98f85289745dd2f93b9f1f422f14 100644 (file)
@@ -1989,6 +1989,23 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL):
             "SELECT t.a FROM t WHERE t.a IN (SELECT scan)",
         )
 
+    def test_type_inference_one(self):
+        expr = column("q").in_([1, 2, 3])
+        is_(expr.right.type._type_affinity, Integer)
+
+        self.assert_compile(expr, "q IN (1, 2, 3)", literal_binds=True)
+
+    def test_type_inference_two(self):
+        expr = column("q").in_([])
+        is_(expr.right.type, sqltypes.NULLTYPE)
+
+        self.assert_compile(
+            expr,
+            "q IN (SELECT 1 WHERE 1!=1)",
+            literal_binds=True,
+            dialect="default_enhanced",
+        )
+
 
 class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
@@ -3075,6 +3092,26 @@ class TupleTypingTest(fixtures.TestBase):
 
         self._assert_types(expr.right.type.types)
 
+    # since we want to infer "binary"
+    @testing.requires.python3
+    def test_tuple_type_expanding_inference(self):
+        a, b, c = column("a"), column("b"), column("c")
+
+        t1 = tuple_(a, b, c)
+        expr = t1.in_([(3, "hi", b"there"), (4, "Q", b"P")])
+
+        eq_(len(expr.right.value), 2)
+
+        self._assert_types(expr.right.type.types)
+
+    @testing.requires.python3
+    def test_tuple_type_plain_inference(self):
+        a, b, c = column("a"), column("b"), column("c")
+
+        t1 = tuple_(a, b, c)
+        expr = t1 == (3, "hi", b"there")
+        self._assert_types(expr.right.type.types)
+
 
 class InSelectableTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
index b98487933cd346bdef1cf10075611a6f0556df29..291c3f36e4699702b8d3d7d821898933adff9c14 100644 (file)
@@ -3013,7 +3013,7 @@ class AnnotationsTest(fixtures.TestBase):
 class ReprTest(fixtures.TestBase):
     def test_ensure_repr_elements(self):
         for obj in [
-            elements.Cast(1, 2),
+            elements.Cast(1, Integer()),
             elements.TypeClause(String()),
             elements.ColumnClause("x"),
             elements.BindParameter("q"),