From: Mike Bayer Date: Sat, 8 Apr 2023 22:43:31 +0000 (-0400) Subject: fix pg ENUM issues X-Git-Tag: rel_2_0_10~27 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=8128a8f3638b522778458edb81c81e654927bea4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git fix pg ENUM issues Restored the :paramref:`_postgresql.ENUM.name` parameter as optional in the signature for :class:`_postgresql.ENUM`, as this is chosen automatically from a given pep-435 ``Enum`` type. Fixed issue where the comparison for :class:`_postgresql.ENUM` against a plain string would cast that right-hand side type as VARCHAR, which due to more explicit casting added to dialects such as asyncpg would produce a PostgreSQL type mismatch error. Fixes: #9611 Fixes: #9621 Change-Id: If095544cd1a52016ad2e7cfa2d70c919a94e79c1 --- diff --git a/doc/build/changelog/unreleased_20/9621.rst b/doc/build/changelog/unreleased_20/9621.rst new file mode 100644 index 0000000000..de09479d37 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9621.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, postgresql + :tickets: 9611 + + Restored the :paramref:`_postgresql.ENUM.name` parameter as optional in the + signature for :class:`_postgresql.ENUM`, as this is chosen automatically + from a given pep-435 ``Enum`` type. + + +.. change:: + :tags: bug, postgresql + :tickets: 9621 + + Fixed issue where the comparison for :class:`_postgresql.ENUM` against a + plain string would cast that right-hand side type as VARCHAR, which due to + more explicit casting added to dialects such as asyncpg would produce a + PostgreSQL type mismatch error. + diff --git a/lib/sqlalchemy/dialects/postgresql/named_types.py b/lib/sqlalchemy/dialects/postgresql/named_types.py index e2a683e183..b0427b5699 100644 --- a/lib/sqlalchemy/dialects/postgresql/named_types.py +++ b/lib/sqlalchemy/dialects/postgresql/named_types.py @@ -20,6 +20,7 @@ from ...sql import elements from ...sql import roles from ...sql import sqltypes from ...sql import type_api +from ...sql.base import _NoArg from ...sql.ddl import InvokeCreateDDLBase from ...sql.ddl import InvokeDropDDLBase @@ -244,7 +245,13 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum): DDLGenerator = EnumGenerator DDLDropper = EnumDropper - def __init__(self, *enums, name: str, create_type: bool = True, **kw): + def __init__( + self, + *enums, + name: Union[str, _NoArg, None] = _NoArg.NO_ARG, + create_type: bool = True, + **kw, + ): """Construct an :class:`_postgresql.ENUM`. Arguments are the same as that of @@ -280,7 +287,19 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum): "non-native enum." ) self.create_type = create_type - super().__init__(*enums, name=name, **kw) + if name is not _NoArg.NO_ARG: + kw["name"] = name + super().__init__(*enums, **kw) + + def coerce_compared_value(self, op, value): + super_coerced_type = super().coerce_compared_value(op, value) + if ( + super_coerced_type._type_affinity + is type_api.STRINGTYPE._type_affinity + ): + return self + else: + return super_coerced_type @classmethod def __test_init__(cls): diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 0ee9095418..5f5be3c571 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -16,6 +16,7 @@ from sqlalchemy import Enum from sqlalchemy import exc from sqlalchemy import Float from sqlalchemy import func +from sqlalchemy import insert from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal @@ -491,6 +492,86 @@ class NamedTypeTest( else: assert False + @testing.variation("name", ["noname", "nonename", "explicit_name"]) + @testing.variation("enum_type", ["pg", "plain"]) + def test_native_enum_string_from_pep435(self, name, enum_type): + """test #9611""" + + class MyEnum(_PY_Enum): + one = "one" + two = "two" + + if enum_type.plain: + cls = Enum + elif enum_type.pg: + cls = ENUM + else: + enum_type.fail() + + if name.noname: + e1 = cls(MyEnum) + eq_(e1.name, "myenum") + elif name.nonename: + e1 = cls(MyEnum, name=None) + eq_(e1.name, None) + elif name.explicit_name: + e1 = cls(MyEnum, name="abc") + eq_(e1.name, "abc") + + @testing.variation("backend_type", ["native", "non_native", "pg_native"]) + @testing.variation("enum_type", ["pep435", "str"]) + def test_compare_to_string_round_trip( + self, connection, backend_type, enum_type, metadata + ): + """test #9621""" + + if enum_type.pep435: + + class MyEnum(_PY_Enum): + one = "one" + two = "two" + + if backend_type.pg_native: + typ = ENUM(MyEnum, name="myenum2") + else: + typ = Enum( + MyEnum, + native_enum=bool(backend_type.native), + name="myenum2", + ) + data = [{"someenum": MyEnum.one}, {"someenum": MyEnum.two}] + expected = MyEnum.two + elif enum_type.str: + if backend_type.pg_native: + typ = ENUM("one", "two", name="myenum2") + else: + typ = Enum( + "one", + "two", + native_enum=bool(backend_type.native), + name="myenum2", + ) + data = [{"someenum": "one"}, {"someenum": "two"}] + expected = "two" + else: + enum_type.fail() + + enum_table = Table( + "et2", + metadata, + Column("id", Integer, primary_key=True), + Column("someenum", typ), + ) + metadata.create_all(connection) + + connection.execute(insert(enum_table), data) + expr = select(enum_table.c.someenum).where( + enum_table.c.someenum == "two" + ) + + row = connection.execute(expr).one() + eq_(row, (expected,)) + @testing.combinations( (Enum("one", "two", "three")), (ENUM("one", "two", "three", name=None)),