From: Mike Bayer Date: Thu, 8 Apr 2021 01:43:17 +0000 (-0400) Subject: Infer types in BindParameter when expanding=True X-Git-Tag: rel_1_4_7~11^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=df078a6fb010e28cb14afa1f0947add1f60e0e52;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Infer types in BindParameter when expanding=True Enhanced the "expanding" feature used for :meth:`_sql.ColumnOperators.in_` operations to infer the type of expression from the right hand list of elements, if the left hand side does not have any explicit type set up. This allows the expression to support stringification among other things. In 1.3, "expanding" was not automatically used for :meth:`_sql.ColumnOperators.in_` expressions, so in that sense this change fixes a behavioral regression. Fixes: #6222 Change-Id: Icdfda1e2c226a21896cafd6d8f251547794451c2 --- diff --git a/doc/build/changelog/unreleased_14/6222.rst b/doc/build/changelog/unreleased_14/6222.rst new file mode 100644 index 0000000000..7464b09c61 --- /dev/null +++ b/doc/build/changelog/unreleased_14/6222.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, sql, regression + :tickets: 6222 + + Enhanced the "expanding" feature used for :meth:`_sql.ColumnOperators.in_` + operations to infer the type of expression from the right hand list of + elements, if the left hand side does not have any explicit type set up. + This allows the expression to support stringification among other things. + In 1.3, "expanding" was not automatically used for + :meth:`_sql.ColumnOperators.in_` expressions, so in that sense this change + fixes a behavioral regression. + diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index dfa6c0f8f4..9e1b690888 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1392,15 +1392,26 @@ class BindParameter(roles.InElementRole, ColumnElement): self.literal_execute = literal_execute if _is_crud: self._is_crud = True + if type_ is None: + if expanding and value: + check_value = value[0] + else: + check_value = value if _compared_to_type is not None: self.type = _compared_to_type.coerce_compared_value( - _compared_to_operator, value + _compared_to_operator, check_value ) else: - self.type = type_api._resolve_value_to_type(value) + self.type = type_api._resolve_value_to_type(check_value) elif isinstance(type_, type): self.type = type_() + elif type_._is_tuple_type: + if expanding and value: + check_value = value[0] + else: + check_value = value + self.type = type_._resolve_values_to_types(check_value) else: self.type = type_ diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index e57a14681d..bdfb7d8332 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2912,8 +2912,20 @@ class TupleType(TypeEngine): _is_tuple_type = True def __init__(self, *types): + self._fully_typed = NULLTYPE not in types self.types = types + def _resolve_values_to_types(self, value): + if self._fully_typed: + return self + else: + return TupleType( + *[ + _resolve_value_to_type(elem) if typ is NULLTYPE else typ + for typ, elem in zip(self.types, value) + ] + ) + def result_processor(self, dialect, coltype): raise NotImplementedError( "The tuple type does not support being fetched " diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 02dd0661a5..21a349d76f 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -692,12 +692,32 @@ class CoreFixtures(object): column("y").in_( bindparam( "q", + # note that a different cache key is created if + # the value given to the bindparam is [], as the type + # cannot be inferred for the empty list but can + # for the non-empty list as of #6222 + random_choices(range(10), k=random.randint(1, 7)), + expanding=True, + ) + ), + column("y2").in_( + bindparam( + "q", + # for typed param, empty and not empty param will have + # the same type random_choices(range(10), k=random.randint(0, 7)), + type_=Integer, expanding=True, ) ), - column("z").in_(random_choices(range(10), k=random.randint(0, 7))), - column("x") == random.randint(1, 10), + # don't include empty for untyped, will create different cache + # key + column("z").in_(random_choices(range(10), k=random.randint(1, 7))), + # empty is fine for typed, will create the same cache key + column("z2", Integer).in_( + random_choices(range(10), k=random.randint(0, 7)) + ), + column("x") == random.randint(0, 10), ) ] diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index 0d7f331e0e..878360b9d8 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -1989,6 +1989,23 @@ class InTest(fixtures.TestBase, testing.AssertsCompiledSQL): "SELECT t.a FROM t WHERE t.a IN (SELECT scan)", ) + def test_type_inference_one(self): + expr = column("q").in_([1, 2, 3]) + is_(expr.right.type._type_affinity, Integer) + + self.assert_compile(expr, "q IN (1, 2, 3)", literal_binds=True) + + def test_type_inference_two(self): + expr = column("q").in_([]) + is_(expr.right.type, sqltypes.NULLTYPE) + + self.assert_compile( + expr, + "q IN (SELECT 1 WHERE 1!=1)", + literal_binds=True, + dialect="default_enhanced", + ) + class MathOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -3075,6 +3092,26 @@ class TupleTypingTest(fixtures.TestBase): self._assert_types(expr.right.type.types) + # since we want to infer "binary" + @testing.requires.python3 + def test_tuple_type_expanding_inference(self): + a, b, c = column("a"), column("b"), column("c") + + t1 = tuple_(a, b, c) + expr = t1.in_([(3, "hi", b"there"), (4, "Q", b"P")]) + + eq_(len(expr.right.value), 2) + + self._assert_types(expr.right.type.types) + + @testing.requires.python3 + def test_tuple_type_plain_inference(self): + a, b, c = column("a"), column("b"), column("c") + + t1 = tuple_(a, b, c) + expr = t1 == (3, "hi", b"there") + self._assert_types(expr.right.type.types) + class InSelectableTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" diff --git a/test/sql/test_selectable.py b/test/sql/test_selectable.py index b98487933c..291c3f36e4 100644 --- a/test/sql/test_selectable.py +++ b/test/sql/test_selectable.py @@ -3013,7 +3013,7 @@ class AnnotationsTest(fixtures.TestBase): class ReprTest(fixtures.TestBase): def test_ensure_repr_elements(self): for obj in [ - elements.Cast(1, 2), + elements.Cast(1, Integer()), elements.TypeClause(String()), elements.ColumnClause("x"), elements.BindParameter("q"),