From: Daniele Varrazzo Date: Mon, 18 Apr 2022 17:15:35 +0000 (+0200) Subject: feat(enum): dump enums by keys instead of values X-Git-Tag: 3.1~137^2~12 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=278d0c382e3b71fd34084f249e3cc0853d90d954;p=thirdparty%2Fpsycopg.git feat(enum): dump enums by keys instead of values This removes the need for enums to be str-based. It also removes the asymmetry whereby automatically generated enums were *not* string based, but pure enums. --- diff --git a/docs/basic/adapt.rst b/docs/basic/adapt.rst index d5c392bda..139c5effc 100644 --- a/docs/basic/adapt.rst +++ b/docs/basic/adapt.rst @@ -389,9 +389,6 @@ Before using a enum type it is necessary to get information about it using the `~psycopg.types.enum.EnumInfo` class and to register it using `~psycopg.types.enum.register_enum()`. -If you use enum array and your enum labels contains comma, you should use -the binary format to return values. See :ref:`binary-data` for details. - .. autoclass:: psycopg.types.enum.EnumInfo `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its @@ -417,31 +414,31 @@ the binary format to return values. See :ref:`binary-data` for details. Example:: - >>> from enum import Enum + >>> from enum import Enum, auto >>> from psycopg.types.enum import EnumInfo, register_enum - >>> class UserRole(str, Enum): - ... ADMIN = "ADMIN" - ... EDITOR = "EDITOR" - ... GUEST = "GUEST" + >>> class UserRole(Enum): + ... ADMIN = auto() + ... EDITOR = auto() + ... GUEST = auto() - >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')") + >>> conn.execute("CREATE TYPE user_role AS ENUM ('ADMIN', 'EDITOR', 'GUEST')") >>> info = EnumInfo.fetch(conn, "user_role") >>> register_enum(info, UserRole, conn) - >>> some_editor = info.python_type("EDITOR") + >>> some_editor = info.python_type.EDITOR >>> some_editor - + >>> conn.execute( ... "SELECT pg_typeof(%(editor)s), %(editor)s", - ... { "editor": some_editor } + ... {"editor": some_editor} ... ).fetchone() - ('user_role', ) + ('user_role', ) >>> conn.execute( - ... "SELECT (%s, %s)::user_role[]", + ... "SELECT ARRAY[%s, %s]", ... [UserRole.ADMIN, UserRole.GUEST] ... ).fetchone() - [, ] + [, ] diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index 38e15ad3a..aa99ab52f 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -2,10 +2,9 @@ Adapters for the enum type. """ from enum import Enum -from typing import Optional, TypeVar, Generic, Type, Dict, Any +from typing import Type, Any, Dict, Generic, Optional, TypeVar, cast from ..adapt import Dumper, Loader -from .string import StrBinaryDumper, StrDumper from .. import postgres from .._encodings import pgconn_encoding from .._typeinfo import EnumInfo as EnumInfo # exported here @@ -18,7 +17,6 @@ E = TypeVar("E", bound=Enum) class EnumLoader(Loader, Generic[E]): - format = Format.TEXT _encoding = "utf-8" python_type: Type[E] @@ -34,19 +32,29 @@ class EnumLoader(Loader, Generic[E]): else: label = data.decode(self._encoding) - return self.python_type(label) + return self.python_type[label] class EnumBinaryLoader(EnumLoader[E]): format = Format.BINARY -class EnumDumper(StrDumper): - pass +class EnumDumper(Dumper): + _encoding = "utf-8" + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + + conn = self.connection + if conn: + self._encoding = pgconn_encoding(conn.pgconn) + def dump(self, value: E) -> Buffer: + return value.name.encode(self._encoding) -class EnumBinaryDumper(StrBinaryDumper): - pass + +class EnumBinaryDumper(EnumDumper): + format = Format.BINARY def register_enum( @@ -61,30 +69,19 @@ def register_enum( the type will be generated and put into info.python_type. :param context: The context where to register the adapters. If `!None`, register it globally. - - .. note:: - Only string enums are supported. - - Use binary format if you use enum array and enum labels contains comma: - connection.execute(..., binary=True) """ if not info: raise TypeError("no info passed. Is the requested enum available?") - if python_type is not None: - if {type(item.value) for item in python_type} != {str}: - raise TypeError("invalid enum value type (string is the only supported)") - - info.python_type = python_type - else: - info.python_type = Enum( # type: ignore - info.name.title(), - {label: label for label in info.enum_labels}, + if python_type is None: + python_type = cast( + Type[E], + Enum(info.name.title(), {label: label for label in info.enum_labels}), ) + info.python_type = python_type adapters = context.adapters if context else postgres.adapters - info.register(context) attribs: Dict[str, Any] = {"python_type": info.python_type} diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py index 8a1ffc7f4..eda761ee5 100644 --- a/tests/types/test_enum.py +++ b/tests/types/test_enum.py @@ -1,4 +1,4 @@ -from enum import Enum +from enum import Enum, auto import pytest @@ -7,28 +7,39 @@ from psycopg.adapt import PyFormat from psycopg.types.enum import EnumInfo, register_enum +class PureTestEnum(Enum): + FOO = auto() + BAR = auto() + BAZ = auto() + + class StrTestEnum(str, Enum): ONE = "ONE" TWO = "TWO" THREE = "THREE" -class NonAsciiEnum(str, Enum): - XE0 = "x\xe0" - XE1 = "x\xe1" +NonAsciiEnum = Enum( + "NonAsciiEnum", {"X\xe0": "x\xe0", "X\xe1": "x\xe1", "COMMA": "foo,bar"}, type=str +) + + +class IntTestEnum(int, Enum): + ONE = 1 + TWO = 2 + THREE = 3 -enum_cases = [ - ("strtestenum", StrTestEnum, [item.value for item in StrTestEnum]), - ("nonasciienum", NonAsciiEnum, [item.value for item in NonAsciiEnum]), -] +enum_cases = [PureTestEnum, StrTestEnum, NonAsciiEnum, IntTestEnum] encodings = ["utf8", "latin1"] @pytest.fixture(scope="session", params=enum_cases) def testenum(request, svcconn): - name, enum, labels = request.param + enum = request.param + name = enum.__name__.lower() + labels = list(enum.__members__.keys()) cur = svcconn.cursor() cur.execute( sql.SQL( @@ -52,6 +63,16 @@ def test_fetch_info(conn, testenum): assert info.enum_labels == labels +def test_register_makes_a_type(conn, testenum): + name, enum, labels = testenum + info = EnumInfo.fetch(conn, name) + assert info + assert info.python_type is None + register_enum(info, context=conn) + assert info.python_type is not None + assert [e.name for e in info.python_type] == [e.name for e in enum] + + @pytest.mark.parametrize("encoding", encodings) @pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) @@ -63,7 +84,7 @@ def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out): for label in labels: cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out) - assert cur.fetchone()[0] == enum(label) + assert cur.fetchone()[0] == enum[label] @pytest.mark.parametrize("fmt_in", PyFormat) @@ -78,7 +99,7 @@ def test_enum_loader_sqlascii(conn, testenum, fmt_in, fmt_out): for label in labels: cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out) - assert cur.fetchone()[0] == enum(label) + assert cur.fetchone()[0] == enum[label] @pytest.mark.parametrize("encoding", encodings) @@ -190,7 +211,7 @@ async def test_enum_async(aconn, testenum, encoding, fmt_in, fmt_out): async with aconn.cursor(binary=fmt_out) as cur: for label in labels: cur = await cur.execute(f"select %{fmt_in}::{name}", [label]) - assert (await cur.fetchone())[0] == enum(label) + assert (await cur.fetchone())[0] == enum[label] cur = await cur.execute(f"select %{fmt_in}", [list(enum)]) assert (await cur.fetchone())[0] == list(enum)