]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat(enum): dump enums by keys instead of values
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 18 Apr 2022 17:15:35 +0000 (19:15 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Apr 2022 02:30:26 +0000 (04:30 +0200)
This removes the need for enums to be str-based. It also removes the
asymmetry whereby automatically generated enums were *not* string based,
but pure enums.

docs/basic/adapt.rst
psycopg/psycopg/types/enum.py
tests/types/test_enum.py

index d5c392bda91516a94ec4c1b316d6130c8cb93c68..139c5effc5eedaa6c1ab8a6ad165bdd9557a236f 100644 (file)
@@ -389,9 +389,6 @@ Before using a enum type it is necessary to get information about it
 using the `~psycopg.types.enum.EnumInfo` class and to register it
 using `~psycopg.types.enum.register_enum()`.
 
-If you use enum array and your enum labels contains comma, you should use
-the binary format to return values. See :ref:`binary-data` for details.
-
 .. autoclass:: psycopg.types.enum.EnumInfo
 
    `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its
@@ -417,31 +414,31 @@ the binary format to return values. See :ref:`binary-data` for details.
 
 Example::
 
-    >>> from enum import Enum
+    >>> from enum import Enum, auto
     >>> from psycopg.types.enum import EnumInfo, register_enum
 
-    >>> class UserRole(str, Enum):
-    ...     ADMIN = "ADMIN"
-    ...     EDITOR = "EDITOR"
-    ...     GUEST = "GUEST"
+    >>> class UserRole(Enum):
+    ...     ADMIN = auto()
+    ...     EDITOR = auto()
+    ...     GUEST = auto()
 
-    >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')")
+    >>> conn.execute("CREATE TYPE user_role AS ENUM ('ADMIN', 'EDITOR', 'GUEST')")
 
     >>> info = EnumInfo.fetch(conn, "user_role")
     >>> register_enum(info, UserRole, conn)
 
-    >>> some_editor = info.python_type("EDITOR")
+    >>> some_editor = info.python_type.EDITOR
     >>> some_editor
-    <UserRole.EDITOR: 'EDITOR'>
+    <UserRole.EDITOR: 2>
 
     >>> conn.execute(
     ...     "SELECT pg_typeof(%(editor)s), %(editor)s",
-    ...     { "editor": some_editor }
+    ...     {"editor": some_editor}
     ... ).fetchone()
-    ('user_role', <UserRole.EDITOR: 'EDITOR'>)
+    ('user_role', <UserRole.EDITOR: 2>)
 
     >>> conn.execute(
-    ...     "SELECT (%s, %s)::user_role[]",
+    ...     "SELECT ARRAY[%s, %s]",
     ...     [UserRole.ADMIN, UserRole.GUEST]
     ... ).fetchone()
-    [<UserRole.ADMIN: 'ADMIN'>, <UserRole.GUEST: 'GUEST'>]
+    [<UserRole.ADMIN: 1>, <UserRole.GUEST: 3>]
index 38e15ad3a4c3271e5baecbae3166294076fff36c..aa99ab52f97d44233d68397fc9452ad7d991e661 100644 (file)
@@ -2,10 +2,9 @@
 Adapters for the enum type.
 """
 from enum import Enum
-from typing import Optional, TypeVar, Generic, Type, Dict, Any
+from typing import Type, Any, Dict, Generic, Optional, TypeVar, cast
 
 from ..adapt import Dumper, Loader
-from .string import StrBinaryDumper, StrDumper
 from .. import postgres
 from .._encodings import pgconn_encoding
 from .._typeinfo import EnumInfo as EnumInfo  # exported here
@@ -18,7 +17,6 @@ E = TypeVar("E", bound=Enum)
 
 
 class EnumLoader(Loader, Generic[E]):
-    format = Format.TEXT
     _encoding = "utf-8"
     python_type: Type[E]
 
@@ -34,19 +32,29 @@ class EnumLoader(Loader, Generic[E]):
         else:
             label = data.decode(self._encoding)
 
-        return self.python_type(label)
+        return self.python_type[label]
 
 
 class EnumBinaryLoader(EnumLoader[E]):
     format = Format.BINARY
 
 
-class EnumDumper(StrDumper):
-    pass
+class EnumDumper(Dumper):
+    _encoding = "utf-8"
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+
+        conn = self.connection
+        if conn:
+            self._encoding = pgconn_encoding(conn.pgconn)
 
+    def dump(self, value: E) -> Buffer:
+        return value.name.encode(self._encoding)
 
-class EnumBinaryDumper(StrBinaryDumper):
-    pass
+
+class EnumBinaryDumper(EnumDumper):
+    format = Format.BINARY
 
 
 def register_enum(
@@ -61,30 +69,19 @@ def register_enum(
         the type will be generated and put into info.python_type.
     :param context: The context where to register the adapters. If `!None`,
         register it globally.
-
-    .. note::
-        Only string enums are supported.
-
-        Use binary format if you use enum array and enum labels contains comma:
-            connection.execute(..., binary=True)
     """
 
     if not info:
         raise TypeError("no info passed. Is the requested enum available?")
 
-    if python_type is not None:
-        if {type(item.value) for item in python_type} != {str}:
-            raise TypeError("invalid enum value type (string is the only supported)")
-
-        info.python_type = python_type
-    else:
-        info.python_type = Enum(  # type: ignore
-            info.name.title(),
-            {label: label for label in info.enum_labels},
+    if python_type is None:
+        python_type = cast(
+            Type[E],
+            Enum(info.name.title(), {label: label for label in info.enum_labels}),
         )
 
+    info.python_type = python_type
     adapters = context.adapters if context else postgres.adapters
-
     info.register(context)
 
     attribs: Dict[str, Any] = {"python_type": info.python_type}
index 8a1ffc7f4af92a3bdf0577aded19c1ff58d43db6..eda761ee58d2620dbe651cad2542888cd55e6ba0 100644 (file)
@@ -1,4 +1,4 @@
-from enum import Enum
+from enum import Enum, auto
 
 import pytest
 
@@ -7,28 +7,39 @@ from psycopg.adapt import PyFormat
 from psycopg.types.enum import EnumInfo, register_enum
 
 
+class PureTestEnum(Enum):
+    FOO = auto()
+    BAR = auto()
+    BAZ = auto()
+
+
 class StrTestEnum(str, Enum):
     ONE = "ONE"
     TWO = "TWO"
     THREE = "THREE"
 
 
-class NonAsciiEnum(str, Enum):
-    XE0 = "x\xe0"
-    XE1 = "x\xe1"
+NonAsciiEnum = Enum(
+    "NonAsciiEnum", {"X\xe0": "x\xe0", "X\xe1": "x\xe1", "COMMA": "foo,bar"}, type=str
+)
+
+
+class IntTestEnum(int, Enum):
+    ONE = 1
+    TWO = 2
+    THREE = 3
 
 
-enum_cases = [
-    ("strtestenum", StrTestEnum, [item.value for item in StrTestEnum]),
-    ("nonasciienum", NonAsciiEnum, [item.value for item in NonAsciiEnum]),
-]
+enum_cases = [PureTestEnum, StrTestEnum, NonAsciiEnum, IntTestEnum]
 
 encodings = ["utf8", "latin1"]
 
 
 @pytest.fixture(scope="session", params=enum_cases)
 def testenum(request, svcconn):
-    name, enum, labels = request.param
+    enum = request.param
+    name = enum.__name__.lower()
+    labels = list(enum.__members__.keys())
     cur = svcconn.cursor()
     cur.execute(
         sql.SQL(
@@ -52,6 +63,16 @@ def test_fetch_info(conn, testenum):
     assert info.enum_labels == labels
 
 
+def test_register_makes_a_type(conn, testenum):
+    name, enum, labels = testenum
+    info = EnumInfo.fetch(conn, name)
+    assert info
+    assert info.python_type is None
+    register_enum(info, context=conn)
+    assert info.python_type is not None
+    assert [e.name for e in info.python_type] == [e.name for e in enum]
+
+
 @pytest.mark.parametrize("encoding", encodings)
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
@@ -63,7 +84,7 @@ def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
 
     for label in labels:
         cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
-        assert cur.fetchone()[0] == enum(label)
+        assert cur.fetchone()[0] == enum[label]
 
 
 @pytest.mark.parametrize("fmt_in", PyFormat)
@@ -78,7 +99,7 @@ def test_enum_loader_sqlascii(conn, testenum, fmt_in, fmt_out):
 
     for label in labels:
         cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
-        assert cur.fetchone()[0] == enum(label)
+        assert cur.fetchone()[0] == enum[label]
 
 
 @pytest.mark.parametrize("encoding", encodings)
@@ -190,7 +211,7 @@ async def test_enum_async(aconn, testenum, encoding, fmt_in, fmt_out):
     async with aconn.cursor(binary=fmt_out) as cur:
         for label in labels:
             cur = await cur.execute(f"select %{fmt_in}::{name}", [label])
-            assert (await cur.fetchone())[0] == enum(label)
+            assert (await cur.fetchone())[0] == enum[label]
 
         cur = await cur.execute(f"select %{fmt_in}", [list(enum)])
         assert (await cur.fetchone())[0] == list(enum)