]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Correct for Variant + ARRAY cases in psycopg2
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Apr 2021 16:26:06 +0000 (12:26 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Apr 2021 17:22:26 +0000 (13:22 -0400)
Fixed regression caused by :ticket:`6023` where the PostgreSQL cast
operator applied to elements within an :class:`_types.ARRAY` when using
psycopg2 would fail to use the correct type in the case that the datatype
were also embedded within an instance of the :class:`_types.Variant`
adapter.

Additionally, repairs support for the correct CREATE TYPE to be emitted
when using a ``Variant(ARRAY(some_schema_type))``.

Fixes: #6182
Change-Id: I1b9ba7c876980d4650715a0b0801b46bdc72860d

doc/build/changelog/unreleased_13/6182.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/base.py
lib/sqlalchemy/dialects/postgresql/psycopg2.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/type_api.py
test/dialect/postgresql/test_types.py
test/sql/test_metadata.py

diff --git a/doc/build/changelog/unreleased_13/6182.rst b/doc/build/changelog/unreleased_13/6182.rst
new file mode 100644 (file)
index 0000000..f38213c
--- /dev/null
@@ -0,0 +1,12 @@
+.. change::
+    :tags: bug, postgresql, regression
+    :tickets: 6182
+
+    Fixed regression caused by :ticket:`6023` where the PostgreSQL cast
+    operator applied to elements within an :class:`_types.ARRAY` when using
+    psycopg2 would fail to use the correct type in the case that the datatype
+    were also embedded within an instance of the :class:`_types.Variant`
+    adapter.
+
+    Additionally, repairs support for the correct CREATE TYPE to be emitted
+    when using a ``Variant(ARRAY(some_schema_type))``.
\ No newline at end of file
index 0854214d029ac639c375a6290007179455f275d1..97eb07bdb6e0f89b2b62f2959fc545522ac71422 100644 (file)
@@ -1989,7 +1989,6 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
             return False
 
     def _on_table_create(self, target, bind, checkfirst=False, **kw):
-
         if (
             checkfirst
             or (
index 1969eb844628fa10f36be04d13cfdf97e539ba3e..16f9ecefae5778406f849fe8935b493a9d8ca13d 100644 (file)
@@ -605,12 +605,17 @@ class PGCompiler_psycopg2(PGCompiler):
         )
         # note that if the type has a bind_expression(), we will get a
         # double compile here
-        if not skip_bind_expression and bindparam.type._is_array:
-            text += "::%s" % (
-                elements.TypeClause(bindparam.type)._compiler_dispatch(
-                    self, skip_bind_expression=skip_bind_expression, **kw
-                ),
-            )
+        if not skip_bind_expression and (
+            bindparam.type._is_array or bindparam.type._is_type_decorator
+        ):
+            typ = bindparam.type._unwrapped_dialect_impl(self.dialect)
+
+            if typ._is_array:
+                text += "::%s" % (
+                    elements.TypeClause(typ)._compiler_dispatch(
+                        self, skip_bind_expression=skip_bind_expression, **kw
+                    ),
+                )
         return text
 
 
index 367b2e203788516fb7e2fc7d1eac4771cb0bb64b..7cc50d99c86843f3f66fdba3418b393afeac8f05 100644 (file)
@@ -1221,13 +1221,19 @@ class SchemaType(SchemaEventTarget):
         if variant_mapping is None:
             return True
 
-        if (
-            dialect.name in variant_mapping
-            and variant_mapping[dialect.name] is self
+        # since PostgreSQL is the only DB that has ARRAY this can only
+        # be integration tested by PG-specific tests
+        def _we_are_the_impl(typ):
+            return (
+                typ is self or isinstance(typ, ARRAY) and typ.item_type is self
+            )
+
+        if dialect.name in variant_mapping and _we_are_the_impl(
+            variant_mapping[dialect.name]
         ):
             return True
         elif dialect.name not in variant_mapping:
-            return variant_mapping["_default"] is self
+            return _we_are_the_impl(variant_mapping["_default"])
 
 
 class Enum(Emulated, String, SchemaType):
@@ -2857,16 +2863,16 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
     def compare_values(self, x, y):
         return x == y
 
-    def _set_parent(self, column, **kw):
+    def _set_parent(self, column, outer=False, **kw):
         """Support SchemaEventTarget"""
 
-        if isinstance(self.item_type, SchemaEventTarget):
+        if not outer and isinstance(self.item_type, SchemaEventTarget):
             self.item_type._set_parent(column, **kw)
 
     def _set_parent_with_dispatch(self, parent):
         """Support SchemaEventTarget"""
 
-        super(ARRAY, self)._set_parent_with_dispatch(parent)
+        super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True)
 
         if isinstance(self.item_type, SchemaEventTarget):
             self.item_type._set_parent_with_dispatch(parent)
index bfce00cb53ba1a80f3061857183aef4792c62816..69cd3c5caf4cc99e289404ecc67c20985aec69e0 100644 (file)
@@ -48,6 +48,7 @@ class TypeEngine(Traversible):
     _is_tuple_type = False
     _is_table_value = False
     _is_array = False
+    _is_type_decorator = False
 
     class Comparator(operators.ColumnOperators):
         """Base class for custom comparison operations defined at the
@@ -955,6 +956,8 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
 
     __visit_name__ = "type_decorator"
 
+    _is_type_decorator = True
+
     def __init__(self, *args, **kwargs):
         """Construct a :class:`.TypeDecorator`.
 
@@ -1497,7 +1500,7 @@ class Variant(TypeDecorator):
         else:
             return self.impl
 
-    def _set_parent(self, column, **kw):
+    def _set_parent(self, column, outer=False, **kw):
         """Support SchemaEventTarget"""
 
         if isinstance(self.impl, SchemaEventTarget):
index 343f3a986584f07a8c9ff253f439c58f49362e70..d93f7cc0acb84a826cbb0f060d6bbbd287cc1ccf 100644 (file)
@@ -49,6 +49,7 @@ from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
+from sqlalchemy.sql.type_api import Variant
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
@@ -1696,12 +1697,18 @@ class ArrayRoundTripTest(object):
                 "data_2",
                 self.ARRAY(types.Enum("a", "b", "c", name="my_enum_2")),
             ),
+            Column(
+                "data_3",
+                self.ARRAY(
+                    types.Enum("a", "b", "c", name="my_enum_3")
+                ).with_variant(String(), "other"),
+            ),
         )
 
         t.create(connection)
         eq_(
             set(e["name"] for e in inspect(connection).get_enums()),
-            set(["my_enum_1", "my_enum_2"]),
+            set(["my_enum_1", "my_enum_2", "my_enum_3"]),
         )
         t.drop(connection)
         eq_(inspect(connection).get_enums(), [])
@@ -1844,6 +1851,7 @@ class ArrayRoundTripTest(object):
                 ],
                 testing.requires.hstore,
             ),
+            (postgresql.ENUM(AnEnum), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=True), enum_values),
             (sqltypes.Enum(AnEnum, native_enum=False), enum_values),
         ]
@@ -1864,14 +1872,21 @@ class ArrayRoundTripTest(object):
     def _cls_type_combinations(cls, **kw):
         return ArrayRoundTripTest.__dict__["_type_combinations"](**kw)
 
-    @testing.fixture
-    def type_specific_fixture(self, metadata, connection, type_):
+    @testing.fixture(params=[True, False])
+    def type_specific_fixture(self, request, metadata, connection, type_):
+        use_variant = request.param
         meta = MetaData()
+
+        if use_variant:
+            typ = self.ARRAY(type_).with_variant(String(), "other")
+        else:
+            typ = self.ARRAY(type_)
+
         table = Table(
             "foo",
             meta,
             Column("id", Integer),
-            Column("bar", self.ARRAY(type_)),
+            Column("bar", typ),
         )
 
         meta.create_all(connection)
@@ -1921,10 +1936,14 @@ class ArrayRoundTripTest(object):
 
         new_gen = gen(3)
 
+        if isinstance(table.c.bar.type, Variant):
+            # this is not likely to occur to users but we need to just
+            # exercise this as far as we can
+            expr = type_coerce(table.c.bar, ARRAY(type_))[1:3]
+        else:
+            expr = table.c.bar[1:3]
         connection.execute(
-            table.update()
-            .where(table.c.id == 2)
-            .values({table.c.bar[1:3]: new_gen[1:4]})
+            table.update().where(table.c.id == 2).values({expr: new_gen[1:4]})
         )
 
         rows = connection.execute(
index eb5e305a24ff5e231cdaa041abd7d128663d2028..e2f015b00ea44cf9298d3d692a3d95a0362c3e64 100644 (file)
@@ -2125,6 +2125,12 @@ class SchemaTypeTest(fixtures.TestBase):
         typ = MyType()
         self._test_before_parent_attach(typ)
 
+    def test_before_parent_attach_variant_array_schematype(self):
+
+        target = Enum("one", "two", "three")
+        typ = ARRAY(target).with_variant(String(), "other")
+        self._test_before_parent_attach(typ, evt_target=target)
+
     def _test_before_parent_attach(self, typ, evt_target=None):
         canary = mock.Mock()