]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: fix client-side representation of enums requiring quotes
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 28 Apr 2026 12:01:40 +0000 (14:01 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 28 Apr 2026 12:16:29 +0000 (14:16 +0200)
The issue seems limited to the enum type: every other EnumInfo subclass
correctly pass the regtype to the base class init.

Close #1298

docs/news.rst
psycopg/psycopg/types/enum.py
tests/test_sql.py
tests/types/test_enum.py

index bccc0b3cf46824f7f4eff6f7fbc9f55314b163a8..8c7a512a785a6033a3d8b915d169a95978133c93 100644 (file)
@@ -15,6 +15,8 @@ Psycopg 3.3.4 (unreleased)
 
 - Fix possible spurious connection timeout in systems with very long uptimes
   in C extension (:ticket:`#1280`).
+- Fix client-side adaptation of enums whose name require quotes
+  (:ticket:`#1298`).
 
 
 Current release
index 6af13c0dfefb2992ce79de3339ad3b85ddca92e0..efebe9e633f28dd8629e7507a04c0aeccf93436a 100644 (file)
@@ -43,9 +43,13 @@ class EnumInfo(TypeInfo):
         name: str,
         oid: int,
         array_oid: int,
+        # A bit ugly: this should have been a keyword-only argument, but it has
+        # been it the wild accepting a positional argument for too long to fix.
         labels: Sequence[str],
+        *,
+        regtype: str = "",
     ):
-        super().__init__(name, oid, array_oid)
+        super().__init__(name, oid, array_oid, regtype=regtype)
         self.labels = labels
         # Will be set by register_enum()
         self.enum: type[Enum] | None = None
@@ -53,18 +57,18 @@ class EnumInfo(TypeInfo):
     @classmethod
     def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate:
         return sql.SQL("""\
-SELECT name, oid, array_oid, array_agg(label) AS labels
+SELECT name, oid, array_oid, regtype, array_agg(label) AS labels
 FROM (
     SELECT
         t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
-        e.enumlabel AS label
+        t.oid::regtype::text AS regtype, e.enumlabel AS label
     FROM pg_type t
     LEFT JOIN  pg_enum e
     ON e.enumtypid = t.oid
     WHERE t.oid = {regtype}
     ORDER BY e.enumsortorder
 ) x
-GROUP BY name, oid, array_oid
+GROUP BY name, oid, array_oid, regtype
 """).format(regtype=cls._to_regtype(conn))
 
 
index 6d949ad1797e962d6b0410a45f8d64103fc42a64..fdabb56d52328f0336f71d7e8e26e9c8605c78a8 100644 (file)
@@ -401,7 +401,7 @@ class TestLiteral:
         assert sql.Literal("foo").as_string(conn) == "'foo'"
 
     @pytest.mark.crdb_skip("composite")  # create type, actually
-    @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"])
+    @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar", "FooBar"])
     def test_invalid_name(self, conn, name):
         if conn.info.parameter_status("is_superuser") != "on":
             pytest.skip("not a superuser")
index 028dd007b4e9e2f9f5e8531ef0f1b61a6e76418c..93e6529dbf3832d3082ad9297f7217edc89e82f8 100644 (file)
@@ -2,6 +2,7 @@ from enum import Enum, auto
 
 import pytest
 
+from psycopg import ClientCursor
 from psycopg import errors as e
 from psycopg import pq, sql
 from psycopg.adapt import PyFormat
@@ -36,6 +37,12 @@ class IntTestEnum(int, Enum):
     THREE = 3
 
 
+class CamelCaseEnum(int, Enum):
+    one = 1
+    TWO = 2
+    Three = 3
+
+
 enum_cases = [PureTestEnum, StrTestEnum, IntTestEnum]
 encodings = ["utf8", crdb_encoding("latin1")]
 
@@ -44,10 +51,12 @@ encodings = ["utf8", crdb_encoding("latin1")]
 def make_test_enums(request, svcconn):
     for enum in enum_cases + [NonAsciiEnum]:
         ensure_enum(enum, svcconn)
+    ensure_enum(CamelCaseEnum, svcconn, name="CamelCaseEnum")
 
 
-def ensure_enum(enum, conn):
-    name = enum.__name__.lower()
+def ensure_enum(enum, conn, name=""):
+    if not name:
+        name = enum.__name__.lower()
     labels = list(enum.__members__)
     conn.execute(sql.SQL("""
             drop type if exists {name};
@@ -114,6 +123,19 @@ def test_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out):
         assert cur.fetchone()[0] == enum[label]
 
 
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_enum_quoted_name(conn, fmt_in):
+    enum = CamelCaseEnum
+
+    info = EnumInfo.fetch(conn, sql.Identifier(enum.__name__))
+    register_enum(info, conn, enum=enum)
+
+    cur = ClientCursor(conn)
+    for value in enum:
+        cur.execute(f"select %{fmt_in.value}", [value])
+        assert next(cur)[0] is value
+
+
 @pytest.mark.crdb_skip("encoding")
 @pytest.mark.parametrize("enum", enum_cases)
 @pytest.mark.parametrize("fmt_in", PyFormat)