From: mike bayer Date: Wed, 15 Feb 2023 22:20:06 +0000 (+0100) Subject: Fix coercion issue for tuple bindparams X-Git-Tag: rel_2_0_4~5^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=693f7f7a84ac77eaacc9ff9c8035a249d7f1ce7e;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Fix coercion issue for tuple bindparams Fixed issue where element types of a tuple value would be hardcoded to take on the types from a compared-to tuple, when the comparison were using the :meth:`.ColumnOperators.in_` operator. This was inconsistent with the usual way that types are determined for a binary expression, which is that the actual element type on the right side is considered first before applying the left-hand-side type. Fixes: #9313 Change-Id: Ia8874c09682a6512fcf4084cf14481024959c461 --- diff --git a/doc/build/changelog/unreleased_20/9313.rst b/doc/build/changelog/unreleased_20/9313.rst new file mode 100644 index 0000000000..78f4f12e63 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9313.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 9313 + + Fixed issue where element types of a tuple value would be hardcoded to take + on the types from a compared-to tuple, when the comparison were using the + :meth:`.ColumnOperators.in_` operator. This was inconsistent with the usual + way that types are determined for a binary expression, which is that the + actual element type on the right side is considered first before applying + the left-hand-side type. diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 4c2c7de3c4..e51b755ddb 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1976,8 +1976,11 @@ class BindParameter(roles.InElementRole, KeyedColumnElement[_T]): self._is_crud = True if type_ is None: - if expanding and value: - check_value = value[0] + if expanding: + if value: + check_value = value[0] + else: + check_value = type_api._NO_VALUE_IN_LIST else: check_value = value if _compared_to_type is not None: @@ -3166,7 +3169,8 @@ class Tuple(ClauseList, ColumnElement[typing_Tuple[Any, ...]]): _compared_to_operator=operator, unique=True, expanding=True, - type_=self.type, + type_=type_, + _compared_to_type=self.type, ) else: return Tuple( diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index b2dcc9b8a2..3c6cb0cb55 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -3157,6 +3157,20 @@ class TupleType(TypeEngine[Tuple[Any, ...]]): for item_type in types ] + def coerce_compared_value( + self, op: Optional[OperatorType], value: Any + ) -> TypeEngine[Any]: + + if value is type_api._NO_VALUE_IN_LIST: + return super().coerce_compared_value(op, value) + else: + return TupleType( + *[ + typ.coerce_compared_value(op, elem) + for typ, elem in zip(self.types, value) + ] + ) + def _resolve_values_to_types(self, value: Any) -> TupleType: if self._fully_typed: return self diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index af7ed21c48..7167430f14 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -11,6 +11,7 @@ from __future__ import annotations +from enum import Enum from types import ModuleType import typing from typing import Any @@ -68,7 +69,14 @@ _CT = TypeVar("_CT", bound=Any) _MatchedOnType = Union["GenericProtocol[Any]", NewType, Type[Any]] -# replace with pep-673 when applicable + +class _NoValueInList(Enum): + NO_VALUE_IN_LIST = 0 + """indicates we are trying to determine the type of an expression + against an empty list.""" + + +_NO_VALUE_IN_LIST = _NoValueInList.NO_VALUE_IN_LIST class _LiteralProcessorType(Protocol[_T_co]): diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index e7e51aa635..8ed8c7d33c 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -4303,6 +4303,14 @@ class TupleTypingTest(fixtures.TestBase): expr = t1 == (3, "hi", b"there") self._assert_types(expr.right.type.types) + def test_tuple_type_left_type_ignored(self): + a, b = column("a", sqltypes.Date), column("b", sqltypes.DateTime) + c = column("c", sqltypes.Float) + + t1 = tuple_(a, b, c) + expr = t1.in_([(3, "hi", b"there")]) + self._assert_types(expr.right.type.types) + class InSelectableTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default"