]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(enum): introduce an explicit dump/load mapping
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 14:04:32 +0000 (16:04 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Apr 2022 03:03:24 +0000 (05:03 +0200)
This will allow mapping customization; it is also slightly faster
because it doesn't require data encoding/decoding.

psycopg/psycopg/types/enum.py

index a9965ea7c8d8af3cd713cefbcc640cba4d33522a..742fd60f853d9f8d6c4732af081f8f25a6a7141d 100644 (file)
@@ -2,57 +2,64 @@
 Adapters for the enum type.
 """
 from enum import Enum
-from typing import Type, Any, Dict, Generic, Optional, TypeVar, cast
+from typing import Any, Dict, Generic, Optional, Type, TypeVar, cast, TYPE_CHECKING
 
 from .. import postgres
 from .. import errors as e
 from ..pq import Format
 from ..abc import AdaptContext
 from ..adapt import Buffer, Dumper, Loader
+from .._encodings import conn_encoding
 from .._typeinfo import EnumInfo as EnumInfo  # exported here
-from .._encodings import pgconn_encoding
 
+if TYPE_CHECKING:
+    from ..connection import BaseConnection
 
 E = TypeVar("E", bound=Enum)
 
 
-class EnumLoader(Loader, Generic[E]):
-    _encoding = "utf-8"
-    enum: Type[E]
+class _BaseEnumLoader(Loader, Generic[E]):
+    """
+    Dumper for a specific Enum class
+    """
 
-    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
-        super().__init__(oid, context)
-        conn = self.connection
-        if conn:
-            self._encoding = pgconn_encoding(conn.pgconn)
+    enum: Type[E]
+    _load_map: Dict[bytes, E]
 
     def load(self, data: Buffer) -> E:
-        if isinstance(data, memoryview):
-            label = bytes(data).decode(self._encoding)
-        else:
-            label = data.decode(self._encoding)
+        if not isinstance(data, bytes):
+            data = bytes(data)
 
         try:
-            return self.enum[label]
+            return self._load_map[data]
         except KeyError:
+            enc = conn_encoding(self.connection)
+            label = data.decode(enc, "replace")  # type: ignore[union-attr]
             raise e.DataError(
                 f"bad memeber for enum {self.enum.__qualname__}: {label!r}"
             )
 
 
-class EnumBinaryLoader(EnumLoader[E]):
-    format = Format.BINARY
+class _BaseEnumDumper(Dumper, Generic[E]):
+    """
+    Loader for a specific Enum class
+    """
+
+    enum: Type[E]
+    _dump_map: Dict[E, bytes]
+
+    def dump(self, value: E) -> Buffer:
+        return self._dump_map[value]
 
 
 class EnumDumper(Dumper):
-    _encoding = "utf-8"
+    """
+    Dumper for a generic Enum class
+    """
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
-
-        conn = self.connection
-        if conn:
-            self._encoding = pgconn_encoding(conn.pgconn)
+        self._encoding = conn_encoding(self.connection)
 
     def dump(self, value: E) -> Buffer:
         return value.name.encode(self._encoding)
@@ -86,29 +93,58 @@ def register_enum(
     adapters = context.adapters if context else postgres.adapters
     info.register(context)
 
-    attribs: Dict[str, Any] = {"enum": info.enum}
+    load_map = _make_load_map(info, enum, context.connection if context else None)
+    attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map}
 
-    loader_base = EnumLoader
     name = f"{info.name.title()}Loader"
-    loader = type(name, (loader_base,), attribs)
+    loader = type(name, (_BaseEnumLoader,), attribs)
     adapters.register_loader(info.oid, loader)
 
-    loader_base = EnumBinaryLoader
     name = f"{info.name.title()}BinaryLoader"
-    loader = type(name, (loader_base,), attribs)
+    loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY})
     adapters.register_loader(info.oid, loader)
 
-    attribs = {"oid": info.oid}
+    dump_map = _make_dump_map(info, enum, context.connection if context else None)
+    attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map}
 
-    name = f"{enum.__name__}BinaryDumper"
-    dumper = type(name, (EnumBinaryDumper,), attribs)
+    name = f"{enum.__name__}Dumper"
+    dumper = type(name, (_BaseEnumDumper,), attribs)
     adapters.register_dumper(info.enum, dumper)
 
-    name = f"{enum.__name__}Dumper"
-    dumper = type(name, (EnumDumper,), attribs)
+    name = f"{enum.__name__}BinaryDumper"
+    dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY})
     adapters.register_dumper(info.enum, dumper)
 
 
+def _make_load_map(
+    info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]"
+) -> Dict[bytes, E]:
+    enc = conn_encoding(conn)
+    rv: Dict[bytes, E] = {}
+    for label in info.labels:
+        try:
+            member = enum[label]
+        except KeyError:
+            # tolerate a missing enum, assuming it won't be used. If it is we
+            # will get a DataError on fetch.
+            pass
+        else:
+            rv[label.encode(enc)] = member
+
+    return rv
+
+
+def _make_dump_map(
+    info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]"
+) -> Dict[E, bytes]:
+    enc = conn_encoding(conn)
+    rv: Dict[E, bytes] = {}
+    for member in enum:
+        rv[member] = member.name.encode(enc)
+
+    return rv
+
+
 def register_default_adapters(context: AdaptContext) -> None:
     context.adapters.register_dumper(Enum, EnumBinaryDumper)
     context.adapters.register_dumper(Enum, EnumDumper)