]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
fix pg ENUM issues
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Apr 2023 22:43:31 +0000 (18:43 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 8 Apr 2023 22:45:31 +0000 (18:45 -0400)
Restored the :paramref:`_postgresql.ENUM.name` parameter as optional in the
signature for :class:`_postgresql.ENUM`, as this is chosen automatically
from a given pep-435 ``Enum`` type.

Fixed issue where the comparison for :class:`_postgresql.ENUM` against a
plain string would cast that right-hand side type as VARCHAR, which due to
more explicit casting added to dialects such as asyncpg would produce a
PostgreSQL type mismatch error.

Fixes: #9611
Fixes: #9621
Change-Id: If095544cd1a52016ad2e7cfa2d70c919a94e79c1

doc/build/changelog/unreleased_20/9621.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/named_types.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/9621.rst b/doc/build/changelog/unreleased_20/9621.rst
new file mode 100644 (file)
index 0000000..de09479
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, postgresql
+    :tickets: 9611
+
+    Restored the :paramref:`_postgresql.ENUM.name` parameter as optional in the
+    signature for :class:`_postgresql.ENUM`, as this is chosen automatically
+    from a given pep-435 ``Enum`` type.
+
+
+.. change::
+    :tags: bug, postgresql
+    :tickets: 9621
+
+    Fixed issue where the comparison for :class:`_postgresql.ENUM` against a
+    plain string would cast that right-hand side type as VARCHAR, which due to
+    more explicit casting added to dialects such as asyncpg would produce a
+    PostgreSQL type mismatch error.
+
index e2a683e1832a751d03cc0f3806de52d0ac5b0c94..b0427b56998bd289e0ab446c93cf316c20472e72 100644 (file)
@@ -20,6 +20,7 @@ from ...sql import elements
 from ...sql import roles
 from ...sql import sqltypes
 from ...sql import type_api
+from ...sql.base import _NoArg
 from ...sql.ddl import InvokeCreateDDLBase
 from ...sql.ddl import InvokeDropDDLBase
 
@@ -244,7 +245,13 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum):
     DDLGenerator = EnumGenerator
     DDLDropper = EnumDropper
 
-    def __init__(self, *enums, name: str, create_type: bool = True, **kw):
+    def __init__(
+        self,
+        *enums,
+        name: Union[str, _NoArg, None] = _NoArg.NO_ARG,
+        create_type: bool = True,
+        **kw,
+    ):
         """Construct an :class:`_postgresql.ENUM`.
 
         Arguments are the same as that of
@@ -280,7 +287,19 @@ class ENUM(NamedType, sqltypes.NativeForEmulated, sqltypes.Enum):
                 "non-native enum."
             )
         self.create_type = create_type
-        super().__init__(*enums, name=name, **kw)
+        if name is not _NoArg.NO_ARG:
+            kw["name"] = name
+        super().__init__(*enums, **kw)
+
+    def coerce_compared_value(self, op, value):
+        super_coerced_type = super().coerce_compared_value(op, value)
+        if (
+            super_coerced_type._type_affinity
+            is type_api.STRINGTYPE._type_affinity
+        ):
+            return self
+        else:
+            return super_coerced_type
 
     @classmethod
     def __test_init__(cls):
index 0ee90954182ba6671820518bbc4802dca974c1c2..5f5be3c571ec824ebb8b2798b84798b9cd37f473 100644 (file)
@@ -16,6 +16,7 @@ from sqlalchemy import Enum
 from sqlalchemy import exc
 from sqlalchemy import Float
 from sqlalchemy import func
+from sqlalchemy import insert
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import literal
@@ -491,6 +492,86 @@ class NamedTypeTest(
         else:
             assert False
 
+    @testing.variation("name", ["noname", "nonename", "explicit_name"])
+    @testing.variation("enum_type", ["pg", "plain"])
+    def test_native_enum_string_from_pep435(self, name, enum_type):
+        """test #9611"""
+
+        class MyEnum(_PY_Enum):
+            one = "one"
+            two = "two"
+
+        if enum_type.plain:
+            cls = Enum
+        elif enum_type.pg:
+            cls = ENUM
+        else:
+            enum_type.fail()
+
+        if name.noname:
+            e1 = cls(MyEnum)
+            eq_(e1.name, "myenum")
+        elif name.nonename:
+            e1 = cls(MyEnum, name=None)
+            eq_(e1.name, None)
+        elif name.explicit_name:
+            e1 = cls(MyEnum, name="abc")
+            eq_(e1.name, "abc")
+
+    @testing.variation("backend_type", ["native", "non_native", "pg_native"])
+    @testing.variation("enum_type", ["pep435", "str"])
+    def test_compare_to_string_round_trip(
+        self, connection, backend_type, enum_type, metadata
+    ):
+        """test #9621"""
+
+        if enum_type.pep435:
+
+            class MyEnum(_PY_Enum):
+                one = "one"
+                two = "two"
+
+            if backend_type.pg_native:
+                typ = ENUM(MyEnum, name="myenum2")
+            else:
+                typ = Enum(
+                    MyEnum,
+                    native_enum=bool(backend_type.native),
+                    name="myenum2",
+                )
+            data = [{"someenum": MyEnum.one}, {"someenum": MyEnum.two}]
+            expected = MyEnum.two
+        elif enum_type.str:
+            if backend_type.pg_native:
+                typ = ENUM("one", "two", name="myenum2")
+            else:
+                typ = Enum(
+                    "one",
+                    "two",
+                    native_enum=bool(backend_type.native),
+                    name="myenum2",
+                )
+            data = [{"someenum": "one"}, {"someenum": "two"}]
+            expected = "two"
+        else:
+            enum_type.fail()
+
+        enum_table = Table(
+            "et2",
+            metadata,
+            Column("id", Integer, primary_key=True),
+            Column("someenum", typ),
+        )
+        metadata.create_all(connection)
+
+        connection.execute(insert(enum_table), data)
+        expr = select(enum_table.c.someenum).where(
+            enum_table.c.someenum == "two"
+        )
+
+        row = connection.execute(expr).one()
+        eq_(row, (expected,))
+
     @testing.combinations(
         (Enum("one", "two", "three")),
         (ENUM("one", "two", "three", name=None)),