- Fix possible spurious connection timeout in systems with very long uptimes
in C extension (:ticket:`#1280`).
+- Fix client-side adaptation of enums whose name require quotes
+ (:ticket:`#1298`).
Current release
name: str,
oid: int,
array_oid: int,
+ # A bit ugly: this should have been a keyword-only argument, but it has
+ # been it the wild accepting a positional argument for too long to fix.
labels: Sequence[str],
+ *,
+ regtype: str = "",
):
- super().__init__(name, oid, array_oid)
+ super().__init__(name, oid, array_oid, regtype=regtype)
self.labels = labels
# Will be set by register_enum()
self.enum: type[Enum] | None = None
@classmethod
def _get_info_query(cls, conn: BaseConnection[Any]) -> QueryNoTemplate:
return sql.SQL("""\
-SELECT name, oid, array_oid, array_agg(label) AS labels
+SELECT name, oid, array_oid, regtype, array_agg(label) AS labels
FROM (
SELECT
t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
- e.enumlabel AS label
+ t.oid::regtype::text AS regtype, e.enumlabel AS label
FROM pg_type t
LEFT JOIN pg_enum e
ON e.enumtypid = t.oid
WHERE t.oid = {regtype}
ORDER BY e.enumsortorder
) x
-GROUP BY name, oid, array_oid
+GROUP BY name, oid, array_oid, regtype
""").format(regtype=cls._to_regtype(conn))
assert sql.Literal("foo").as_string(conn) == "'foo'"
@pytest.mark.crdb_skip("composite") # create type, actually
- @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar"])
+ @pytest.mark.parametrize("name", ["a-b", f"{eur}", "order", "foo bar", "FooBar"])
def test_invalid_name(self, conn, name):
if conn.info.parameter_status("is_superuser") != "on":
pytest.skip("not a superuser")
import pytest
+from psycopg import ClientCursor
from psycopg import errors as e
from psycopg import pq, sql
from psycopg.adapt import PyFormat
THREE = 3
+class CamelCaseEnum(int, Enum):
+ one = 1
+ TWO = 2
+ Three = 3
+
+
enum_cases = [PureTestEnum, StrTestEnum, IntTestEnum]
encodings = ["utf8", crdb_encoding("latin1")]
def make_test_enums(request, svcconn):
for enum in enum_cases + [NonAsciiEnum]:
ensure_enum(enum, svcconn)
+ ensure_enum(CamelCaseEnum, svcconn, name="CamelCaseEnum")
-def ensure_enum(enum, conn):
- name = enum.__name__.lower()
+def ensure_enum(enum, conn, name=""):
+ if not name:
+ name = enum.__name__.lower()
labels = list(enum.__members__)
conn.execute(sql.SQL("""
drop type if exists {name};
assert cur.fetchone()[0] == enum[label]
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_enum_quoted_name(conn, fmt_in):
+ enum = CamelCaseEnum
+
+ info = EnumInfo.fetch(conn, sql.Identifier(enum.__name__))
+ register_enum(info, conn, enum=enum)
+
+ cur = ClientCursor(conn)
+ for value in enum:
+ cur.execute(f"select %{fmt_in.value}", [value])
+ assert next(cur)[0] is value
+
+
@pytest.mark.crdb_skip("encoding")
@pytest.mark.parametrize("enum", enum_cases)
@pytest.mark.parametrize("fmt_in", PyFormat)