]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(enum): add mapping override to `register_enum()`
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Apr 2022 00:48:41 +0000 (02:48 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Apr 2022 03:03:24 +0000 (05:03 +0200)
docs/basic/adapt.rst
psycopg/psycopg/types/enum.py
tests/types/test_enum.py

index b06541380041baac9b889499044c45ca80e58076..c4cf87875aef3047d1313b64dc8fe5a2aa2a6494 100644 (file)
@@ -466,3 +466,52 @@ Example::
     ...     [UserRole.ADMIN, UserRole.GUEST]
     ... ).fetchone()
     [<UserRole.ADMIN: 1>, <UserRole.GUEST: 3>]
+
+If the Python and the PostgreSQL enum don't match 1:1 (for instance if members
+have a different name, or if more than one Python enum should map to the same
+PostgreSQL enum, or vice versa), you can specify the exceptions using the
+`!mapping` parameter.
+
+`!mapping` should be a dictionary with Python enum members as keys and the
+matching PostgreSQL enum labels as values, or a list of `(member, label)`
+pairs with the same meaning (useful when some members are repeated). Order
+matters: if an element on either side is specified more than once, the last
+pair in the sequence will take precedence::
+
+    # Legacy roles, defined in medieval times.
+    >>> conn.execute(
+    ...     "CREATE TYPE abbey_role AS ENUM ('ABBOT', 'SCRIBE', 'MONK', 'GUEST')")
+
+    >>> info = EnumInfo.fetch(conn, "abbey_role")
+    >>> register_enum(info, conn, UserRole, mapping=[
+    ...     (UserRole.ADMIN, "ABBOT"),
+    ...     (UserRole.EDITOR, "SCRIBE"),
+    ...     (UserRole.EDITOR, "MONK")])
+
+    >>> conn.execute("SELECT '{ABBOT,SCRIBE,MONK,GUEST}'::abbey_role[]").fetchone()[0]
+    [<UserRole.ADMIN: 1>,
+     <UserRole.EDITOR: 2>,
+     <UserRole.EDITOR: 2>,
+     <UserRole.GUEST: 3>]
+
+    >>> conn.execute("SELECT %s::text[]", [list(UserRole)]).fetchone()[0]
+    ['ABBOT', 'MONK', 'GUEST']
+
+A particularly useful case is when the PostgreSQL labels match the *values* of
+a `!str`\-based Enum. In this case it is possible to use something like ``{m:
+m.value for m in enum}`` as mapping::
+
+    >>> class LowercaseRole(str, Enum):
+    ...     ADMIN = "admin"
+    ...     EDITOR = "editor"
+    ...     GUEST = "guest"
+
+    >>> conn.execute(
+    ...     "CREATE TYPE lowercase_role AS ENUM ('admin', 'editor', 'guest')")
+
+    >>> info = EnumInfo.fetch(conn, "lowercase_role")
+    >>> register_enum(
+    ...     info, conn, LowercaseRole, mapping={m: m.value for m in LowercaseRole})
+
+    >>> conn.execute("SELECT 'editor'::lowercase_role").fetchone()[0]
+    <LowercaseRole.EDITOR: 'editor'>
index 742fd60f853d9f8d6c4732af081f8f25a6a7141d..89b1870e607ab86a7b5d3feb7aebaa8398cb7127 100644 (file)
@@ -2,21 +2,24 @@
 Adapters for the enum type.
 """
 from enum import Enum
-from typing import Any, Dict, Generic, Optional, Type, TypeVar, cast, TYPE_CHECKING
+from typing import Any, Dict, Generic, Optional, Mapping, Sequence
+from typing import Tuple, Type, TypeVar, Union, cast
 
 from .. import postgres
 from .. import errors as e
 from ..pq import Format
 from ..abc import AdaptContext
 from ..adapt import Buffer, Dumper, Loader
+from .._compat import TypeAlias
 from .._encodings import conn_encoding
 from .._typeinfo import EnumInfo as EnumInfo  # exported here
 
-if TYPE_CHECKING:
-    from ..connection import BaseConnection
-
 E = TypeVar("E", bound=Enum)
 
+EnumDumpMap: TypeAlias = Dict[E, bytes]
+EnumLoadMap: TypeAlias = Dict[bytes, E]
+EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
+
 
 class _BaseEnumLoader(Loader, Generic[E]):
     """
@@ -24,7 +27,7 @@ class _BaseEnumLoader(Loader, Generic[E]):
     """
 
     enum: Type[E]
-    _load_map: Dict[bytes, E]
+    _load_map: EnumLoadMap[E]
 
     def load(self, data: Buffer) -> E:
         if not isinstance(data, bytes):
@@ -46,7 +49,7 @@ class _BaseEnumDumper(Dumper, Generic[E]):
     """
 
     enum: Type[E]
-    _dump_map: Dict[E, bytes]
+    _dump_map: EnumDumpMap[E]
 
     def dump(self, value: E) -> Buffer:
         return self._dump_map[value]
@@ -73,6 +76,8 @@ def register_enum(
     info: EnumInfo,
     context: Optional[AdaptContext] = None,
     enum: Optional[Type[E]] = None,
+    *,
+    mapping: EnumMapping[E] = None,
 ) -> None:
     """Register the adapters to load and dump a enum type.
 
@@ -81,6 +86,8 @@ def register_enum(
         register it globally.
     :param enum: Python enum type matching to the PostgreSQL one. If `!None`,
         a new enum will be generated and exposed as `EnumInfo.enum`.
+    :param mapping: Override the mapping between `!enum` members and `!info`
+        labels.
     """
 
     if not info:
@@ -93,7 +100,7 @@ def register_enum(
     adapters = context.adapters if context else postgres.adapters
     info.register(context)
 
-    load_map = _make_load_map(info, enum, context.connection if context else None)
+    load_map = _make_load_map(info, enum, mapping, context)
     attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map}
 
     name = f"{info.name.title()}Loader"
@@ -104,7 +111,7 @@ def register_enum(
     loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY})
     adapters.register_loader(info.oid, loader)
 
-    dump_map = _make_dump_map(info, enum, context.connection if context else None)
+    dump_map = _make_dump_map(info, enum, mapping, context)
     attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map}
 
     name = f"{enum.__name__}Dumper"
@@ -117,10 +124,13 @@ def register_enum(
 
 
 def _make_load_map(
-    info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]"
-) -> Dict[bytes, E]:
-    enc = conn_encoding(conn)
-    rv: Dict[bytes, E] = {}
+    info: EnumInfo,
+    enum: Type[E],
+    mapping: EnumMapping[E],
+    context: Optional[AdaptContext],
+) -> EnumLoadMap[E]:
+    enc = conn_encoding(context.connection if context else None)
+    rv: EnumLoadMap[E] = {}
     for label in info.labels:
         try:
             member = enum[label]
@@ -131,17 +141,34 @@ def _make_load_map(
         else:
             rv[label.encode(enc)] = member
 
+    if mapping:
+        if isinstance(mapping, Mapping):
+            mapping = list(mapping.items())
+
+        for member, label in mapping:
+            rv[label.encode(enc)] = member
+
     return rv
 
 
 def _make_dump_map(
-    info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]"
-) -> Dict[E, bytes]:
-    enc = conn_encoding(conn)
-    rv: Dict[E, bytes] = {}
+    info: EnumInfo,
+    enum: Type[E],
+    mapping: EnumMapping[E],
+    context: Optional[AdaptContext],
+) -> EnumDumpMap[E]:
+    enc = conn_encoding(context.connection if context else None)
+    rv: EnumDumpMap[E] = {}
     for member in enum:
         rv[member] = member.name.encode(enc)
 
+    if mapping:
+        if isinstance(mapping, Mapping):
+            mapping = list(mapping.items())
+
+        for member, label in mapping:
+            rv[member] = label.encode(enc)
+
     return rv
 
 
index 6ee3c6735324dc1ac7a472426087df56ae29908b..2c4abdacd12794c8f891db1ce0880382e588ab50 100644 (file)
@@ -266,3 +266,84 @@ def test_enum_error(conn):
         conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone()
     with pytest.raises(e.DataError):
         conn.execute("select 'BAR'::puretestenum").fetchone()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+    "mapping",
+    [
+        {StrTestEnum.ONE: "FOO", StrTestEnum.TWO: "BAR", StrTestEnum.THREE: "BAZ"},
+        [
+            (StrTestEnum.ONE, "FOO"),
+            (StrTestEnum.TWO, "BAR"),
+            (StrTestEnum.THREE, "BAZ"),
+        ],
+    ],
+)
+def test_remap(conn, fmt_in, fmt_out, mapping):
+    info = EnumInfo.fetch(conn, "puretestenum")
+    register_enum(info, conn, StrTestEnum, mapping=mapping)
+
+    for member, label in [("ONE", "FOO"), ("TWO", "BAR"), ("THREE", "BAZ")]:
+        cur = conn.execute(f"select %{fmt_in}::text", [StrTestEnum[member]])
+        assert cur.fetchone()[0] == label
+        cur = conn.execute(f"select '{label}'::puretestenum", binary=fmt_out)
+        assert cur.fetchone()[0] is StrTestEnum[member]
+
+
+def test_remap_rename(conn):
+    enum = Enum("RenamedEnum", "FOO BAR QUX")
+    info = EnumInfo.fetch(conn, "puretestenum")
+    register_enum(info, conn, enum, mapping={enum.QUX: "BAZ"})
+
+    for member, label in [("FOO", "FOO"), ("BAR", "BAR"), ("QUX", "BAZ")]:
+        cur = conn.execute("select %s::text", [enum[member]])
+        assert cur.fetchone()[0] == label
+        cur = conn.execute(f"select '{label}'::puretestenum")
+        assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_python(conn):
+    enum = Enum("LargerEnum", "FOO BAR BAZ QUX QUUX QUUUX")
+    info = EnumInfo.fetch(conn, "puretestenum")
+    mapping = {enum[m]: "BAZ" for m in ["QUX", "QUUX", "QUUUX"]}
+    register_enum(info, conn, enum, mapping=mapping)
+
+    for member, label in [("FOO", "FOO"), ("BAZ", "BAZ"), ("QUUUX", "BAZ")]:
+        cur = conn.execute("select %s::text", [enum[member]])
+        assert cur.fetchone()[0] == label
+
+    for member, label in [("FOO", "FOO"), ("QUUUX", "BAZ")]:
+        cur = conn.execute(f"select '{label}'::puretestenum")
+        assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_postgres(conn):
+    enum = Enum("SmallerEnum", "FOO")
+    info = EnumInfo.fetch(conn, "puretestenum")
+    mapping = [(enum.FOO, "BAR"), (enum.FOO, "BAZ")]
+    register_enum(info, conn, enum, mapping=mapping)
+
+    cur = conn.execute("select %s::text", [enum.FOO])
+    assert cur.fetchone()[0] == "BAZ"
+
+    for label in PureTestEnum.__members__:
+        cur = conn.execute(f"select '{label}'::puretestenum")
+        assert cur.fetchone()[0] is enum.FOO
+
+
+def test_remap_by_value(conn):
+    enum = Enum(  # type: ignore
+        "ByValue",
+        {m.lower(): m for m in PureTestEnum.__members__},
+    )
+    info = EnumInfo.fetch(conn, "puretestenum")
+    register_enum(info, conn, enum, mapping={m: m.value for m in enum})
+
+    for label in PureTestEnum.__members__:
+        cur = conn.execute("select %s::text", [enum[label.lower()]])
+        assert cur.fetchone()[0] == label
+
+        cur = conn.execute(f"select '{label}'::puretestenum")
+        assert cur.fetchone()[0] is enum[label.lower()]