]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
use tuple expansion if type._is_tuple, test for Sequence if no type
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Nov 2021 14:18:42 +0000 (10:18 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 5 Nov 2021 15:55:24 +0000 (11:55 -0400)
Fixed regression where the row objects returned for ORM queries, which are
now the normal :class:`_sql.Row` objects, would not be interpreted by the
:meth:`_sql.ColumnOperators.in_` operator as tuple values to be broken out
into individual bound parameters, and would instead pass them as single
values to the driver leading to failures. The change to the "expanding IN"
system now accommodates for the expression already being of type
:class:`.TupleType` and treats values accordingly if so. In the uncommon
case of using "tuple-in" with an untyped statement such as a textual
statement with no typing information, a tuple value is detected for values
that implement ``collections.abc.Sequence``, but that are not ``str`` or
``bytes``, as always when testing for ``Sequence``.

Added :class:`.TupleType` to the top level ``sqlalchemy`` import namespace.

Fixes: #7292
Change-Id: I8286387e3b3c3752b3bd4ae3560d4f31172acc22
(cherry picked from commit 0c44a1e77cfde0f841a4a64140314c6b833efdab)

doc/build/changelog/unreleased_14/7292.rst [new file with mode: 0644]
lib/sqlalchemy/__init__.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/testing/suite/test_select.py
lib/sqlalchemy/types.py
lib/sqlalchemy/util/__init__.py
test/sql/test_lambdas.py
test/sql/test_resultset.py

diff --git a/doc/build/changelog/unreleased_14/7292.rst b/doc/build/changelog/unreleased_14/7292.rst
new file mode 100644 (file)
index 0000000..e75d11e
--- /dev/null
@@ -0,0 +1,20 @@
+.. change::
+    :tags: bug, sql, regression
+    :tickets: 7292
+
+    Fixed regression where the row objects returned for ORM queries, which are
+    now the normal :class:`_sql.Row` objects, would not be interpreted by the
+    :meth:`_sql.ColumnOperators.in_` operator as tuple values to be broken out
+    into individual bound parameters, and would instead pass them as single
+    values to the driver leading to failures. The change to the "expanding IN"
+    system now accommodates for the expression already being of type
+    :class:`.TupleType` and treats values accordingly if so. In the uncommon
+    case of using "tuple-in" with an untyped statement such as a textual
+    statement with no typing information, a tuple value is detected for values
+    that implement ``collections.abc.Sequence``, but that are not ``str`` or
+    ``bytes``, as always when testing for ``Sequence``.
+
+.. change::
+    :tags: usecase, sql
+
+    Added :class:`.TupleType` to the top level ``sqlalchemy`` import namespace.
\ No newline at end of file
index ad6d96fdd3e4f5d89584ec79883a0c13da7ea5df..d5cc233243b7a49ef7fb4bbec06ab264f9520934 100644 (file)
@@ -123,6 +123,7 @@ from .types import Text
 from .types import TIME
 from .types import Time
 from .types import TIMESTAMP
+from .types import TupleType
 from .types import TypeDecorator
 from .types import Unicode
 from .types import UnicodeText
index 266452851bc78e46dbbee933a110892a1da92ba9..7db8d6b5d6a25d061dd8459099eeed8b26c3c0ae 100644 (file)
@@ -2025,8 +2025,14 @@ class SQLCompiler(Compiled):
                     [parameter.type], parameter.expand_op
                 )
 
-        elif isinstance(values[0], (tuple, list)):
-            assert typ_dialect_impl._is_tuple_type
+        elif typ_dialect_impl._is_tuple_type or (
+            typ_dialect_impl._isnull
+            and isinstance(values[0], util.collections_abc.Sequence)
+            and not isinstance(
+                values[0], util.string_types + util.binary_types
+            )
+        ):
+
             replacement_expression = (
                 "VALUES " if self.dialect.tuple_in_values else ""
             ) + ", ".join(
@@ -2042,7 +2048,6 @@ class SQLCompiler(Compiled):
                 for i, tuple_element in enumerate(values)
             )
         else:
-            assert not typ_dialect_impl._is_tuple_type
             replacement_expression = ", ".join(
                 self.render_literal_value(value, parameter.type)
                 for value in values
@@ -2071,10 +2076,14 @@ class SQLCompiler(Compiled):
                     [parameter.type], parameter.expand_op
                 )
 
-        elif (
-            isinstance(values[0], (tuple, list))
-            and not typ_dialect_impl._is_array
+        elif typ_dialect_impl._is_tuple_type or (
+            typ_dialect_impl._isnull
+            and isinstance(values[0], util.collections_abc.Sequence)
+            and not isinstance(
+                values[0], util.string_types + util.binary_types
+            )
         ):
+            assert not typ_dialect_impl._is_array
             to_update = [
                 ("%s_%s_%s" % (name, i, j), value)
                 for i, tuple_element in enumerate(values, 1)
index ae589d648a8a8b7c742206284069da590266682f..3f3801ab009d159eee99741d50bc8bd2b931ddbd 100644 (file)
@@ -2966,7 +2966,10 @@ class TupleType(TypeEngine):
 
     def __init__(self, *types):
         self._fully_typed = NULLTYPE not in types
-        self.types = types
+        self.types = [
+            item_type() if isinstance(item_type, type) else item_type
+            for item_type in types
+        ]
 
     def _resolve_values_to_types(self, value):
         if self._fully_typed:
index a3475f651b46b681d3fa776a9e89269a44bc815e..3e3ad04a7825d930bfb1913ff2db2dc4b3b53918 100644 (file)
@@ -30,11 +30,13 @@ from ... import testing
 from ... import text
 from ... import true
 from ... import tuple_
+from ... import TupleType
 from ... import union
 from ... import util
 from ... import values
 from ...exc import DatabaseError
 from ...exc import ProgrammingError
+from ...util import collections_abc
 
 
 class CollateTest(fixtures.TablesTest):
@@ -1131,6 +1133,41 @@ class ExpandingBoundInTest(fixtures.TablesTest):
         )
         self._assert_result(stmt, [])
 
+    def test_typed_str_in(self):
+        """test related to #7292.
+
+        as a type is given to the bound param, there is no ambiguity
+        to the type of element.
+
+        """
+
+        stmt = text(
+            "select id FROM some_table WHERE z IN :q ORDER BY id"
+        ).bindparams(bindparam("q", type_=String, expanding=True))
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={"q": ["z2", "z3", "z4"]},
+        )
+
+    def test_untyped_str_in(self):
+        """test related to #7292.
+
+        for untyped expression, we look at the types of elements.
+        Test for Sequence to detect tuple in.  but not strings or bytes!
+        as always....
+
+        """
+
+        stmt = text(
+            "select id FROM some_table WHERE z IN :q ORDER BY id"
+        ).bindparams(bindparam("q", expanding=True))
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={"q": ["z2", "z3", "z4"]},
+        )
+
     @testing.requires.tuple_in
     def test_bound_in_two_tuple_bindparam(self):
         table = self.tables.some_table
@@ -1197,6 +1234,73 @@ class ExpandingBoundInTest(fixtures.TablesTest):
             params={"q": [(2, "z2"), (3, "z3"), (4, "z4")]},
         )
 
+    @testing.requires.tuple_in
+    def test_bound_in_heterogeneous_two_tuple_typed_bindparam_non_tuple(self):
+        class LikeATuple(collections_abc.Sequence):
+            def __init__(self, *data):
+                self._data = data
+
+            def __iter__(self):
+                return iter(self._data)
+
+            def __getitem__(self, idx):
+                return self._data[idx]
+
+            def __len__(self):
+                return len(self._data)
+
+        stmt = text(
+            "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+        ).bindparams(
+            bindparam(
+                "q", type_=TupleType(Integer(), String()), expanding=True
+            )
+        )
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={
+                "q": [
+                    LikeATuple(2, "z2"),
+                    LikeATuple(3, "z3"),
+                    LikeATuple(4, "z4"),
+                ]
+            },
+        )
+
+    @testing.requires.tuple_in
+    def test_bound_in_heterogeneous_two_tuple_text_bindparam_non_tuple(self):
+        # note this becomes ARRAY if we dont use expanding
+        # explicitly right now
+
+        class LikeATuple(collections_abc.Sequence):
+            def __init__(self, *data):
+                self._data = data
+
+            def __iter__(self):
+                return iter(self._data)
+
+            def __getitem__(self, idx):
+                return self._data[idx]
+
+            def __len__(self):
+                return len(self._data)
+
+        stmt = text(
+            "select id FROM some_table WHERE (x, z) IN :q ORDER BY id"
+        ).bindparams(bindparam("q", expanding=True))
+        self._assert_result(
+            stmt,
+            [(2,), (3,), (4,)],
+            params={
+                "q": [
+                    LikeATuple(2, "z2"),
+                    LikeATuple(3, "z3"),
+                    LikeATuple(4, "z4"),
+                ]
+            },
+        )
+
     def test_empty_set_against_integer_bindparam(self):
         table = self.tables.some_table
         stmt = (
index ecc351fc948c1f3cf09c007f8c88a91357a77daf..df8abdc6944f6710e75d725e21a197e4c38fc9d3 100644 (file)
@@ -36,6 +36,7 @@ __all__ = [
     "INTEGER",
     "DATE",
     "TIME",
+    "TupleType",
     "String",
     "Integer",
     "SmallInteger",
@@ -103,6 +104,7 @@ from .sql.sqltypes import Text
 from .sql.sqltypes import TIME
 from .sql.sqltypes import Time
 from .sql.sqltypes import TIMESTAMP
+from .sql.sqltypes import TupleType
 from .sql.sqltypes import Unicode
 from .sql.sqltypes import UnicodeText
 from .sql.sqltypes import VARBINARY
index bdd69431e0f81c3f731b416266139a51a481fc9d..e4e79294f205ea1c195b65c235bd98f83881d504 100644 (file)
@@ -53,6 +53,7 @@ from .compat import b
 from .compat import b64decode
 from .compat import b64encode
 from .compat import binary_type
+from .compat import binary_types
 from .compat import byte_buffer
 from .compat import callable
 from .compat import cmp
index 2e794d7bcf9640545906f86690e641f5657b84fe..a53401a4f1062428dc8049f6a3c37e0b5060abf9 100644 (file)
@@ -26,6 +26,7 @@ from sqlalchemy.testing import is_
 from sqlalchemy.testing import ne_
 from sqlalchemy.testing.assertions import expect_raises_message
 from sqlalchemy.testing.assertsql import CompiledSQL
+from sqlalchemy.types import ARRAY
 from sqlalchemy.types import Boolean
 from sqlalchemy.types import Integer
 from sqlalchemy.types import String
@@ -1180,9 +1181,9 @@ class LambdaElementTest(
     def test_in_parameters_five(self):
         def go(n1, n2):
             stmt = lambdas.lambda_stmt(
-                lambda: select(1).where(column("q").in_(n1))
+                lambda: select(1).where(column("q", ARRAY(String)).in_(n1))
             )
-            stmt += lambda s: s.where(column("y").in_(n2))
+            stmt += lambda s: s.where(column("y", ARRAY(String)).in_(n2))
             return stmt
 
         expr = go(["a", "b", "c"], ["d", "e", "f"])
index c02b3cbc1b6987cd161680e8fc7e3575f19f73ce..d07f81facee01199b52f154a3600182ed06e83c9 100644 (file)
@@ -21,6 +21,7 @@ from sqlalchemy import table
 from sqlalchemy import testing
 from sqlalchemy import text
 from sqlalchemy import true
+from sqlalchemy import tuple_
 from sqlalchemy import type_coerce
 from sqlalchemy import TypeDecorator
 from sqlalchemy import util
@@ -775,6 +776,37 @@ class CursorResultTest(fixtures.TablesTest):
         connection.execute(users.insert(), r._mapping)
         eq_(connection.execute(users.select()).fetchall(), [(1, "john")])
 
+    @testing.requires.tuple_in
+    def test_row_tuple_interpretation(self, connection):
+        """test #7292"""
+        users = self.tables.users
+
+        connection.execute(
+            users.insert(),
+            [
+                dict(user_id=1, user_name="u1"),
+                dict(user_id=2, user_name="u2"),
+                dict(user_id=3, user_name="u3"),
+            ],
+        )
+        rows = connection.execute(
+            select(users.c.user_id, users.c.user_name)
+        ).all()
+
+        # was previously needed
+        # rows = [(x, y) for x, y in rows]
+
+        new_stmt = (
+            select(users)
+            .where(tuple_(users.c.user_id, users.c.user_name).in_(rows))
+            .order_by(users.c.user_id)
+        )
+
+        eq_(
+            connection.execute(new_stmt).all(),
+            [(1, "u1"), (2, "u2"), (3, "u3")],
+        )
+
     def test_result_as_args(self, connection):
         users = self.tables.users
         users2 = self.tables.users2