]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(crdb): configure dumpers to deal correctly with Enum types
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 22 May 2022 00:12:01 +0000 (02:12 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 12 Jul 2022 11:58:34 +0000 (12:58 +0100)
psycopg/psycopg/crdb.py
tests/types/test_enum.py

index e509d78ec663914dde3918100e9b873d688ca2c7..acbd99a1717a840ea5577b5214c7648ed767060a 100644 (file)
@@ -5,13 +5,15 @@ Types configuration specific for CockroachDB.
 # Copyright (C) 2022 The Psycopg Team
 
 import re
+from enum import Enum
 from typing import Any, Optional, Union, TYPE_CHECKING
 
 from . import errors as e
 from .abc import AdaptContext
-from .postgres import adapters as pg_adapters
-from ._adapters_map import AdaptersMap
+from .postgres import adapters as pg_adapters, TEXT_OID
 from .conninfo import ConnectionInfo
+from ._adapters_map import AdaptersMap
+from .types.enum import EnumDumper, EnumBinaryDumper
 
 adapters = AdaptersMap(pg_adapters)
 
@@ -55,6 +57,14 @@ class CrdbConnectionInfo(ConnectionInfo):
         return int(m.group(1)) * 10000 + int(m.group(2)) * 100 + int(m.group(3))
 
 
+class CrdbEnumDumper(EnumDumper):
+    oid = TEXT_OID
+
+
+class CrdbEnumBinaryDumper(EnumBinaryDumper):
+    oid = TEXT_OID
+
+
 def register_crdb_adapters(context: AdaptContext) -> None:
     from .types import string
 
@@ -63,6 +73,8 @@ def register_crdb_adapters(context: AdaptContext) -> None:
     # Dump strings with text oid instead of unknown.
     # Unlike PostgreSQL, CRDB seems able to cast text to most types.
     adapters.register_dumper(str, string.StrDumper)
+    adapters.register_dumper(Enum, CrdbEnumBinaryDumper)
+    adapters.register_dumper(Enum, CrdbEnumDumper)
 
 
 register_crdb_adapters(adapters)
index 0284dcf53fc6f0467d62600d9e9fee4621a23c26..e5b15d13258faefbf4f20c61bb19d0d4575e50a6 100644 (file)
@@ -7,6 +7,8 @@ from psycopg.adapt import PyFormat
 from psycopg.types import TypeInfo
 from psycopg.types.enum import EnumInfo, register_enum
 
+from ..fix_crdb import crdb_encoding
+
 
 class PureTestEnum(Enum):
     FOO = auto()
@@ -34,7 +36,7 @@ class IntTestEnum(int, Enum):
 
 
 enum_cases = [PureTestEnum, StrTestEnum, IntTestEnum]
-encodings = ["utf8", "latin1"]
+encodings = ["utf8", crdb_encoding("latin1")]
 
 
 @pytest.fixture(scope="session", autouse=True)
@@ -49,7 +51,7 @@ def ensure_enum(enum, conn):
     conn.execute(
         sql.SQL(
             """
-            drop type if exists {name} cascade;
+            drop type if exists {name};
             create type {name} as enum ({labels});
             """
         ).format(name=sql.Identifier(name), labels=sql.SQL(",").join(labels))
@@ -116,6 +118,7 @@ def test_enum_loader_nonascii(conn, encoding, fmt_in, fmt_out):
         assert cur.fetchone()[0] == enum[label]
 
 
+@pytest.mark.crdb("skip", reason="encoding")
 @pytest.mark.parametrize("enum", enum_cases)
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
@@ -158,6 +161,7 @@ def test_enum_dumper_nonascii(conn, encoding, fmt_in, fmt_out):
         assert cur.fetchone()[0] == item
 
 
+@pytest.mark.crdb("skip", reason="encoding")
 @pytest.mark.parametrize("enum", enum_cases)
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)