From: Mike Bayer Date: Wed, 21 Dec 2022 21:10:34 +0000 (-0500) Subject: check for adapt to same class in AbstractRange X-Git-Tag: rel_2_0_0rc1~13 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b973cbd8939f2cc0e29c668fffd507958c3e455a;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git check for adapt to same class in AbstractRange Fixed regression where newly revised PostgreSQL range types such as :class:`_postgresql.INT4RANGE` could not be set up as the impl of a :class:`.TypeDecorator` custom type, instead raising a ``TypeError``. Fixes: #9020 Change-Id: Ib881c3c7f63d000f49a09185a8663659a9970aa9 --- diff --git a/doc/build/changelog/unreleased_20/9020.rst b/doc/build/changelog/unreleased_20/9020.rst new file mode 100644 index 0000000000..b09912b40c --- /dev/null +++ b/doc/build/changelog/unreleased_20/9020.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, postgresql + :tickets: 9020 + + Fixed regression where newly revised PostgreSQL range types such as + :class:`_postgresql.INT4RANGE` could not be set up as the impl of a + :class:`.TypeDecorator` custom type, instead raising a ``TypeError``. diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 609af5eb62..69fc1977df 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -691,7 +691,7 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): and also render as ``INT4RANGE`` in SQL and DDL. """ - if issubclass(cls, AbstractRangeImpl): + if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__: # two ways to do this are: 1. create a new type on the fly # or 2. have AbstractRangeImpl(visit_name) constructor and a # visit_abstract_range_impl() method in the PG compiler. diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index caa758b0d2..e9d5e561f8 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4486,6 +4486,25 @@ class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): cols = insp.get_columns("data_table") assert isinstance(cols[0]["type"], self._col_type) + def test_type_decorator_round_trip(self, connection, metadata): + """test #9020""" + + class MyRange(TypeDecorator): + cache_ok = True + impl = self._col_type + + table = Table( + "typedec_table", + metadata, + Column("range", MyRange, primary_key=True), + ) + table.create(connection) + connection.execute(table.insert(), {"range": self._data_obj()}) + data = connection.execute( + select(table.c.range).where(table.c.range == self._data_obj()) + ).fetchall() + eq_(data, [(self._data_obj(),)]) + def test_textual_round_trip_w_dialect_type(self, connection): """test #8690""" data_table = self.tables.data_table diff --git a/test/sql/test_types.py b/test/sql/test_types.py index 59519a5ecc..47fe68354a 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -85,6 +85,7 @@ from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_not +from sqlalchemy.testing import is_true from sqlalchemy.testing import mock from sqlalchemy.testing import pickleable from sqlalchemy.testing.assertions import expect_raises_message @@ -120,6 +121,15 @@ def _types_for_mod(mod): def _all_types(omit_special_types=False): + yield from ( + typ + for typ, _ in _all_types_w_their_dialect( + omit_special_types=omit_special_types + ) + ) + + +def _all_types_w_their_dialect(omit_special_types=False): seen = set() for typ in _types_for_mod(types): if omit_special_types and ( @@ -129,6 +139,7 @@ def _all_types(omit_special_types=False): type_api.TypeEngineMixin, types.Variant, types.TypeDecorator, + types.PickleType, ) or type_api.TypeEngineMixin in typ.__bases__ ): @@ -137,13 +148,13 @@ def _all_types(omit_special_types=False): if typ in seen: continue seen.add(typ) - yield typ + yield typ, default.DefaultDialect for dialect in _all_dialect_modules(): for typ in _types_for_mod(dialect): if typ in seen: continue seen.add(typ) - yield typ + yield typ, dialect.dialect def _get_instance(type_): @@ -350,6 +361,40 @@ class AdaptTest(fixtures.TestBase): t2 = t1.adapt(Text) eq_(t2.length, 50) + @testing.combinations( + *[ + (t, d) + for t, d in _all_types_w_their_dialect(omit_special_types=True) + ] + ) + def test_every_possible_type_can_be_decorated(self, typ, dialect_cls): + """test for #9020 + + Apparently the adapt() method is called with the same class as given + in the case of :class:`.TypeDecorator`, at least with the + PostgreSQL RANGE types, which is not usually expected. + + """ + my_type = type("MyType", (TypeDecorator,), {"impl": typ}) + + if issubclass(typ, ARRAY): + inst = my_type(Integer) + elif issubclass(typ, pg.ENUM): + inst = my_type(name="my_enum") + elif issubclass(typ, pg.DOMAIN): + inst = my_type(name="my_domain", data_type=Integer) + else: + inst = my_type() + impl = inst._unwrapped_dialect_impl(dialect_cls()) + + if dialect_cls is default.DefaultDialect: + is_true(isinstance(impl, typ)) + + if impl._type_affinity is Interval: + is_true(issubclass(typ, sqltypes._AbstractInterval)) + else: + is_true(issubclass(typ, impl._type_affinity)) + class TypeAffinityTest(fixtures.TestBase): @testing.combinations(