From ab01f893f8c489e2fe981699e022c76e0318ec77 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Thu, 1 Apr 2021 12:26:06 -0400 Subject: [PATCH] Correct for Variant + ARRAY cases in psycopg2 Fixed regression caused by :ticket:`6023` where the PostgreSQL cast operator applied to elements within an :class:`_types.ARRAY` when using psycopg2 would fail to use the correct type in the case that the datatype were also embedded within an instance of the :class:`_types.Variant` adapter. Additionally, repairs support for the correct CREATE TYPE to be emitted when using a ``Variant(ARRAY(some_schema_type))``. Fixes: #6182 Change-Id: I1b9ba7c876980d4650715a0b0801b46bdc72860d --- doc/build/changelog/unreleased_13/6182.rst | 12 +++++++ lib/sqlalchemy/dialects/postgresql/base.py | 1 - .../dialects/postgresql/psycopg2.py | 17 ++++++---- lib/sqlalchemy/sql/sqltypes.py | 20 +++++++---- lib/sqlalchemy/sql/type_api.py | 5 ++- test/dialect/postgresql/test_types.py | 33 +++++++++++++++---- test/sql/test_metadata.py | 6 ++++ 7 files changed, 72 insertions(+), 22 deletions(-) create mode 100644 doc/build/changelog/unreleased_13/6182.rst diff --git a/doc/build/changelog/unreleased_13/6182.rst b/doc/build/changelog/unreleased_13/6182.rst new file mode 100644 index 0000000000..f38213c8f9 --- /dev/null +++ b/doc/build/changelog/unreleased_13/6182.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, postgresql, regression + :tickets: 6182 + + Fixed regression caused by :ticket:`6023` where the PostgreSQL cast + operator applied to elements within an :class:`_types.ARRAY` when using + psycopg2 would fail to use the correct type in the case that the datatype + were also embedded within an instance of the :class:`_types.Variant` + adapter. + + Additionally, repairs support for the correct CREATE TYPE to be emitted + when using a ``Variant(ARRAY(some_schema_type))``. \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 0854214d02..97eb07bdb6 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1989,7 +1989,6 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst=False, **kw): - if ( checkfirst or ( diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 1969eb8446..16f9ecefae 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -605,12 +605,17 @@ class PGCompiler_psycopg2(PGCompiler): ) # note that if the type has a bind_expression(), we will get a # double compile here - if not skip_bind_expression and bindparam.type._is_array: - text += "::%s" % ( - elements.TypeClause(bindparam.type)._compiler_dispatch( - self, skip_bind_expression=skip_bind_expression, **kw - ), - ) + if not skip_bind_expression and ( + bindparam.type._is_array or bindparam.type._is_type_decorator + ): + typ = bindparam.type._unwrapped_dialect_impl(self.dialect) + + if typ._is_array: + text += "::%s" % ( + elements.TypeClause(typ)._compiler_dispatch( + self, skip_bind_expression=skip_bind_expression, **kw + ), + ) return text diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 367b2e2037..7cc50d99c8 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1221,13 +1221,19 @@ class SchemaType(SchemaEventTarget): if variant_mapping is None: return True - if ( - dialect.name in variant_mapping - and variant_mapping[dialect.name] is self + # since PostgreSQL is the only DB that has ARRAY this can only + # be integration tested by PG-specific tests + def _we_are_the_impl(typ): + return ( + typ is self or isinstance(typ, ARRAY) and typ.item_type is self + ) + + if dialect.name in variant_mapping and _we_are_the_impl( + variant_mapping[dialect.name] ): return True elif dialect.name not in variant_mapping: - return variant_mapping["_default"] is self + return _we_are_the_impl(variant_mapping["_default"]) class Enum(Emulated, String, SchemaType): @@ -2857,16 +2863,16 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): def compare_values(self, x, y): return x == y - def _set_parent(self, column, **kw): + def _set_parent(self, column, outer=False, **kw): """Support SchemaEventTarget""" - if isinstance(self.item_type, SchemaEventTarget): + if not outer and isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent(column, **kw) def _set_parent_with_dispatch(self, parent): """Support SchemaEventTarget""" - super(ARRAY, self)._set_parent_with_dispatch(parent) + super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True) if isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent_with_dispatch(parent) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index bfce00cb53..69cd3c5caf 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -48,6 +48,7 @@ class TypeEngine(Traversible): _is_tuple_type = False _is_table_value = False _is_array = False + _is_type_decorator = False class Comparator(operators.ColumnOperators): """Base class for custom comparison operations defined at the @@ -955,6 +956,8 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): __visit_name__ = "type_decorator" + _is_type_decorator = True + def __init__(self, *args, **kwargs): """Construct a :class:`.TypeDecorator`. @@ -1497,7 +1500,7 @@ class Variant(TypeDecorator): else: return self.impl - def _set_parent(self, column, **kw): + def _set_parent(self, column, outer=False, **kw): """Support SchemaEventTarget""" if isinstance(self.impl, SchemaEventTarget): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 343f3a9865..d93f7cc0ac 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -49,6 +49,7 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes +from sqlalchemy.sql.type_api import Variant from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -1696,12 +1697,18 @@ class ArrayRoundTripTest(object): "data_2", self.ARRAY(types.Enum("a", "b", "c", name="my_enum_2")), ), + Column( + "data_3", + self.ARRAY( + types.Enum("a", "b", "c", name="my_enum_3") + ).with_variant(String(), "other"), + ), ) t.create(connection) eq_( set(e["name"] for e in inspect(connection).get_enums()), - set(["my_enum_1", "my_enum_2"]), + set(["my_enum_1", "my_enum_2", "my_enum_3"]), ) t.drop(connection) eq_(inspect(connection).get_enums(), []) @@ -1844,6 +1851,7 @@ class ArrayRoundTripTest(object): ], testing.requires.hstore, ), + (postgresql.ENUM(AnEnum), enum_values), (sqltypes.Enum(AnEnum, native_enum=True), enum_values), (sqltypes.Enum(AnEnum, native_enum=False), enum_values), ] @@ -1864,14 +1872,21 @@ class ArrayRoundTripTest(object): def _cls_type_combinations(cls, **kw): return ArrayRoundTripTest.__dict__["_type_combinations"](**kw) - @testing.fixture - def type_specific_fixture(self, metadata, connection, type_): + @testing.fixture(params=[True, False]) + def type_specific_fixture(self, request, metadata, connection, type_): + use_variant = request.param meta = MetaData() + + if use_variant: + typ = self.ARRAY(type_).with_variant(String(), "other") + else: + typ = self.ARRAY(type_) + table = Table( "foo", meta, Column("id", Integer), - Column("bar", self.ARRAY(type_)), + Column("bar", typ), ) meta.create_all(connection) @@ -1921,10 +1936,14 @@ class ArrayRoundTripTest(object): new_gen = gen(3) + if isinstance(table.c.bar.type, Variant): + # this is not likely to occur to users but we need to just + # exercise this as far as we can + expr = type_coerce(table.c.bar, ARRAY(type_))[1:3] + else: + expr = table.c.bar[1:3] connection.execute( - table.update() - .where(table.c.id == 2) - .values({table.c.bar[1:3]: new_gen[1:4]}) + table.update().where(table.c.id == 2).values({expr: new_gen[1:4]}) ) rows = connection.execute( diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index eb5e305a24..e2f015b00e 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -2125,6 +2125,12 @@ class SchemaTypeTest(fixtures.TestBase): typ = MyType() self._test_before_parent_attach(typ) + def test_before_parent_attach_variant_array_schematype(self): + + target = Enum("one", "two", "three") + typ = ARRAY(target).with_variant(String(), "other") + self._test_before_parent_attach(typ, evt_target=target) + def _test_before_parent_attach(self, typ, evt_target=None): canary = mock.Mock() -- 2.47.2