]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
check for adapt to same class in AbstractRange
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Dec 2022 21:10:34 +0000 (16:10 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 21 Dec 2022 21:12:53 +0000 (16:12 -0500)
Fixed regression where newly revised PostgreSQL range types such as
:class:`_postgresql.INT4RANGE` could not be set up as the impl of a
:class:`.TypeDecorator` custom type, instead raising a ``TypeError``.

Fixes: #9020
Change-Id: Ib881c3c7f63d000f49a09185a8663659a9970aa9

doc/build/changelog/unreleased_20/9020.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py
test/sql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/9020.rst b/doc/build/changelog/unreleased_20/9020.rst
new file mode 100644 (file)
index 0000000..b09912b
--- /dev/null
@@ -0,0 +1,7 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 9020
+
+    Fixed regression where newly revised PostgreSQL range types such as
+    :class:`_postgresql.INT4RANGE` could not be set up as the impl of a
+    :class:`.TypeDecorator` custom type, instead raising a ``TypeError``.
index 609af5eb6200084245d54125cdebbea2ee4798f0..69fc1977dfcca7d1a6709a691ed1bfcfc5d3b068 100644 (file)
@@ -691,7 +691,7 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
         and also render as ``INT4RANGE`` in SQL and DDL.
 
         """
-        if issubclass(cls, AbstractRangeImpl):
+        if issubclass(cls, AbstractRangeImpl) and cls is not self.__class__:
             # two ways to do this are:  1. create a new type on the fly
             # or 2. have AbstractRangeImpl(visit_name) constructor and a
             # visit_abstract_range_impl() method in the PG compiler.
index caa758b0d2a2fd0c6dd0837bb32cdcb64973c04b..e9d5e561f80df28eb0d78eaef8b622e6ad4a2512 100644 (file)
@@ -4486,6 +4486,25 @@ class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
         cols = insp.get_columns("data_table")
         assert isinstance(cols[0]["type"], self._col_type)
 
+    def test_type_decorator_round_trip(self, connection, metadata):
+        """test #9020"""
+
+        class MyRange(TypeDecorator):
+            cache_ok = True
+            impl = self._col_type
+
+        table = Table(
+            "typedec_table",
+            metadata,
+            Column("range", MyRange, primary_key=True),
+        )
+        table.create(connection)
+        connection.execute(table.insert(), {"range": self._data_obj()})
+        data = connection.execute(
+            select(table.c.range).where(table.c.range == self._data_obj())
+        ).fetchall()
+        eq_(data, [(self._data_obj(),)])
+
     def test_textual_round_trip_w_dialect_type(self, connection):
         """test #8690"""
         data_table = self.tables.data_table
index 59519a5eccc048c6ab99597268253cd9aefd2d3e..47fe68354a5261ce18c64005f91522d22957778e 100644 (file)
@@ -85,6 +85,7 @@ from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import is_not
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing import mock
 from sqlalchemy.testing import pickleable
 from sqlalchemy.testing.assertions import expect_raises_message
@@ -120,6 +121,15 @@ def _types_for_mod(mod):
 
 
 def _all_types(omit_special_types=False):
+    yield from (
+        typ
+        for typ, _ in _all_types_w_their_dialect(
+            omit_special_types=omit_special_types
+        )
+    )
+
+
+def _all_types_w_their_dialect(omit_special_types=False):
     seen = set()
     for typ in _types_for_mod(types):
         if omit_special_types and (
@@ -129,6 +139,7 @@ def _all_types(omit_special_types=False):
                 type_api.TypeEngineMixin,
                 types.Variant,
                 types.TypeDecorator,
+                types.PickleType,
             )
             or type_api.TypeEngineMixin in typ.__bases__
         ):
@@ -137,13 +148,13 @@ def _all_types(omit_special_types=False):
         if typ in seen:
             continue
         seen.add(typ)
-        yield typ
+        yield typ, default.DefaultDialect
     for dialect in _all_dialect_modules():
         for typ in _types_for_mod(dialect):
             if typ in seen:
                 continue
             seen.add(typ)
-            yield typ
+            yield typ, dialect.dialect
 
 
 def _get_instance(type_):
@@ -350,6 +361,40 @@ class AdaptTest(fixtures.TestBase):
         t2 = t1.adapt(Text)
         eq_(t2.length, 50)
 
+    @testing.combinations(
+        *[
+            (t, d)
+            for t, d in _all_types_w_their_dialect(omit_special_types=True)
+        ]
+    )
+    def test_every_possible_type_can_be_decorated(self, typ, dialect_cls):
+        """test for #9020
+
+        Apparently the adapt() method is called with the same class as given
+        in the case of :class:`.TypeDecorator`, at least with the
+        PostgreSQL RANGE types, which is not usually expected.
+
+        """
+        my_type = type("MyType", (TypeDecorator,), {"impl": typ})
+
+        if issubclass(typ, ARRAY):
+            inst = my_type(Integer)
+        elif issubclass(typ, pg.ENUM):
+            inst = my_type(name="my_enum")
+        elif issubclass(typ, pg.DOMAIN):
+            inst = my_type(name="my_domain", data_type=Integer)
+        else:
+            inst = my_type()
+        impl = inst._unwrapped_dialect_impl(dialect_cls())
+
+        if dialect_cls is default.DefaultDialect:
+            is_true(isinstance(impl, typ))
+
+        if impl._type_affinity is Interval:
+            is_true(issubclass(typ, sqltypes._AbstractInterval))
+        else:
+            is_true(issubclass(typ, impl._type_affinity))
+
 
 class TypeAffinityTest(fixtures.TestBase):
     @testing.combinations(