]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Remove special rule for TypeDecorator of TypeDecorator
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 5 Jan 2021 13:48:36 +0000 (08:48 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 6 Jan 2021 02:57:06 +0000 (21:57 -0500)
Removing this check for "TypeDecorator" in impl seems to not
break anything and allows TypeDecorator.with_variant() to
work correctly.   The line has been traced back to 2007 and
does not appear to have relevance today.

Fixed bug where making use of the :meth:`.TypeEngine.with_variant` method
on a :class:`.TypeDecorator` type would fail to take into account the
dialect-specific mappings in use, due to a rule in :class:`.TypeDecorator`
that was instead attempting to check for chains of :class:`.TypeDecorator`
instances.

Fixes: #5816
Change-Id: Ic86d9d985810e3050f15972b4841108acca2fa3e
(cherry picked from commit 458f83c6d213a80c2f0353b96421de36aee705f3)

doc/build/changelog/unreleased_13/5816.rst [new file with mode: 0644]
lib/sqlalchemy/sql/type_api.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_13/5816.rst b/doc/build/changelog/unreleased_13/5816.rst
new file mode 100644 (file)
index 0000000..5049622
--- /dev/null
@@ -0,0 +1,10 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 5816
+
+    Fixed bug where making use of the :meth:`.TypeEngine.with_variant` method
+    on a :class:`.TypeDecorator` type would fail to take into account the
+    dialect-specific mappings in use, due to a rule in :class:`.TypeDecorator`
+    that was instead attempting to check for chains of :class:`.TypeDecorator`
+    instances.
+
index 0bf7aa1674ea24e7af49f99c4bc29179c1eea458..9f279a3e9433d7acd99714cc38c594649f4ddac8 100644 (file)
@@ -1013,8 +1013,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
 
         In most cases this returns a dialect-adapted form of
         the :class:`.TypeEngine` type represented by ``self.impl``.
-        Makes usage of :meth:`dialect_impl` but also traverses
-        into wrapped :class:`.TypeDecorator` instances.
+        Makes usage of :meth:`dialect_impl`.
         Behavior can be customized here by overriding
         :meth:`load_dialect_impl`.
 
@@ -1022,8 +1021,6 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
         adapted = dialect.type_descriptor(self)
         if not isinstance(adapted, type(self)):
             return adapted
-        elif isinstance(self.impl, TypeDecorator):
-            return self.impl.type_engine(dialect)
         else:
             return self.load_dialect_impl(dialect)
 
index 2c028508f1852b60b86df0c43d95bd438008583d..55e9b0f9a51c3e09bda74f8a660fe8187aa2515b 100644 (file)
@@ -489,6 +489,9 @@ class _UserDefinedTypeFixture(object):
             def copy(self):
                 return MyUnicodeType(self.impl.length)
 
+        class MyDecOfDec(types.TypeDecorator):
+            impl = MyNewIntType
+
         Table(
             "users",
             metadata,
@@ -501,6 +504,7 @@ class _UserDefinedTypeFixture(object):
             Column("goofy7", MyNewUnicodeType(50), nullable=False),
             Column("goofy8", MyNewIntType, nullable=False),
             Column("goofy9", MyNewIntSubClass, nullable=False),
+            Column("goofy10", MyDecOfDec, nullable=False),
         )
 
 
@@ -520,6 +524,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
                     goofy7=util.u("jack"),
                     goofy8=12,
                     goofy9=12,
+                    goofy10=12,
                 ),
             )
             conn.execute(
@@ -532,6 +537,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
                     goofy7=util.u("lala"),
                     goofy8=15,
                     goofy9=15,
+                    goofy10=15,
                 ),
             )
             conn.execute(
@@ -544,6 +550,7 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
                     goofy7=util.u("fred"),
                     goofy8=9,
                     goofy9=9,
+                    goofy10=9,
                 ),
             )
 
@@ -581,7 +588,19 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
         result = testing.db.execute(stmt, {"goofy": [15, 9]})
         eq_(result.fetchall(), [(3, 1500), (4, 900)])
 
-    def test_expanding_in(self):
+    def test_plain_in_typedec_of_typedec(self, connection):
+        users = self.tables.users
+        self._data_fixture()
+
+        stmt = (
+            select([users.c.user_id, users.c.goofy10])
+            .where(users.c.goofy10.in_([15, 9]))
+            .order_by(users.c.user_id)
+        )
+        result = connection.execute(stmt, {"goofy": [15, 9]})
+        eq_(result.fetchall(), [(3, 1500), (4, 900)])
+
+    def test_expanding_in_typedec(self, connection):
         users = self.tables.users
         self._data_fixture()
 
@@ -593,6 +612,18 @@ class UserDefinedRoundTripTest(_UserDefinedTypeFixture, fixtures.TablesTest):
         result = testing.db.execute(stmt, {"goofy": [15, 9]})
         eq_(result.fetchall(), [(3, 1500), (4, 900)])
 
+    def test_expanding_in_typedec_of_typedec(self, connection):
+        users = self.tables.users
+        self._data_fixture()
+
+        stmt = (
+            select([users.c.user_id, users.c.goofy10])
+            .where(users.c.goofy10.in_(bindparam("goofy", expanding=True)))
+            .order_by(users.c.user_id)
+        )
+        result = connection.execute(stmt, {"goofy": [15, 9]})
+        eq_(result.fetchall(), [(3, 1500), (4, 900)])
+
 
 class UserDefinedTest(
     _UserDefinedTypeFixture, fixtures.TablesTest, AssertsCompiledSQL
@@ -1035,6 +1066,178 @@ class TypeCoerceCastTest(fixtures.TablesTest):
         )
 
 
+class VariantBackendTest(fixtures.TestBase, AssertsCompiledSQL):
+    __backend__ = True
+
+    @testing.fixture
+    def variant_roundtrip(self, connection):
+
+        metadata = MetaData()
+
+        def run(datatype, data, assert_data):
+            t = Table(
+                "t",
+                metadata,
+                Column("data", datatype),
+            )
+            t.create(connection)
+
+            connection.execute(t.insert(), [{"data": elem} for elem in data])
+            eq_(
+                connection.execute(select([t]).order_by(t.c.data)).fetchall(),
+                [(elem,) for elem in assert_data],
+            )
+
+            eq_(
+                # test an IN, which in 1.4 is an expanding
+                connection.execute(
+                    select([t]).where(t.c.data.in_(data)).order_by(t.c.data)
+                ).fetchall(),
+                [(elem,) for elem in assert_data],
+            )
+
+        try:
+            yield run
+        finally:
+            metadata.drop_all(connection)
+
+    def test_type_decorator_variant_one_roundtrip(self, variant_roundtrip):
+        class Foo(TypeDecorator):
+            impl = String(50)
+
+        if testing.against("postgresql"):
+            data = [5, 6, 10]
+        else:
+            data = ["five", "six", "ten"]
+        variant_roundtrip(
+            Foo().with_variant(Integer, "postgresql"), data, data
+        )
+
+    def test_type_decorator_variant_two(self, variant_roundtrip):
+        class UTypeOne(types.UserDefinedType):
+            def get_col_spec(self):
+                return "VARCHAR(50)"
+
+            def bind_processor(self, dialect):
+                def process(value):
+                    return value + "UONE"
+
+                return process
+
+        class UTypeTwo(types.UserDefinedType):
+            def get_col_spec(self):
+                return "VARCHAR(50)"
+
+            def bind_processor(self, dialect):
+                def process(value):
+                    return value + "UTWO"
+
+                return process
+
+        variant = UTypeOne()
+        for db in ["postgresql", "mysql", "mariadb"]:
+            variant = variant.with_variant(UTypeTwo(), db)
+
+        class Foo(TypeDecorator):
+            impl = variant
+
+        if testing.against("postgresql"):
+            data = assert_data = [5, 6, 10]
+        elif testing.against("mysql") or testing.against("mariadb"):
+            data = ["five", "six", "ten"]
+            assert_data = ["fiveUTWO", "sixUTWO", "tenUTWO"]
+        else:
+            data = ["five", "six", "ten"]
+            assert_data = ["fiveUONE", "sixUONE", "tenUONE"]
+
+        variant_roundtrip(
+            Foo().with_variant(Integer, "postgresql"), data, assert_data
+        )
+
+    def test_type_decorator_variant_three(self, variant_roundtrip):
+        class Foo(TypeDecorator):
+            impl = String
+
+        if testing.against("postgresql"):
+            data = ["five", "six", "ten"]
+        else:
+            data = [5, 6, 10]
+
+        variant_roundtrip(
+            Integer().with_variant(Foo(), "postgresql"), data, data
+        )
+
+    def test_type_decorator_compile_variant_one(self):
+        class Foo(TypeDecorator):
+            impl = String
+
+        self.assert_compile(
+            Foo().with_variant(Integer, "sqlite"),
+            "INTEGER",
+            dialect=dialects.sqlite.dialect(),
+        )
+
+        self.assert_compile(
+            Foo().with_variant(Integer, "sqlite"),
+            "VARCHAR",
+            dialect=dialects.postgresql.dialect(),
+        )
+
+    def test_type_decorator_compile_variant_two(self):
+        class UTypeOne(types.UserDefinedType):
+            def get_col_spec(self):
+                return "UTYPEONE"
+
+            def bind_processor(self, dialect):
+                def process(value):
+                    return value + "UONE"
+
+                return process
+
+        class UTypeTwo(types.UserDefinedType):
+            def get_col_spec(self):
+                return "UTYPETWO"
+
+            def bind_processor(self, dialect):
+                def process(value):
+                    return value + "UTWO"
+
+                return process
+
+        variant = UTypeOne().with_variant(UTypeTwo(), "postgresql")
+
+        class Foo(TypeDecorator):
+            impl = variant
+
+        self.assert_compile(
+            Foo().with_variant(Integer, "sqlite"),
+            "INTEGER",
+            dialect=dialects.sqlite.dialect(),
+        )
+
+        self.assert_compile(
+            Foo().with_variant(Integer, "sqlite"),
+            "UTYPETWO",
+            dialect=dialects.postgresql.dialect(),
+        )
+
+    def test_type_decorator_compile_variant_three(self):
+        class Foo(TypeDecorator):
+            impl = String
+
+        self.assert_compile(
+            Integer().with_variant(Foo(), "postgresql"),
+            "INTEGER",
+            dialect=dialects.sqlite.dialect(),
+        )
+
+        self.assert_compile(
+            Integer().with_variant(Foo(), "postgresql"),
+            "VARCHAR",
+            dialect=dialects.postgresql.dialect(),
+        )
+
+
 class VariantTest(fixtures.TestBase, AssertsCompiledSQL):
     def setup(self):
         class UTypeOne(types.UserDefinedType):
@@ -2322,6 +2525,9 @@ class ExpressionTest(
             def process_result_value(self, value, dialect):
                 return value + "BIND_OUT"
 
+        class MyDecOfDec(types.TypeDecorator):
+            impl = MyTypeDec
+
         meta = MetaData(testing.db)
         test_table = Table(
             "test",
@@ -2331,6 +2537,7 @@ class ExpressionTest(
             Column("atimestamp", Date),
             Column("avalue", MyCustomType),
             Column("bvalue", MyTypeDec(50)),
+            Column("cvalue", MyDecOfDec(50)),
         )
 
         meta.create_all()
@@ -2342,7 +2549,8 @@ class ExpressionTest(
                 "atimestamp": datetime.date(2007, 10, 15),
                 "avalue": 25,
                 "bvalue": "foo",
-            }
+                "cvalue": "foo",
+            },
         )
 
     @classmethod
@@ -2361,6 +2569,7 @@ class ExpressionTest(
                     datetime.date(2007, 10, 15),
                     25,
                     "BIND_INfooBIND_OUT",
+                    "BIND_INfooBIND_OUT",
                 )
             ],
         )
@@ -2398,6 +2607,7 @@ class ExpressionTest(
                     datetime.date(2007, 10, 15),
                     25,
                     "BIND_INfooBIND_OUT",
+                    "BIND_INfooBIND_OUT",
                 )
             ],
         )
@@ -2416,6 +2626,7 @@ class ExpressionTest(
                     datetime.date(2007, 10, 15),
                     25,
                     "BIND_INfooBIND_OUT",
+                    "BIND_INfooBIND_OUT",
                 )
             ],
         )