]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use tuple expansion if type._is_tuple, test for Sequence if no type
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Nov 2021 14:18:42 +0000 (10:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Nov 2021 15:55:14 +0000 (11:55 -0400)
Fixed regression where the row objects returned for ORM queries, which are
now the normal :class:`_sql.Row` objects, would not be interpreted by the
:meth:`_sql.ColumnOperators.in_` operator as tuple values to be broken out
into individual bound parameters, and would instead pass them as single
values to the driver leading to failures. The change to the "expanding IN"
system now accommodates for the expression already being of type
:class:`.TupleType` and treats values accordingly if so. In the uncommon
case of using "tuple-in" with an untyped statement such as a textual
statement with no typing information, a tuple value is detected for values
that implement ``collections.abc.Sequence``, but that are not ``str`` or
``bytes``, as always when testing for ``Sequence``.

Added :class:`.TupleType` to the top level ``sqlalchemy`` import namespace.

Fixes: #7292
Change-Id: I8286387e3b3c3752b3bd4ae3560d4f31172acc22

doc/build/changelog/unreleased_14/7292.rst [new file with mode: 0644]
lib/sqlalchemy/__init__.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/suite/test_select.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util/__init__.py
test/sql/test_lambdas.py
test/sql/test_resultset.py

diff --git a/doc/build/changelog/unreleased_14/7292.rst b/doc/build/changelog/unreleased_14/7292.rst
new file mode 100644 (file)
index 0000000..e75d11e
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 7292
+
+    Fixed regression where the row objects returned for ORM queries, which are
+    now the normal :class:`_sql.Row` objects, would not be interpreted by the
+    :meth:`_sql.ColumnOperators.in_` operator as tuple values to be broken out
+    into individual bound parameters, and would instead pass them as single
+    values to the driver leading to failures. The change to the "expanding IN"
+    system now accommodates for the expression already being of type
+    :class:`.TupleType` and treats values accordingly if so. In the uncommon
+    case of using "tuple-in" with an untyped statement such as a textual
+    statement with no typing information, a tuple value is detected for values
+    that implement ``collections.abc.Sequence``, but that are not ``str`` or
+    ``bytes``, as always when testing for ``Sequence``.
+
+.. change::
+    :tags: usecase, sql
+
+    Added :class:`.TupleType` to the top level ``sqlalchemy`` import namespace.
\ No newline at end of file
index dc49690a503444f62ce42691ad2d6223c72e32b7..3580dae5982cf2f4aba64b95e11118780370d344 100644 (file)
@@ -123,6 +123,7 @@ from .types import Text
 from .types import TIME
 from .types import Time
 from .types import TIMESTAMP
+from .types import TupleType
 from .types import TypeDecorator
 from .types import Unicode
 from .types import UnicodeText
index bcede5d7676295af560420c84ec7fb664214d62b..96349578cb4c6dd68b2d2f43d5e1d54aad7da98c 100644 (file)
@@ -2024,8 +2024,14 @@ class SQLCompiler(Compiled):
                     [parameter.type], parameter.expand_op
                 )
 
-        elif isinstance(values[0], (tuple, list)):
-            assert typ_dialect_impl._is_tuple_type
+        elif typ_dialect_impl._is_tuple_type or (
+            typ_dialect_impl._isnull
+            and isinstance(values[0], util.collections_abc.Sequence)
+            and not isinstance(
+                values[0], util.string_types + util.binary_types
+            )
+        ):
+
             replacement_expression = (
                 "VALUES " if self.dialect.tuple_in_values else ""
             ) + ", ".join(
@@ -2041,7 +2047,6 @@ class SQLCompiler(Compiled):
                 for i, tuple_element in enumerate(values)
             )
         else:
-            assert not typ_dialect_impl._is_tuple_type
             replacement_expression = ", ".join(
                 self.render_literal_value(value, parameter.type)
                 for value in values
@@ -2070,10 +2075,14 @@ class SQLCompiler(Compiled):
                     [parameter.type], parameter.expand_op
                 )
 
-        elif (
-            isinstance(values[0], (tuple, list))
-            and not typ_dialect_impl._is_array
+        elif typ_dialect_impl._is_tuple_type or (
+            typ_dialect_impl._isnull
+            and isinstance(values[0], util.collections_abc.Sequence)
+            and not isinstance(
+                values[0], util.string_types + util.binary_types
+            )
         ):
+            assert not typ_dialect_impl._is_array
             to_update = [
                 ("%s_%s_%s" % (name, i, j), value)
                 for i, tuple_element in enumerate(values, 1)
index 77af76d0b8d5242dafd0d070d863a9e3882fd51d..0d7a06e313f8fbd8e6f770fa3efe51d3628a951c 100644 (file)
@@ -2949,7 +2949,10 @@ class TupleType(TypeEngine):
 
     def __init__(self, *types):
         self._fully_typed = NULLTYPE not in types
-        self.types = types
+        self.types = [
+            item_type() if isinstance(item_type, type) else item_type
+            for item_type in types
+        ]
 
     def _resolve_values_to_types(self, value):
         if self._fully_typed:
index 63502b077fa0c9dfe38e7eb55eaa14e038d17d47..bea8a60751aa384b9c39f0e0bc375c799d9b1cce 100644 (file)
@@ -30,11 +30,13 @@ from ... import testing
 from ... import text
 from ... import true
 from ... import tuple_
+from ... import TupleType
 from ... import union
 from ... import util
 from ... import values
 from ...exc import DatabaseError
 from ...exc import ProgrammingError
+from ...util import collections_abc
 
 
 class CollateTest(fixtures.TablesTest):
@@ -1131,6 +1133,41 @@ class ExpandingBoundInTest(fixtures.TablesTest):
         )
         self._assert_result(stmt, [])
 
+    def test_typed_str_in(self):
+        """test related to #7292.
+
+        as a type is given to the bound param, there is no ambiguity
+        to the type of element.
+
+        """
+
+        stmt = text(
+            "select id FROM some_table WHERE z IN :q ORDER BY id"
+        ).bindparams(bindparam("q", type_=String, expanding=True))
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={"q": ["z2", "z3", "z4"]},
+        )
+
+    def test_untyped_str_in(self):
+        """test related to #7292.
+
+        for untyped expression, we look at the types of elements.
+        Test for Sequence to detect tuple in.  but not strings or bytes!
+        as always....
+
+        """
+
+        stmt = text(
+            "select id FROM some_table WHERE z IN :q ORDER BY id"
+        ).bindparams(bindparam("q", expanding=True))
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={"q": ["z2", "z3", "z4"]},
+        )
+
     @testing.requires.tuple_in
     def test_bound_in_two_tuple_bindparam(self):
         table = self.tables.some_table
@@ -1197,6 +1234,73 @@ class ExpandingBoundInTest(fixtures.TablesTest):
             params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
         )
 
+    @testing.requires.tuple_in
+    def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self):
+        class LikeATuple(collections_abc.Sequence):
+            def __init__(self, *data):
+                self._data = data
+
+            def __iter__(self):
+                return iter(self._data)
+
+            def __getitem__(self, idx):
+                return self._data[idx]
+
+            def __len__(self):
+                return len(self._data)
+
+        stmt = text(
+            "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+        ).bindparams(
+            bindparam(
+                "q", type_=TupleType(Integer(), String()), expanding=True
+            )
+        )
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={
+                "q": [
+                    LikeATuple(2, "z2"),
+                    LikeATuple(3, "z3"),
+                    LikeATuple(4, "z4"),
+                ]
+            },
+        )
+
+    @testing.requires.tuple_in
+    def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self):
+        # note this becomes ARRAY if we dont use expanding
+        # explicitly right now
+
+        class LikeATuple(collections_abc.Sequence):
+            def __init__(self, *data):
+                self._data = data
+
+            def __iter__(self):
+                return iter(self._data)
+
+            def __getitem__(self, idx):
+                return self._data[idx]
+
+            def __len__(self):
+                return len(self._data)
+
+        stmt = text(
+            "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+        ).bindparams(bindparam("q", expanding=True))
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={
+                "q": [
+                    LikeATuple(2, "z2"),
+                    LikeATuple(3, "z3"),
+                    LikeATuple(4, "z4"),
+                ]
+            },
+        )
+
     def test_empty_set_against_integer_bindparam(self):
         table = self.tables.some_table
         stmt = (
index ecc351fc948c1f3cf09c007f8c88a91357a77daf..df8abdc6944f6710e75d725e21a197e4c38fc9d3 100644 (file)
@@ -36,6 +36,7 @@ __all__ = [
     "INTEGER",
     "DATE",
     "TIME",
+    "TupleType",
     "String",
     "Integer",
     "SmallInteger",
@@ -103,6 +104,7 @@ from .sql.sqltypes import Text
 from .sql.sqltypes import TIME
 from .sql.sqltypes import Time
 from .sql.sqltypes import TIMESTAMP
+from .sql.sqltypes import TupleType
 from .sql.sqltypes import Unicode
 from .sql.sqltypes import UnicodeText
 from .sql.sqltypes import VARBINARY
index 327f767159eafd658895fdc1df74d440699cb737..8a18a584a292332623744245dff1fc177fa78195 100644 (file)
@@ -53,6 +53,7 @@ from .compat import b
 from .compat import b64decode
 from .compat import b64encode
 from .compat import binary_type
+from .compat import binary_types
 from .compat import byte_buffer
 from .compat import callable
 from .compat import cmp
index 2e794d7bcf9640545906f86690e641f5657b84fe..a53401a4f1062428dc8049f6a3c37e0b5060abf9 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.types import ARRAY
 from sqlalchemy.types import Boolean
 from sqlalchemy.types import Integer
 from sqlalchemy.types import String
@@ -1180,9 +1181,9 @@ class LambdaElementTest(
     def test_in_parameters_five(self):
         def go(n1, n2):
             stmt = lambdas.lambda_stmt(
-                lambda: select(1).where(column("q").in_(n1))
+                lambda: select(1).where(column("q", ARRAY(String)).in_(n1))
             )
-            stmt += lambda s: s.where(column("y").in_(n2))
+            stmt += lambda s: s.where(column("y", ARRAY(String)).in_(n2))
             return stmt
 
         expr = go(["a", "b", "c"], ["d", "e", "f"])
index 346cb3d58914bc17560a34a19da4bf7826d37420..3909fe60dfec1e07be0fa4d1e149123f7d5f9bf9 100644 (file)
@@ -21,6 +21,7 @@ from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import true
+from sqlalchemy import tuple_
 from sqlalchemy import type_coerce
 from sqlalchemy import TypeDecorator
 from sqlalchemy import util
@@ -869,6 +870,37 @@ class CursorResultTest(fixtures.TablesTest):
         connection.execute(users.insert(), r._mapping)
         eq_(connection.execute(users.select()).fetchall(), [(1, "john")])
 
+    @testing.requires.tuple_in
+    def test_row_tuple_interpretation(self, connection):
+        """test #7292"""
+        users = self.tables.users
+
+        connection.execute(
+            users.insert(),
+            [
+                dict(user_id=1, user_name="u1"),
+                dict(user_id=2, user_name="u2"),
+                dict(user_id=3, user_name="u3"),
+            ],
+        )
+        rows = connection.execute(
+            select(users.c.user_id, users.c.user_name)
+        ).all()
+
+        # was previously needed
+        # rows = [(x, y) for x, y in rows]
+
+        new_stmt = (
+            select(users)
+            .where(tuple_(users.c.user_id, users.c.user_name).in_(rows))
+            .order_by(users.c.user_id)
+        )
+
+        eq_(
+            connection.execute(new_stmt).all(),
+            [(1, "u1"), (2, "u2"), (3, "u3")],
+        )
+
     def test_result_as_args(self, connection):
         users = self.tables.users
         users2 = self.tables.users2