From: Mike Bayer Date: Thu, 1 Apr 2021 16:26:06 +0000 (-0400) Subject: Correct for Variant + ARRAY cases in psycopg2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=a241f8c7539720ebcfe0ae240b593036e5ff9b65;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Correct for Variant + ARRAY cases in psycopg2 Fixed regression caused by :ticket:`6023` where the PostgreSQL cast operator applied to elements within an :class:`_types.ARRAY` when using psycopg2 would fail to use the correct type in the case that the datatype were also embedded within an instance of the :class:`_types.Variant` adapter. Additionally, repairs support for the correct CREATE TYPE to be emitted when using a ``Variant(ARRAY(some_schema_type))``. Fixes: #6182 Change-Id: I1b9ba7c876980d4650715a0b0801b46bdc72860d (cherry picked from commit 43e98c75bde96ef27daeaf7fbfbea30b7eb7c295) --- diff --git a/doc/build/changelog/unreleased_13/6182.rst b/doc/build/changelog/unreleased_13/6182.rst new file mode 100644 index 0000000000..f38213c8f9 --- /dev/null +++ b/doc/build/changelog/unreleased_13/6182.rst @@ -0,0 +1,12 @@ +.. change:: + :tags: bug, postgresql, regression + :tickets: 6182 + + Fixed regression caused by :ticket:`6023` where the PostgreSQL cast + operator applied to elements within an :class:`_types.ARRAY` when using + psycopg2 would fail to use the correct type in the case that the datatype + were also embedded within an instance of the :class:`_types.Variant` + adapter. + + Additionally, repairs support for the correct CREATE TYPE to be emitted + when using a ``Variant(ARRAY(some_schema_type))``. \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 93ee1fa1ab..6c3f71a014 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1642,7 +1642,6 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst=False, **kw): - if ( checkfirst or ( diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index 9b75e25dcc..1744f5b044 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -652,12 +652,17 @@ class PGCompiler_psycopg2(PGCompiler): ) # note that if the type has a bind_expression(), we will get a # double compile here - if not skip_bind_expression and bindparam.type._is_array: - text += "::%s" % ( - elements.TypeClause(bindparam.type)._compiler_dispatch( - self, skip_bind_expression=skip_bind_expression, **kw - ), - ) + if not skip_bind_expression and ( + bindparam.type._is_array or bindparam.type._is_type_decorator + ): + typ = bindparam.type._unwrapped_dialect_impl(self.dialect) + + if typ._is_array: + text += "::%s" % ( + elements.TypeClause(typ)._compiler_dispatch( + self, skip_bind_expression=skip_bind_expression, **kw + ), + ) return text diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index fe2ca9a09c..44608c9ca3 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1179,13 +1179,19 @@ class SchemaType(SchemaEventTarget): if variant_mapping is None: return True - if ( - dialect.name in variant_mapping - and variant_mapping[dialect.name] is self + # since PostgreSQL is the only DB that has ARRAY this can only + # be integration tested by PG-specific tests + def _we_are_the_impl(typ): + return ( + typ is self or isinstance(typ, ARRAY) and typ.item_type is self + ) + + if dialect.name in variant_mapping and _we_are_the_impl( + variant_mapping[dialect.name] ): return True elif dialect.name not in variant_mapping: - return variant_mapping["_default"] is self + return _we_are_the_impl(variant_mapping["_default"]) class Enum(Emulated, String, SchemaType): @@ -2745,16 +2751,16 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): def compare_values(self, x, y): return x == y - def _set_parent(self, column, **kw): + def _set_parent(self, column, outer=False, **kw): """Support SchemaEventTarget""" - if isinstance(self.item_type, SchemaEventTarget): + if not outer and isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent(column, **kw) def _set_parent_with_dispatch(self, parent, **kw): """Support SchemaEventTarget""" - super(ARRAY, self)._set_parent_with_dispatch(parent) + super(ARRAY, self)._set_parent_with_dispatch(parent, outer=True) if isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent_with_dispatch(parent) diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 7cf6971cde..241bba0ed7 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -46,6 +46,7 @@ class TypeEngine(Visitable): _sqla_type = True _isnull = False _is_array = False + _is_type_decorator = False class Comparator(operators.ColumnOperators): """Base class for custom comparison operations defined at the @@ -884,6 +885,8 @@ class TypeDecorator(SchemaEventTarget, TypeEngine): __visit_name__ = "type_decorator" + _is_type_decorator = True + def __init__(self, *args, **kwargs): """Construct a :class:`.TypeDecorator`. @@ -1412,7 +1415,7 @@ class Variant(TypeDecorator): else: return self.impl - def _set_parent(self, column, **kw): + def _set_parent(self, column, outer=False, **kw): """Support SchemaEventTarget""" if isinstance(self.impl, SchemaEventTarget): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 559cac8af1..5d0f5933f1 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -50,6 +50,8 @@ from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes +from sqlalchemy.sql import type_coerce +from sqlalchemy.sql.type_api import Variant from sqlalchemy.testing import engines from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises @@ -1740,12 +1742,18 @@ class ArrayRoundTripTest(object): "data_2", self.ARRAY(types.Enum("a", "b", "c", name="my_enum_2")), ), + Column( + "data_3", + self.ARRAY( + types.Enum("a", "b", "c", name="my_enum_3") + ).with_variant(String(), "other"), + ), ) t.create(testing.db) eq_( set(e["name"] for e in inspect(testing.db).get_enums()), - set(["my_enum_1", "my_enum_2"]), + set(["my_enum_1", "my_enum_2", "my_enum_3"]), ) t.drop(testing.db) eq_(inspect(testing.db).get_enums(), []) @@ -1888,6 +1896,7 @@ class ArrayRoundTripTest(object): ], testing.requires.hstore, ), + (postgresql.ENUM(AnEnum), enum_values), (sqltypes.Enum(AnEnum, native_enum=True), enum_values), ( sqltypes.Enum( @@ -1918,14 +1927,21 @@ class ArrayRoundTripTest(object): m.drop_all(testing.db) - @testing.fixture - def type_specific_fixture(self, metadata, connection, type_): + @testing.fixture(params=[True, False]) + def type_specific_fixture(self, request, metadata, connection, type_): + use_variant = request.param meta = MetaData() + + if use_variant: + typ = self.ARRAY(type_).with_variant(String(), "other") + else: + typ = self.ARRAY(type_) + table = Table( "foo", meta, Column("id", Integer), - Column("bar", self.ARRAY(type_)), + Column("bar", typ), ) meta.create_all(connection) @@ -1975,10 +1991,14 @@ class ArrayRoundTripTest(object): new_gen = gen(3) + if isinstance(table.c.bar.type, Variant): + # this is not likely to occur to users but we need to just + # exercise this as far as we can + expr = type_coerce(table.c.bar, ARRAY(type_))[1:3] + else: + expr = table.c.bar[1:3] connection.execute( - table.update() - .where(table.c.id == 2) - .values({table.c.bar[1:3]: new_gen[1:4]}) + table.update().where(table.c.id == 2).values({expr: new_gen[1:4]}) ) rows = connection.execute( diff --git a/test/sql/test_metadata.py b/test/sql/test_metadata.py index 65c5def205..5cd97c5ef8 100644 --- a/test/sql/test_metadata.py +++ b/test/sql/test_metadata.py @@ -2047,6 +2047,12 @@ class SchemaTypeTest(fixtures.TestBase): typ = MyType() self._test_before_parent_attach(typ) + def test_before_parent_attach_variant_array_schematype(self): + + target = Enum("one", "two", "three") + typ = ARRAY(target).with_variant(String(), "other") + self._test_before_parent_attach(typ, evt_target=target) + def _test_before_parent_attach(self, typ, evt_target=None): canary = mock.Mock()