]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Support TypeDecorator.get_dbapi_type() for setinpusizes
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 30 Dec 2020 18:56:20 +0000 (13:56 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 30 Dec 2020 20:06:34 +0000 (15:06 -0500)
Adjusted the "setinputsizes" logic relied upon by the cx_Oracle, asyncpg
and pg8000 dialects to support a :class:`.TypeDecorator` that includes
an override the :meth:`.TypeDecorator.get_dbapi_type()` method.

Change-Id: I5aa70abf0d9a9e2ca43309f2dd80b3fcd83881b9

doc/build/changelog/unreleased_14/setinputsize.rst [new file with mode: 0644]
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/testing/suite/test_types.py
test/orm/test_lazy_relations.py

diff --git a/doc/build/changelog/unreleased_14/setinputsize.rst b/doc/build/changelog/unreleased_14/setinputsize.rst
new file mode 100644 (file)
index 0000000..f8694cd
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, engine, postgresql, oracle
+
+    Adjusted the "setinputsizes" logic relied upon by the cx_Oracle, asyncpg
+    and pg8000 dialects to support a :class:`.TypeDecorator` that includes
+    an override the :meth:`.TypeDecorator.get_dbapi_type()` method.
+
index a734bb5825db3e0dfbe7d49f35b21c802a434161..8ee575cca5b75f5a88f20f73df17bcc58207a95c 100644 (file)
@@ -1047,14 +1047,19 @@ class SQLCompiler(Compiled):
         if include_types is None and exclude_types is None:
 
             def _lookup_type(typ):
-                dialect_impl = typ._unwrapped_dialect_impl(dialect)
-                return dialect_impl.get_dbapi_type(dbapi)
+                dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
+                return dbtype
 
         else:
 
             def _lookup_type(typ):
+                # note we get dbtype from the possibly TypeDecorator-wrapped
+                # dialect_impl, but the dialect_impl itself that we use for
+                # include/exclude is the unwrapped version.
+
                 dialect_impl = typ._unwrapped_dialect_impl(dialect)
-                dbtype = dialect_impl.get_dbapi_type(dbapi)
+
+                dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
 
                 if (
                     dbtype is not None
index 749e83de436086178b63915a233bca943f8187cd..43777239c606e5c5a4cab463fe8246c80fd8df2d 100644 (file)
@@ -462,6 +462,45 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase):
             assert isinstance(row[0], (long, int))  # noqa
 
 
+class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
+    __backend__ = True
+
+    @testing.fixture
+    def string_as_int(self):
+        class StringAsInt(TypeDecorator):
+            impl = String(50)
+
+            def get_dbapi_type(self, dbapi):
+                return dbapi.NUMBER
+
+            def column_expression(self, col):
+                return cast(col, Integer)
+
+            def bind_expression(self, col):
+                return cast(col, String(50))
+
+        return StringAsInt()
+
+    @testing.provide_metadata
+    def test_special_type(self, connection, string_as_int):
+
+        type_ = string_as_int
+
+        metadata = self.metadata
+        t = Table("t", metadata, Column("x", type_))
+        t.create(connection)
+
+        connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]])
+
+        result = {row[0] for row in connection.execute(t.select())}
+        eq_(result, {1, 2, 3})
+
+        result = {
+            row[0] for row in connection.execute(t.select().where(t.c.x == 2))
+        }
+        eq_(result, {2})
+
+
 class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase):
     __backend__ = True
 
@@ -1302,6 +1341,7 @@ __all__ = (
     "TextTest",
     "NumericTest",
     "IntegerTest",
+    "CastTypeDecoratorTest",
     "DateTimeHistoricTest",
     "DateTimeCoercedToDateTimeTest",
     "TimeMicrosecondsTest",
index e8da84841e21d6a07a5609b8bedf2c4022cb6002..c81de142c790f3ebfbda9ff574c16d4000f6baab 100644 (file)
@@ -1453,17 +1453,19 @@ class RefersToSelfLazyLoadInterferenceTest(fixtures.MappedTest):
 class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults):
     """ORM-level test for [ticket:3531]"""
 
-    # mysql is having a recursion issue in the bind_expression
-    __only_on__ = ("sqlite", "postgresql")
+    __backend__ = True
 
     class StringAsInt(TypeDecorator):
         impl = String(50)
 
+        def get_dbapi_type(self, dbapi):
+            return dbapi.NUMBER
+
         def column_expression(self, col):
             return sa.cast(col, Integer)
 
         def bind_expression(self, col):
-            return sa.cast(col, String)
+            return sa.cast(col, String(50))
 
     @classmethod
     def define_tables(cls, metadata):