]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Fix coercion issue for tuple bindparams
authormike bayer <mike_mp@zzzcomputing.com>
Wed, 15 Feb 2023 22:20:06 +0000 (23:20 +0100)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 16 Feb 2023 02:47:45 +0000 (21:47 -0500)
Fixed issue where element types of a tuple value would be hardcoded to take
on the types from a compared-to tuple, when the comparison were using the
:meth:`.ColumnOperators.in_` operator. This was inconsistent with the usual
way that types are determined for a binary expression, which is that the
actual element type on the right side is considered first before applying
the left-hand-side type.

Fixes: #9313
Change-Id: Ia8874c09682a6512fcf4084cf14481024959c461

doc/build/changelog/unreleased_20/9313.rst [new file with mode: 0644]
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_20/9313.rst b/doc/build/changelog/unreleased_20/9313.rst
new file mode 100644 (file)
index 0000000..78f4f12
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 9313
+
+    Fixed issue where element types of a tuple value would be hardcoded to take
+    on the types from a compared-to tuple, when the comparison were using the
+    :meth:`.ColumnOperators.in_` operator. This was inconsistent with the usual
+    way that types are determined for a binary expression, which is that the
+    actual element type on the right side is considered first before applying
+    the left-hand-side type.
index 4c2c7de3c4955f3c549755b6414e0eee868ebdf1..e51b755ddbb82bc67f8e6aaaa136d4206edfcd92 100644 (file)
@@ -1976,8 +1976,11 @@ class BindParameter(roles.InElementRole, KeyedColumnElement[_T]):
             self._is_crud = True
 
         if type_ is None:
-            if expanding and value:
-                check_value = value[0]
+            if expanding:
+                if value:
+                    check_value = value[0]
+                else:
+                    check_value = type_api._NO_VALUE_IN_LIST
             else:
                 check_value = value
             if _compared_to_type is not None:
@@ -3166,7 +3169,8 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]):
                 _compared_to_operator=operator,
                 unique=True,
                 expanding=True,
-                type_=self.type,
+                type_=type_,
+                _compared_to_type=self.type,
             )
         else:
             return Tuple(
index b2dcc9b8a2570c20161d855f4194c83285fc367f..3c6cb0cb558740c6308d01f044d7f074f9352fa4 100644 (file)
@@ -3157,6 +3157,20 @@ class TupleType(TypeEngine[Tuple[Any, ...]]):
             for item_type in types
         ]
 
+    def coerce_compared_value(
+        self, op: Optional[OperatorType], value: Any
+    ) -> TypeEngine[Any]:
+
+        if value is type_api._NO_VALUE_IN_LIST:
+            return super().coerce_compared_value(op, value)
+        else:
+            return TupleType(
+                *[
+                    typ.coerce_compared_value(op, elem)
+                    for typ, elem in zip(self.types, value)
+                ]
+            )
+
     def _resolve_values_to_types(self, value: Any) -> TupleType:
         if self._fully_typed:
             return self
index af7ed21c48c6b1089fbc34584d4b89428feb031e..7167430f14d6b81aec11576689be56c6adf12e08 100644 (file)
@@ -11,6 +11,7 @@
 
 from __future__ import annotations
 
+from enum import Enum
 from types import ModuleType
 import typing
 from typing import Any
@@ -68,7 +69,14 @@ _CT = TypeVar("_CT", bound=Any)
 
 _MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]]
 
-# replace with pep-673 when applicable
+
+class _NoValueInList(Enum):
+    NO_VALUE_IN_LIST = 0
+    """indicates we are trying to determine the type of an expression
+    against an empty list."""
+
+
+_NO_VALUE_IN_LIST = _NoValueInList.NO_VALUE_IN_LIST
 
 
 class _LiteralProcessorType(Protocol[_T_co]):
index e7e51aa635dec4bcf5220d71ab80366325264550..8ed8c7d33cad0408e3f414799eef29c976e1db6b 100644 (file)
@@ -4303,6 +4303,14 @@ class TupleTypingTest(fixtures.TestBase):
         expr = t1 == (3, "hi", b"there")
         self._assert_types(expr.right.type.types)
 
+    def test_tuple_type_left_type_ignored(self):
+        a, b = column("a", sqltypes.Date), column("b", sqltypes.DateTime)
+        c = column("c", sqltypes.Float)
+
+        t1 = tuple_(a, b, c)
+        expr = t1.in_([(3, "hi", b"there")])
+        self._assert_types(expr.right.type.types)
+
 
 class InSelectableTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"