From: Mike Bayer Date: Wed, 30 Dec 2020 18:56:20 +0000 (-0500) Subject: Support TypeDecorator.get_dbapi_type() for setinpusizes X-Git-Tag: rel_1_4_0b2~75 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=102b91d8950926f1215dd7c59c5b7f200b5c0f8b;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Support TypeDecorator.get_dbapi_type() for setinpusizes Adjusted the "setinputsizes" logic relied upon by the cx_Oracle, asyncpg and pg8000 dialects to support a :class:`.TypeDecorator` that includes an override the :meth:`.TypeDecorator.get_dbapi_type()` method. Change-Id: I5aa70abf0d9a9e2ca43309f2dd80b3fcd83881b9 --- diff --git a/doc/build/changelog/unreleased_14/setinputsize.rst b/doc/build/changelog/unreleased_14/setinputsize.rst new file mode 100644 index 0000000000..f8694cd5fb --- /dev/null +++ b/doc/build/changelog/unreleased_14/setinputsize.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, engine, postgresql, oracle + + Adjusted the "setinputsizes" logic relied upon by the cx_Oracle, asyncpg + and pg8000 dialects to support a :class:`.TypeDecorator` that includes + an override the :meth:`.TypeDecorator.get_dbapi_type()` method. + diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index a734bb5825..8ee575cca5 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1047,14 +1047,19 @@ class SQLCompiler(Compiled): if include_types is None and exclude_types is None: def _lookup_type(typ): - dialect_impl = typ._unwrapped_dialect_impl(dialect) - return dialect_impl.get_dbapi_type(dbapi) + dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi) + return dbtype else: def _lookup_type(typ): + # note we get dbtype from the possibly TypeDecorator-wrapped + # dialect_impl, but the dialect_impl itself that we use for + # include/exclude is the unwrapped version. + dialect_impl = typ._unwrapped_dialect_impl(dialect) - dbtype = dialect_impl.get_dbapi_type(dbapi) + + dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi) if ( dbtype is not None diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 749e83de43..43777239c6 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -462,6 +462,45 @@ class IntegerTest(_LiteralRoundTripFixture, fixtures.TestBase): assert isinstance(row[0], (long, int)) # noqa +class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase): + __backend__ = True + + @testing.fixture + def string_as_int(self): + class StringAsInt(TypeDecorator): + impl = String(50) + + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + + def column_expression(self, col): + return cast(col, Integer) + + def bind_expression(self, col): + return cast(col, String(50)) + + return StringAsInt() + + @testing.provide_metadata + def test_special_type(self, connection, string_as_int): + + type_ = string_as_int + + metadata = self.metadata + t = Table("t", metadata, Column("x", type_)) + t.create(connection) + + connection.execute(t.insert(), [{"x": x} for x in [1, 2, 3]]) + + result = {row[0] for row in connection.execute(t.select())} + eq_(result, {1, 2, 3}) + + result = { + row[0] for row in connection.execute(t.select().where(t.c.x == 2)) + } + eq_(result, {2}) + + class NumericTest(_LiteralRoundTripFixture, fixtures.TestBase): __backend__ = True @@ -1302,6 +1341,7 @@ __all__ = ( "TextTest", "NumericTest", "IntegerTest", + "CastTypeDecoratorTest", "DateTimeHistoricTest", "DateTimeCoercedToDateTimeTest", "TimeMicrosecondsTest", diff --git a/test/orm/test_lazy_relations.py b/test/orm/test_lazy_relations.py index e8da84841e..c81de142c7 100644 --- a/test/orm/test_lazy_relations.py +++ b/test/orm/test_lazy_relations.py @@ -1453,17 +1453,19 @@ class RefersToSelfLazyLoadInterferenceTest(fixtures.MappedTest): class TypeCoerceTest(fixtures.MappedTest, testing.AssertsExecutionResults): """ORM-level test for [ticket:3531]""" - # mysql is having a recursion issue in the bind_expression - __only_on__ = ("sqlite", "postgresql") + __backend__ = True class StringAsInt(TypeDecorator): impl = String(50) + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER + def column_expression(self, col): return sa.cast(col, Integer) def bind_expression(self, col): - return sa.cast(col, String) + return sa.cast(col, String(50)) @classmethod def define_tables(cls, metadata):