From 98cb5a63b7e035a33923777890c5d656e428d3d4 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 21 Apr 2022 16:04:32 +0200 Subject: [PATCH] refactor(enum): introduce an explicit dump/load mapping This will allow mapping customization; it is also slightly faster because it doesn't require data encoding/decoding. --- psycopg/psycopg/types/enum.py | 100 +++++++++++++++++++++++----------- 1 file changed, 68 insertions(+), 32 deletions(-) diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index a9965ea7c..742fd60f8 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -2,57 +2,64 @@ Adapters for the enum type. """ from enum import Enum -from typing import Type, Any, Dict, Generic, Optional, TypeVar, cast +from typing import Any, Dict, Generic, Optional, Type, TypeVar, cast, TYPE_CHECKING from .. import postgres from .. import errors as e from ..pq import Format from ..abc import AdaptContext from ..adapt import Buffer, Dumper, Loader +from .._encodings import conn_encoding from .._typeinfo import EnumInfo as EnumInfo # exported here -from .._encodings import pgconn_encoding +if TYPE_CHECKING: + from ..connection import BaseConnection E = TypeVar("E", bound=Enum) -class EnumLoader(Loader, Generic[E]): - _encoding = "utf-8" - enum: Type[E] +class _BaseEnumLoader(Loader, Generic[E]): + """ + Dumper for a specific Enum class + """ - def __init__(self, oid: int, context: Optional[AdaptContext] = None): - super().__init__(oid, context) - conn = self.connection - if conn: - self._encoding = pgconn_encoding(conn.pgconn) + enum: Type[E] + _load_map: Dict[bytes, E] def load(self, data: Buffer) -> E: - if isinstance(data, memoryview): - label = bytes(data).decode(self._encoding) - else: - label = data.decode(self._encoding) + if not isinstance(data, bytes): + data = bytes(data) try: - return self.enum[label] + return self._load_map[data] except KeyError: + enc = conn_encoding(self.connection) + label = data.decode(enc, "replace") # type: ignore[union-attr] raise e.DataError( f"bad memeber for enum {self.enum.__qualname__}: {label!r}" ) -class EnumBinaryLoader(EnumLoader[E]): - format = Format.BINARY +class _BaseEnumDumper(Dumper, Generic[E]): + """ + Loader for a specific Enum class + """ + + enum: Type[E] + _dump_map: Dict[E, bytes] + + def dump(self, value: E) -> Buffer: + return self._dump_map[value] class EnumDumper(Dumper): - _encoding = "utf-8" + """ + Dumper for a generic Enum class + """ def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) - - conn = self.connection - if conn: - self._encoding = pgconn_encoding(conn.pgconn) + self._encoding = conn_encoding(self.connection) def dump(self, value: E) -> Buffer: return value.name.encode(self._encoding) @@ -86,29 +93,58 @@ def register_enum( adapters = context.adapters if context else postgres.adapters info.register(context) - attribs: Dict[str, Any] = {"enum": info.enum} + load_map = _make_load_map(info, enum, context.connection if context else None) + attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map} - loader_base = EnumLoader name = f"{info.name.title()}Loader" - loader = type(name, (loader_base,), attribs) + loader = type(name, (_BaseEnumLoader,), attribs) adapters.register_loader(info.oid, loader) - loader_base = EnumBinaryLoader name = f"{info.name.title()}BinaryLoader" - loader = type(name, (loader_base,), attribs) + loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY}) adapters.register_loader(info.oid, loader) - attribs = {"oid": info.oid} + dump_map = _make_dump_map(info, enum, context.connection if context else None) + attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map} - name = f"{enum.__name__}BinaryDumper" - dumper = type(name, (EnumBinaryDumper,), attribs) + name = f"{enum.__name__}Dumper" + dumper = type(name, (_BaseEnumDumper,), attribs) adapters.register_dumper(info.enum, dumper) - name = f"{enum.__name__}Dumper" - dumper = type(name, (EnumDumper,), attribs) + name = f"{enum.__name__}BinaryDumper" + dumper = type(name, (_BaseEnumDumper,), {**attribs, "format": Format.BINARY}) adapters.register_dumper(info.enum, dumper) +def _make_load_map( + info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]" +) -> Dict[bytes, E]: + enc = conn_encoding(conn) + rv: Dict[bytes, E] = {} + for label in info.labels: + try: + member = enum[label] + except KeyError: + # tolerate a missing enum, assuming it won't be used. If it is we + # will get a DataError on fetch. + pass + else: + rv[label.encode(enc)] = member + + return rv + + +def _make_dump_map( + info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]" +) -> Dict[E, bytes]: + enc = conn_encoding(conn) + rv: Dict[E, bytes] = {} + for member in enum: + rv[member] = member.name.encode(enc) + + return rv + + def register_default_adapters(context: AdaptContext) -> None: context.adapters.register_dumper(Enum, EnumBinaryDumper) context.adapters.register_dumper(Enum, EnumDumper) -- 2.47.2