... [UserRole.ADMIN, UserRole.GUEST]
... ).fetchone()
[<UserRole.ADMIN: 1>, <UserRole.GUEST: 3>]
+
+If the Python and the PostgreSQL enum don't match 1:1 (for instance if members
+have a different name, or if more than one Python enum should map to the same
+PostgreSQL enum, or vice versa), you can specify the exceptions using the
+`!mapping` parameter.
+
+`!mapping` should be a dictionary with Python enum members as keys and the
+matching PostgreSQL enum labels as values, or a list of `(member, label)`
+pairs with the same meaning (useful when some members are repeated). Order
+matters: if an element on either side is specified more than once, the last
+pair in the sequence will take precedence::
+
+ # Legacy roles, defined in medieval times.
+ >>> conn.execute(
+ ... "CREATE TYPE abbey_role AS ENUM ('ABBOT', 'SCRIBE', 'MONK', 'GUEST')")
+
+ >>> info = EnumInfo.fetch(conn, "abbey_role")
+ >>> register_enum(info, conn, UserRole, mapping=[
+ ... (UserRole.ADMIN, "ABBOT"),
+ ... (UserRole.EDITOR, "SCRIBE"),
+ ... (UserRole.EDITOR, "MONK")])
+
+ >>> conn.execute("SELECT '{ABBOT,SCRIBE,MONK,GUEST}'::abbey_role[]").fetchone()[0]
+ [<UserRole.ADMIN: 1>,
+ <UserRole.EDITOR: 2>,
+ <UserRole.EDITOR: 2>,
+ <UserRole.GUEST: 3>]
+
+ >>> conn.execute("SELECT %s::text[]", [list(UserRole)]).fetchone()[0]
+ ['ABBOT', 'MONK', 'GUEST']
+
+A particularly useful case is when the PostgreSQL labels match the *values* of
+a `!str`\-based Enum. In this case it is possible to use something like ``{m:
+m.value for m in enum}`` as mapping::
+
+ >>> class LowercaseRole(str, Enum):
+ ... ADMIN = "admin"
+ ... EDITOR = "editor"
+ ... GUEST = "guest"
+
+ >>> conn.execute(
+ ... "CREATE TYPE lowercase_role AS ENUM ('admin', 'editor', 'guest')")
+
+ >>> info = EnumInfo.fetch(conn, "lowercase_role")
+ >>> register_enum(
+ ... info, conn, LowercaseRole, mapping={m: m.value for m in LowercaseRole})
+
+ >>> conn.execute("SELECT 'editor'::lowercase_role").fetchone()[0]
+ <LowercaseRole.EDITOR: 'editor'>
Adapters for the enum type.
"""
from enum import Enum
-from typing import Any, Dict, Generic, Optional, Type, TypeVar, cast, TYPE_CHECKING
+from typing import Any, Dict, Generic, Optional, Mapping, Sequence
+from typing import Tuple, Type, TypeVar, Union, cast
from .. import postgres
from .. import errors as e
from ..pq import Format
from ..abc import AdaptContext
from ..adapt import Buffer, Dumper, Loader
+from .._compat import TypeAlias
from .._encodings import conn_encoding
from .._typeinfo import EnumInfo as EnumInfo # exported here
-if TYPE_CHECKING:
- from ..connection import BaseConnection
-
E = TypeVar("E", bound=Enum)
+EnumDumpMap: TypeAlias = Dict[E, bytes]
+EnumLoadMap: TypeAlias = Dict[bytes, E]
+EnumMapping: TypeAlias = Union[Mapping[E, str], Sequence[Tuple[E, str]], None]
+
class _BaseEnumLoader(Loader, Generic[E]):
"""
"""
enum: Type[E]
- _load_map: Dict[bytes, E]
+ _load_map: EnumLoadMap[E]
def load(self, data: Buffer) -> E:
if not isinstance(data, bytes):
"""
enum: Type[E]
- _dump_map: Dict[E, bytes]
+ _dump_map: EnumDumpMap[E]
def dump(self, value: E) -> Buffer:
return self._dump_map[value]
info: EnumInfo,
context: Optional[AdaptContext] = None,
enum: Optional[Type[E]] = None,
+ *,
+ mapping: EnumMapping[E] = None,
) -> None:
"""Register the adapters to load and dump a enum type.
register it globally.
:param enum: Python enum type matching to the PostgreSQL one. If `!None`,
a new enum will be generated and exposed as `EnumInfo.enum`.
+ :param mapping: Override the mapping between `!enum` members and `!info`
+ labels.
"""
if not info:
adapters = context.adapters if context else postgres.adapters
info.register(context)
- load_map = _make_load_map(info, enum, context.connection if context else None)
+ load_map = _make_load_map(info, enum, mapping, context)
attribs: Dict[str, Any] = {"enum": info.enum, "_load_map": load_map}
name = f"{info.name.title()}Loader"
loader = type(name, (_BaseEnumLoader,), {**attribs, "format": Format.BINARY})
adapters.register_loader(info.oid, loader)
- dump_map = _make_dump_map(info, enum, context.connection if context else None)
+ dump_map = _make_dump_map(info, enum, mapping, context)
attribs = {"oid": info.oid, "enum": info.enum, "_dump_map": dump_map}
name = f"{enum.__name__}Dumper"
def _make_load_map(
- info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]"
-) -> Dict[bytes, E]:
- enc = conn_encoding(conn)
- rv: Dict[bytes, E] = {}
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumLoadMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumLoadMap[E] = {}
for label in info.labels:
try:
member = enum[label]
else:
rv[label.encode(enc)] = member
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[label.encode(enc)] = member
+
return rv
def _make_dump_map(
- info: EnumInfo, enum: Type[E], conn: "Optional[BaseConnection[Any]]"
-) -> Dict[E, bytes]:
- enc = conn_encoding(conn)
- rv: Dict[E, bytes] = {}
+ info: EnumInfo,
+ enum: Type[E],
+ mapping: EnumMapping[E],
+ context: Optional[AdaptContext],
+) -> EnumDumpMap[E]:
+ enc = conn_encoding(context.connection if context else None)
+ rv: EnumDumpMap[E] = {}
for member in enum:
rv[member] = member.name.encode(enc)
+ if mapping:
+ if isinstance(mapping, Mapping):
+ mapping = list(mapping.items())
+
+ for member, label in mapping:
+ rv[member] = label.encode(enc)
+
return rv
conn.execute("select %s::text", [StrTestEnum.ONE]).fetchone()
with pytest.raises(e.DataError):
conn.execute("select 'BAR'::puretestenum").fetchone()
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize(
+ "mapping",
+ [
+ {StrTestEnum.ONE: "FOO", StrTestEnum.TWO: "BAR", StrTestEnum.THREE: "BAZ"},
+ [
+ (StrTestEnum.ONE, "FOO"),
+ (StrTestEnum.TWO, "BAR"),
+ (StrTestEnum.THREE, "BAZ"),
+ ],
+ ],
+)
+def test_remap(conn, fmt_in, fmt_out, mapping):
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, StrTestEnum, mapping=mapping)
+
+ for member, label in [("ONE", "FOO"), ("TWO", "BAR"), ("THREE", "BAZ")]:
+ cur = conn.execute(f"select %{fmt_in}::text", [StrTestEnum[member]])
+ assert cur.fetchone()[0] == label
+ cur = conn.execute(f"select '{label}'::puretestenum", binary=fmt_out)
+ assert cur.fetchone()[0] is StrTestEnum[member]
+
+
+def test_remap_rename(conn):
+ enum = Enum("RenamedEnum", "FOO BAR QUX")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, enum, mapping={enum.QUX: "BAZ"})
+
+ for member, label in [("FOO", "FOO"), ("BAR", "BAR"), ("QUX", "BAZ")]:
+ cur = conn.execute("select %s::text", [enum[member]])
+ assert cur.fetchone()[0] == label
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_python(conn):
+ enum = Enum("LargerEnum", "FOO BAR BAZ QUX QUUX QUUUX")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ mapping = {enum[m]: "BAZ" for m in ["QUX", "QUUX", "QUUUX"]}
+ register_enum(info, conn, enum, mapping=mapping)
+
+ for member, label in [("FOO", "FOO"), ("BAZ", "BAZ"), ("QUUUX", "BAZ")]:
+ cur = conn.execute("select %s::text", [enum[member]])
+ assert cur.fetchone()[0] == label
+
+ for member, label in [("FOO", "FOO"), ("QUUUX", "BAZ")]:
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[member]
+
+
+def test_remap_more_postgres(conn):
+ enum = Enum("SmallerEnum", "FOO")
+ info = EnumInfo.fetch(conn, "puretestenum")
+ mapping = [(enum.FOO, "BAR"), (enum.FOO, "BAZ")]
+ register_enum(info, conn, enum, mapping=mapping)
+
+ cur = conn.execute("select %s::text", [enum.FOO])
+ assert cur.fetchone()[0] == "BAZ"
+
+ for label in PureTestEnum.__members__:
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum.FOO
+
+
+def test_remap_by_value(conn):
+ enum = Enum( # type: ignore
+ "ByValue",
+ {m.lower(): m for m in PureTestEnum.__members__},
+ )
+ info = EnumInfo.fetch(conn, "puretestenum")
+ register_enum(info, conn, enum, mapping={m: m.value for m in enum})
+
+ for label in PureTestEnum.__members__:
+ cur = conn.execute("select %s::text", [enum[label.lower()]])
+ assert cur.fetchone()[0] == label
+
+ cur = conn.execute(f"select '{label}'::puretestenum")
+ assert cur.fetchone()[0] is enum[label.lower()]