]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Correct for Variant + ARRAY cases in psycopg2
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Apr 2021 16:26:06 +0000 (12:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Apr 2021 17:23:50 +0000 (13:23 -0400)
Fixed regression caused by :ticket:`6023` where the PostgreSQL cast
operator applied to elements within an :class:`_types.ARRAY` when using
psycopg2 would fail to use the correct type in the case that the datatype
were also embedded within an instance of the :class:`_types.Variant`
adapter.

Additionally, repairs support for the correct CREATE TYPE to be emitted
when using a ``Variant(ARRAY(some_schema_type))``.

Fixes: #6182
Change-Id: I1b9ba7c876980d4650715a0b0801b46bdc72860d
(cherry picked from commit 43e98c75bde96ef27daeaf7fbfbea30b7eb7c295)

doc/build/changelog/unreleased_13/6182.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/dialect/postgresql/test_types.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_13/6182.rst b/doc/build/changelog/unreleased_13/6182.rst
new file mode 100644 (file)
index 0000000..f38213c
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, postgresql, regression
+    :tickets: 6182
+
+    Fixed regression caused by :ticket:`6023` where the PostgreSQL cast
+    operator applied to elements within an :class:`_types.ARRAY` when using
+    psycopg2 would fail to use the correct type in the case that the datatype
+    were also embedded within an instance of the :class:`_types.Variant`
+    adapter.
+
+    Additionally, repairs support for the correct CREATE TYPE to be emitted
+    when using a ``Variant(ARRAY(some_schema_type))``.
\ No newline at end of file
index 93ee1fa1ab089402dfebd4fc546dc2979fe1ee5c..6c3f71a01484452a95029d17e4f456c4e4a1d3d5 100644 (file)
@@ -1642,7 +1642,6 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
             return False
 
     def _on_table_create(self, target, bind, checkfirst=False, **kw):
-
         if (
             checkfirst
             or (
index 9b75e25dcc23a0c1a2d0b589688638aa7a13e49d..1744f5b044ef33233357e4dacecbd54b44f92a89 100644 (file)
@@ -652,12 +652,17 @@ class PGCompiler_psycopg2(PGCompiler):
         )
         # note that if the type has a bind_expression(), we will get a
         # double compile here
-        if not skip_bind_expression and bindparam.type._is_array:
-            text += "::%s" % (
-                elements.TypeClause(bindparam.type)._compiler_dispatch(
-                    self, skip_bind_expression=skip_bind_expression, **kw
-                ),
-            )
+        if not skip_bind_expression and (
+            bindparam.type._is_array or bindparam.type._is_type_decorator
+        ):
+            typ = bindparam.type._unwrapped_dialect_impl(self.dialect)
+
+            if typ._is_array:
+                text += "::%s" % (
+                    elements.TypeClause(typ)._compiler_dispatch(
+                        self, skip_bind_expression=skip_bind_expression, **kw
+                    ),
+                )
         return text
 
 
index fe2ca9a09cd99c7dcba445fa859362160c645798..44608c9ca329342c733e33a85fce34d22f7fe321 100644 (file)
@@ -1179,13 +1179,19 @@ class SchemaType(SchemaEventTarget):
         if variant_mapping is None:
             return True
 
-        if (
-            dialect.name in variant_mapping
-            and variant_mapping[dialect.name] is self
+        # since PostgreSQL is the only DB that has ARRAY this can only
+        # be integration tested by PG-specific tests
+        def _we_are_the_impl(typ):
+            return (
+                typ is self or isinstance(typ, ARRAY) and typ.item_type is self
+            )
+
+        if dialect.name in variant_mapping and _we_are_the_impl(
+            variant_mapping[dialect.name]
         ):
             return True
         elif dialect.name not in variant_mapping:
-            return variant_mapping["_default"] is self
+            return _we_are_the_impl(variant_mapping["_default"])
 
 
 class Enum(Emulated, String, SchemaType):
@@ -2745,16 +2751,16 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
     def compare_values(self, x, y):
         return x == y
 
-    def _set_parent(self, column, **kw):
+    def _set_parent(self, column, outer=False, **kw):
         """Support SchemaEventTarget"""
 
-        if isinstance(self.item_type, SchemaEventTarget):
+        if not outer and isinstance(self.item_type, SchemaEventTarget):
             self.item_type._set_parent(column, **kw)
 
     def _set_parent_with_dispatch(self, parent, **kw):
         """Support SchemaEventTarget"""
 
-        super(ARRAY, self)._set_parent_with_dispatch(parent)
+        super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True)
 
         if isinstance(self.item_type, SchemaEventTarget):
             self.item_type._set_parent_with_dispatch(parent)
index 7cf6971cde85972d8c872f86eaaa27cd85bc0cdf..241bba0ed72ee61aec33ce5e1cfa2598424c0b23 100644 (file)
@@ -46,6 +46,7 @@ class TypeEngine(Visitable):
     _sqla_type = True
     _isnull = False
     _is_array = False
+    _is_type_decorator = False
 
     class Comparator(operators.ColumnOperators):
         """Base class for custom comparison operations defined at the
@@ -884,6 +885,8 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
 
     __visit_name__ = "type_decorator"
 
+    _is_type_decorator = True
+
     def __init__(self, *args, **kwargs):
         """Construct a :class:`.TypeDecorator`.
 
@@ -1412,7 +1415,7 @@ class Variant(TypeDecorator):
         else:
             return self.impl
 
-    def _set_parent(self, column, **kw):
+    def _set_parent(self, column, outer=False, **kw):
         """Support SchemaEventTarget"""
 
         if isinstance(self.impl, SchemaEventTarget):
index 559cac8af135f622f11fb133564f51e1e677bc0e..5d0f5933f156bcd27256b0e4f39e20188ed6a540 100644 (file)
@@ -50,6 +50,8 @@ from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
+from sqlalchemy.sql import type_coerce
+from sqlalchemy.sql.type_api import Variant
 from sqlalchemy.testing import engines
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
@@ -1740,12 +1742,18 @@ class ArrayRoundTripTest(object):
                 "data_2",
                 self.ARRAY(types.Enum("a", "b", "c", name="my_enum_2")),
             ),
+            Column(
+                "data_3",
+                self.ARRAY(
+                    types.Enum("a", "b", "c", name="my_enum_3")
+                ).with_variant(String(), "other"),
+            ),
         )
 
         t.create(testing.db)
         eq_(
             set(e["name"] for e in inspect(testing.db).get_enums()),
-            set(["my_enum_1", "my_enum_2"]),
+            set(["my_enum_1", "my_enum_2", "my_enum_3"]),
         )
         t.drop(testing.db)
         eq_(inspect(testing.db).get_enums(), [])
@@ -1888,6 +1896,7 @@ class ArrayRoundTripTest(object):
                 ],
                 testing.requires.hstore,
             ),
+            (postgresql.ENUM(AnEnum), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=True), enum_values),
             (
                 sqltypes.Enum(
@@ -1918,14 +1927,21 @@ class ArrayRoundTripTest(object):
 
         m.drop_all(testing.db)
 
-    @testing.fixture
-    def type_specific_fixture(self, metadata, connection, type_):
+    @testing.fixture(params=[True, False])
+    def type_specific_fixture(self, request, metadata, connection, type_):
+        use_variant = request.param
         meta = MetaData()
+
+        if use_variant:
+            typ = self.ARRAY(type_).with_variant(String(), "other")
+        else:
+            typ = self.ARRAY(type_)
+
         table = Table(
             "foo",
             meta,
             Column("id", Integer),
-            Column("bar", self.ARRAY(type_)),
+            Column("bar", typ),
         )
 
         meta.create_all(connection)
@@ -1975,10 +1991,14 @@ class ArrayRoundTripTest(object):
 
         new_gen = gen(3)
 
+        if isinstance(table.c.bar.type, Variant):
+            # this is not likely to occur to users but we need to just
+            # exercise this as far as we can
+            expr = type_coerce(table.c.bar, ARRAY(type_))[1:3]
+        else:
+            expr = table.c.bar[1:3]
         connection.execute(
-            table.update()
-            .where(table.c.id == 2)
-            .values({table.c.bar[1:3]: new_gen[1:4]})
+            table.update().where(table.c.id == 2).values({expr: new_gen[1:4]})
         )
 
         rows = connection.execute(
index 65c5def205001b1b11fe1fc132b93168fd1657d1..5cd97c5ef8abefe8414718a412cfc9748e0dfd68 100644 (file)
@@ -2047,6 +2047,12 @@ class SchemaTypeTest(fixtures.TestBase):
         typ = MyType()
         self._test_before_parent_attach(typ)
 
+    def test_before_parent_attach_variant_array_schematype(self):
+
+        target = Enum("one", "two", "three")
+        typ = ARRAY(target).with_variant(String(), "other")
+        self._test_before_parent_attach(typ, evt_target=target)
+
     def _test_before_parent_attach(self, typ, evt_target=None):
         canary = mock.Mock()