From: Daniele Varrazzo Date: Fri, 22 Apr 2022 00:48:41 +0000 (+0200) Subject: feat(enum): add mapping override to `register_enum()` X-Git-Tag: 3.1~137^2~2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3567311c1ae5f1f9845a8b29645bbc59e44c7324;p=thirdparty%2Fpsycopg.git feat(enum): add mapping override to `register_enum()` --- diff --git a/docs/basic/adapt.rst b/docs/basic/adapt.rst index b06541380..c4cf87875 100644 --- a/docs/basic/adapt.rst +++ b/docs/basic/adapt.rst @@ -466,3 +466,52 @@ Example:: ... [UserRole.ADMIN, UserRole.GUEST] ... ).fetchone() [, ] + +If the Python and the PostgreSQL enum don't match 1:1 (for instance if members +have a different name, or if more than one Python enum should map to the same +PostgreSQL enum, or vice versa), you can specify the exceptions using the +`!mapping` parameter. + +`!mapping` should be a dictionary with Python enum members as keys and the +matching PostgreSQL enum labels as values, or a list of `(member, label)` +pairs with the same meaning (useful when some members are repeated). Order +matters: if an element on either side is specified more than once, the last +pair in the sequence will take precedence:: + + # Legacy roles, defined in medieval times. + >>> conn.execute( + ... "CREATE TYPE abbey_role AS ENUM ('ABBOT', 'SCRIBE', 'MONK', 'GUEST')") + + >>> info = EnumInfo.fetch(conn, "abbey_role") + >>> register_enum(info, conn, UserRole, mapping=[ + ... (UserRole.ADMIN, "ABBOT"), + ... (UserRole.EDITOR, "SCRIBE"), + ... (UserRole.EDITOR, "MONK")]) + + >>> conn.execute("SELECT '{ABBOT,SCRIBE,MONK,GUEST}'::abbey_role[]").fetchone()[0] + [, + , + , + ] + + >>> conn.execute("SELECT %s::text[]", [list(UserRole)]).fetchone()[0] + ['ABBOT', 'MONK', 'GUEST'] + +A particularly useful case is when the PostgreSQL labels match the *values* of +a `!str`\-based Enum. In this case it is possible to use something like ``{m: +m.value for m in enum}`` as mapping:: + + >>> class LowercaseRole(str, Enum): + ... ADMIN = "admin" + ... EDITOR = "editor" + ... GUEST = "guest" + + >>> conn.execute( + ... "CREATE TYPE lowercase_role AS ENUM ('admin', 'editor', 'guest')") + + >>> info = EnumInfo.fetch(conn, "lowercase_role") + >>> register_enum( + ... info, conn, LowercaseRole, mapping={m: m.value for m in LowercaseRole}) + + >>> conn.execute("SELECT 'editor'::lowercase_role").fetchone()[0] + diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index 742fd60f8..89b1870e6 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -2,21 +2,24 @@ Adapters for the enum type. """ from enum import Enum -from typing import Any, Dict, Generic, Optional, Type, TypeVar, cast, TYPE_CHECKING +from typing import Any, Dict, Generic, Optional, Mapping, Sequence +from typing import Tuple, Type, TypeVar, Union, cast from .. import postgres from .. import errors as e from ..pq import Format from ..abc import AdaptContext from ..adapt import Buffer, Dumper, Loader +from .._compat import TypeAlias from .._encodings import conn_encoding from .._typeinfo import EnumInfo as EnumInfo # exported here -if TYPE_CHECKING: - from ..connection import BaseConnection - E = TypeVar("E", bound=Enum) +EnumDumpMap: TypeAlias = Dict[E, bytes] +EnumLoadMap: TypeAlias = Dict[bytes, E] +EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None] + class _BaseEnumLoader(Loader, Generic[E]): """ @@ -24,7 +27,7 @@ class _BaseEnumLoader(Loader, Generic[E]): """ enum: Type[E] - _load_map: Dict[bytes, E] + _load_map: EnumLoadMap[E] def load(self, data: Buffer) -> E: if not isinstance(data, bytes): @@ -46,7 +49,7 @@ class _BaseEnumDumper(Dumper, Generic[E]): """ enum: Type[E] - _dump_map: Dict[E, bytes] + _dump_map: EnumDumpMap[E] def dump(self, value: E) -> Buffer: return self._dump_map[value] @@ -73,6 +76,8 @@ def register_enum( info: EnumInfo, context: Optional[AdaptContext] = None, enum: Optional[Type[E]] = None, + *, + mapping: EnumMapping[E] = None, ) -> None: """Register the adapters to load and dump a enum type. @@ -81,6 +86,8 @@ def register_enum( register it globally. :param enum: Python enum type matching to the PostgreSQL one. If `!None`, a new enum will be generated and exposed as `EnumInfo.enum`. + :param mapping: Override the mapping between `!enum` members and `!info` + labels. """ if not info: @@ -93,7 +100,7 @@ def register_enum( adapters = context.adapters if context else postgres.adapters info.register(context) - load_map = _make_load_map(info, enum, context.connection if context else None) + load_map = _make_load_map(info, enum, mapping, context) attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map} name = f"{info.name.title()}Loader" @@ -104,7 +111,7 @@ def register_enum( loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY}) adapters.register_loader(info.oid, loader) - dump_map = _make_dump_map(info, enum, context.connection if context else None) + dump_map = _make_dump_map(info, enum, mapping, context) attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map} name = f"{enum.__name__}Dumper" @@ -117,10 +124,13 @@ def register_enum( def _make_load_map( - info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]" -) -> Dict[bytes, E]: - enc = conn_encoding(conn) - rv: Dict[bytes, E] = {} + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumLoadMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumLoadMap[E] = {} for label in info.labels: try: member = enum[label] @@ -131,17 +141,34 @@ def _make_load_map( else: rv[label.encode(enc)] = member + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[label.encode(enc)] = member + return rv def _make_dump_map( - info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]" -) -> Dict[E, bytes]: - enc = conn_encoding(conn) - rv: Dict[E, bytes] = {} + info: EnumInfo, + enum: Type[E], + mapping: EnumMapping[E], + context: Optional[AdaptContext], +) -> EnumDumpMap[E]: + enc = conn_encoding(context.connection if context else None) + rv: EnumDumpMap[E] = {} for member in enum: rv[member] = member.name.encode(enc) + if mapping: + if isinstance(mapping, Mapping): + mapping = list(mapping.items()) + + for member, label in mapping: + rv[member] = label.encode(enc) + return rv diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py index 6ee3c6735..2c4abdacd 100644 --- a/tests/types/test_enum.py +++ b/tests/types/test_enum.py @@ -266,3 +266,84 @@ def test_enum_error(conn): conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone() with pytest.raises(e.DataError): conn.execute("select 'BAR'::puretestenum").fetchone() + + +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize( + "mapping", + [ + {StrTestEnum.ONE: "FOO", StrTestEnum.TWO: "BAR", StrTestEnum.THREE: "BAZ"}, + [ + (StrTestEnum.ONE, "FOO"), + (StrTestEnum.TWO, "BAR"), + (StrTestEnum.THREE, "BAZ"), + ], + ], +) +def test_remap(conn, fmt_in, fmt_out, mapping): + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, StrTestEnum, mapping=mapping) + + for member, label in [("ONE", "FOO"), ("TWO", "BAR"), ("THREE", "BAZ")]: + cur = conn.execute(f"select %{fmt_in}::text", [StrTestEnum[member]]) + assert cur.fetchone()[0] == label + cur = conn.execute(f"select '{label}'::puretestenum", binary=fmt_out) + assert cur.fetchone()[0] is StrTestEnum[member] + + +def test_remap_rename(conn): + enum = Enum("RenamedEnum", "FOO BAR QUX") + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, enum, mapping={enum.QUX: "BAZ"}) + + for member, label in [("FOO", "FOO"), ("BAR", "BAR"), ("QUX", "BAZ")]: + cur = conn.execute("select %s::text", [enum[member]]) + assert cur.fetchone()[0] == label + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[member] + + +def test_remap_more_python(conn): + enum = Enum("LargerEnum", "FOO BAR BAZ QUX QUUX QUUUX") + info = EnumInfo.fetch(conn, "puretestenum") + mapping = {enum[m]: "BAZ" for m in ["QUX", "QUUX", "QUUUX"]} + register_enum(info, conn, enum, mapping=mapping) + + for member, label in [("FOO", "FOO"), ("BAZ", "BAZ"), ("QUUUX", "BAZ")]: + cur = conn.execute("select %s::text", [enum[member]]) + assert cur.fetchone()[0] == label + + for member, label in [("FOO", "FOO"), ("QUUUX", "BAZ")]: + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[member] + + +def test_remap_more_postgres(conn): + enum = Enum("SmallerEnum", "FOO") + info = EnumInfo.fetch(conn, "puretestenum") + mapping = [(enum.FOO, "BAR"), (enum.FOO, "BAZ")] + register_enum(info, conn, enum, mapping=mapping) + + cur = conn.execute("select %s::text", [enum.FOO]) + assert cur.fetchone()[0] == "BAZ" + + for label in PureTestEnum.__members__: + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum.FOO + + +def test_remap_by_value(conn): + enum = Enum( # type: ignore + "ByValue", + {m.lower(): m for m in PureTestEnum.__members__}, + ) + info = EnumInfo.fetch(conn, "puretestenum") + register_enum(info, conn, enum, mapping={m: m.value for m in enum}) + + for label in PureTestEnum.__members__: + cur = conn.execute("select %s::text", [enum[label.lower()]]) + assert cur.fetchone()[0] == label + + cur = conn.execute(f"select '{label}'::puretestenum") + assert cur.fetchone()[0] is enum[label.lower()]