]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(enum): code review fixes
authorVladimir Osokin <ertaquo@gmail.com>
Thu, 14 Apr 2022 10:37:56 +0000 (15:37 +0500)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 14:05:46 +0000 (16:05 +0200)
* moved documentation to `docs/basic/adapt.rst`;
* use encoding in `EnumLoader`;
* oid-based dumpers;
* rewritten tests;
* minor fixes.

docs/basic/adapt.rst
docs/basic/pgtypes.rst
docs/news.rst
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/postgres.py
psycopg/psycopg/types/enum.py
tests/types/test_enum.py

index 47f92cc7aab543dd38fdb09430dbbd23b845ff84..d5c392bda91516a94ec4c1b316d6130c8cb93c68 100644 (file)
@@ -371,3 +371,77 @@ address types`__:
 
     >>> conn.execute("select '::ffff:1.2.3.0/120'::cidr").fetchone()[0]
     IPv6Network('::ffff:102:300/120')
+
+.. _adapt-enum:
+
+Enum adaptation
+---------------
+
+.. versionadded:: 3.1
+
+Psycopg can adapt PostgreSQL enum types (created with the
+|CREATE TYPE AS ENUM|_ command)
+
+.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)`
+.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html
+
+Before using a enum type it is necessary to get information about it
+using the `~psycopg.types.enum.EnumInfo` class and to register it
+using `~psycopg.types.enum.register_enum()`.
+
+If you use enum array and your enum labels contains comma, you should use
+the binary format to return values. See :ref:`binary-data` for details.
+
+.. autoclass:: psycopg.types.enum.EnumInfo
+
+   `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its
+   documentation for the generic usage, especially the
+   `~psycopg.types.TypeInfo.fetch()` method.
+
+   .. attribute:: enum_labels
+
+       Contains labels available in the PostgreSQL enum type.
+
+   .. attribute:: python_type
+
+       After `register_enum()` is called, it will contain the python type
+       mapping to the registered enum.
+
+.. autofunction:: psycopg.types.enum.register_enum
+
+   After registering, fetching data of the registered enum will cast
+   PostgreSQL enum labels into corresponding Python enum labels.
+
+   If no ``python_type`` is specified, a `Enum` is created based on
+   PostgreSQL enum labels.
+
+Example::
+
+    >>> from enum import Enum
+    >>> from psycopg.types.enum import EnumInfo, register_enum
+
+    >>> class UserRole(str, Enum):
+    ...     ADMIN = "ADMIN"
+    ...     EDITOR = "EDITOR"
+    ...     GUEST = "GUEST"
+
+    >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')")
+
+    >>> info = EnumInfo.fetch(conn, "user_role")
+    >>> register_enum(info, UserRole, conn)
+
+    >>> some_editor = info.python_type("EDITOR")
+    >>> some_editor
+    <UserRole.EDITOR: 'EDITOR'>
+
+    >>> conn.execute(
+    ...     "SELECT pg_typeof(%(editor)s), %(editor)s",
+    ...     { "editor": some_editor }
+    ... ).fetchone()
+    ('user_role', <UserRole.EDITOR: 'EDITOR'>)
+
+    >>> conn.execute(
+    ...     "SELECT (%s, %s)::user_role[]",
+    ...     [UserRole.ADMIN, UserRole.GUEST]
+    ... ).fetchone()
+    [<UserRole.ADMIN: 'ADMIN'>, <UserRole.GUEST: 'GUEST'>]
index 6a8d450a3ba398eb468cbd7f9e36170c58f80633..471187707ace5ddd5b1230da88d223166408fa7b 100644 (file)
@@ -366,72 +366,3 @@ connection or cursor), other connections and cursors will be unaffected::
     ... """).fetchone()[0]
     '0101000020E61000009279E40F061E48C0F2B0506B9A1F3440'
 
-.. _adapt-enum:
-
-Enum types casting
-------------------
-
-Psycopg can adapt PostgreSQL enum types (created with the
-|CREATE TYPE AS ENUM|_ command)
-
-.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)`
-.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html
-
-Before using a enum type it is necessary to get information about it
-using the `~psycopg.types.enum.EnumInfo` class and to register it
-using `~psycopg.types.enum.register_enum()`.
-
-.. autoclass:: psycopg.types.enum.EnumInfo
-
-   `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its
-   documentation for the generic usage, especially the
-   `~psycopg.types.TypeInfo.fetch()` method.
-
-   .. attribute:: enum_labels
-
-       Contains labels available in the PostgreSQL enum type.
-
-   .. attribute:: python_type
-
-       After `register_composite()` is called, it will contain the python type
-       mapping to the registered enum.
-
-.. autofunction:: psycopg.types.enum.register_enum
-
-   After registering, fetching data of the registered enum will cast
-   PostgreSQL enum labels into corresponding Python enum labels.
-
-   If no ``python_type`` is specified, a `Enum` is created based on
-   PostgreSQL enum labels.
-
-Example::
-
-    >>> from enum import Enum
-    >>> from psycopg.types.enum import EnumInfo, register_enum
-
-    >>> class UserRole(str, Enum):
-    ...     ADMIN = "ADMIN"
-    ...     EDITOR = "EDITOR"
-    ...     GUEST = "GUEST"
-
-    >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')")
-
-    >>> info = EnumInfo.fetch(conn, "user_role")
-    >>> register_enum(info, UserRole, conn)
-
-    >>> some_editor = info.python_type("EDITOR")
-    >>> some_editor
-    <UserRole.EDITOR: 'EDITOR'>
-
-    >>> conn.execute(
-    ...     "SELECT pg_typeof(%(editor)s), %(editor)s",
-    ...     { "editor": some_editor }
-    ... ).fetchone()
-    ('user_role', <UserRole.EDITOR: 'EDITOR'>)
-
-    >>> conn.execute(
-    ...     "SELECT (%s, %s)::user_role[]",
-    ...     [UserRole.ADMIN, UserRole.GUEST]
-    ... ).fetchone()
-    [<UserRole.ADMIN: 'ADMIN'>, <UserRole.GUEST: 'GUEST'>]
-
index d7a2cf8cc3613ef1ed6e616404b264e7a54891a7..fe9daaccb07b190ded76e9299b3a2c64c7e8b596 100644 (file)
@@ -24,6 +24,7 @@ Psycopg 3.1 (unreleased)
 - Allow `bytearray`/`memoryview` data too as `Copy.write()` input
   (:ticket:`#254`).
 - Drop support for Python 3.6.
+- Add `enum` support (:ticket:`#274`).
 
 
 Psycopg 3.0.12 (unreleased)
index 87bfbbe0d29e1b45d166e010444082bc99e40842..dfe770edc5c727699f46bbd75595509e9a37eac2 100644 (file)
@@ -6,7 +6,7 @@ information to the adapters if needed.
 """
 
 # Copyright (C) 2020 The Psycopg Team
-
+from enum import Enum
 from typing import Any, Dict, Iterator, Optional, overload
 from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING
 
@@ -299,8 +299,8 @@ class EnumInfo(TypeInfo):
     ):
         super().__init__(name, oid, array_oid)
         self.enum_labels = enum_labels
-        # Will be set by register() if the `python_type` is a type
-        self.python_type: Optional[type] = None
+        # Will be set by register_enum()
+        self.python_type: Optional[Type[Enum]] = None
 
     @classmethod
     def _get_info_query(
index 50ab959481a3870c2775769fc07a64f5a4fbc037..29816c24817c7ac9464dcb160f1a611cbca2e145 100644 (file)
@@ -107,13 +107,14 @@ TEXT_ARRAY_OID = types["text"].array_oid
 
 def register_default_adapters(context: AdaptContext) -> None:
 
-    from .types import array, bool, composite, datetime, json, multirange
+    from .types import array, bool, composite, datetime, enum, json, multirange
     from .types import net, none, numeric, range, string, uuid
 
     array.register_default_adapters(context)
     bool.register_default_adapters(context)
     composite.register_default_adapters(context)
     datetime.register_default_adapters(context)
+    enum.register_default_adapters(context)
     json.register_default_adapters(context)
     multirange.register_default_adapters(context)
     net.register_default_adapters(context)
index 0c412a0418574fb4195cb93c9a72228385471cbc..3c3ef79f53ceba8fcd422e012e7b8fc419f70baf 100644 (file)
@@ -2,29 +2,39 @@
 Adapters for the enum type.
 """
 from enum import Enum
-from typing import Optional, TypeVar, Generic, Type
+from typing import Optional, TypeVar, Generic, Type, Dict, Any
 
+from ..adapt import Dumper, Loader
 from .string import StrBinaryDumper, StrDumper
 from .. import postgres
+from .._encodings import pgconn_encoding
 from .._typeinfo import EnumInfo as EnumInfo  # exported here
 from ..abc import AdaptContext
-from ..adapt import Buffer, Loader
+from ..adapt import Buffer
 from ..pq import Format
 
+
 E = TypeVar("E", bound=Enum)
 
 
 class EnumLoader(Loader, Generic[E]):
     format = Format.TEXT
+    _encoding = "utf-8"
     python_type: Type[E]
 
     def __init__(self, oid: int, context: Optional[AdaptContext] = None):
         super().__init__(oid, context)
+        conn = self.connection
+        if conn:
+            self._encoding = pgconn_encoding(conn.pgconn)
 
     def load(self, data: Buffer) -> E:
         if isinstance(data, memoryview):
-            data = bytes(data)
-        return self.python_type(data.decode())
+            label = bytes(data).decode(self._encoding)
+        else:
+            label = data.decode(self._encoding)
+
+        return self.python_type(label)
 
 
 class EnumBinaryLoader(EnumLoader[E]):
@@ -32,7 +42,8 @@ class EnumBinaryLoader(EnumLoader[E]):
 
 
 class EnumDumper(StrDumper):
-    pass
+    def dump(self, obj: str) -> bytes:
+        return super().dump(obj)
 
 
 class EnumBinaryDumper(StrBinaryDumper):
@@ -55,7 +66,7 @@ def register_enum(
     .. note::
         Only string enums are supported.
 
-        Use binary format if any of your enum labels contains comma:
+        Use binary format if you use enum array and enum labels contains comma:
             connection.execute(..., binary=True)
     """
 
@@ -77,18 +88,30 @@ def register_enum(
 
     info.register(context)
 
-    base = EnumLoader
-    name = f"{info.name.title()}{base.__name__}"
-    attribs = {"python_type": info.python_type}
-    loader = type(name, (base,), attribs)
+    attribs: Dict[str, Any] = {"python_type": info.python_type}
+
+    loader_base = EnumLoader
+    name = f"{info.name.title()}{loader_base.__name__}"
+    loader = type(name, (loader_base,), attribs)
     adapters.register_loader(info.oid, loader)
 
-    base = EnumBinaryLoader
-    name = f"{info.name.title()}{base.__name__}"
-    attribs = {"python_type": info.python_type}
-    loader = type(name, (base,), attribs)
+    loader_base = EnumBinaryLoader
+    name = f"{info.name.title()}{loader_base.__name__}"
+    loader = type(name, (loader_base,), attribs)
     adapters.register_loader(info.oid, loader)
 
+    attribs = {"oid": info.oid}
+
+    dumper_base: Type[Dumper] = EnumBinaryDumper
+    name = f"{info.name.title()}{dumper_base.__name__}"
+    dumper = type(name, (dumper_base,), attribs)
+    adapters.register_dumper(info.python_type, dumper)
+
+    dumper_base = EnumDumper
+    name = f"{info.name.title()}{dumper_base.__name__}"
+    dumper = type(name, (dumper_base,), attribs)
+    adapters.register_dumper(info.python_type, dumper)
+
 
 def register_default_adapters(context: AdaptContext) -> None:
     context.adapters.register_dumper(Enum, EnumBinaryDumper)
index 535dd1473181647b0490609c40f3834965db952c..c69988c2b1df3d118bb94c499d53017bf581cf36 100644 (file)
@@ -2,175 +2,134 @@ from enum import Enum
 
 import pytest
 
-from psycopg import pq
+from psycopg import pq, sql
 from psycopg.adapt import PyFormat
 from psycopg.types.enum import EnumInfo, register_enum
 
 
-class _TestEnum(str, Enum):
+class StrTestEnum(str, Enum):
     ONE = "ONE"
     TWO = "TWO"
     THREE = "THREE"
 
 
-@pytest.fixture(scope="session")
-def testenum(svcconn):
-    cur = svcconn.cursor()
-    cur.execute(
-        """
-        drop type if exists testenum cascade;
-        create type testenum as enum('ONE', 'TWO', 'THREE');
-        """
-    )
-    return EnumInfo.fetch(svcconn, "testenum")
+class NonAscciEnum(str, Enum):
+    XE0 = "x\xe0"
+    XE1 = "x\xe1"
+
 
+enum_cases = [
+    ("strtestenum", StrTestEnum, [item.value for item in StrTestEnum]),
+    ("nonasccienum", NonAscciEnum, [item.value for item in NonAscciEnum]),
+]
+
+encodings = ["utf8", "latin1"]
+
+
+@pytest.fixture(scope="session", params=enum_cases)
+def testenum(request, svcconn):
+    name, enum, labels = request.param
+    quoted_labels = [sql.quote(label) for label in labels]
 
-@pytest.fixture(scope="session")
-def nonasciienum(svcconn):
     cur = svcconn.cursor()
     cur.execute(
-        """
-        drop type if exists nonasciienum cascade;
-        create type nonasciienum as enum ('x\xe0');
+        f"""
+        drop type if exists {name} cascade;
+        create type {name} as enum({','.join(quoted_labels)});
         """
     )
-    return EnumInfo.fetch(svcconn, "nonasciienum")
+    return name, enum, labels
 
 
 def test_fetch_info(conn, testenum):
-    assert testenum.name == "testenum"
-    assert testenum.oid > 0
-    assert testenum.oid != testenum.array_oid > 0
-    assert len(testenum.enum_labels) == 3
-    assert testenum.enum_labels == ["ONE", "TWO", "THREE"]
-
+    name, enum, labels = testenum
 
-@pytest.mark.parametrize("fmt_in", PyFormat)
-def test_enum_insert_generic(conn, testenum, fmt_in):
-    # No regstration, test for generic enum
-    conn.execute("create table test_enum_insert (id int primary key, val testenum)")
-    cur = conn.cursor()
-    cur.executemany(
-        f"insert into test_enum_insert (id, val) values (%s, %{fmt_in})",
-        list(enumerate(_TestEnum)),
-    )
-    cur.execute("select id, val from test_enum_insert order by id")
-    recs = cur.fetchall()
-    assert recs == [(0, "ONE"), (1, "TWO"), (2, "THREE")]
+    info = EnumInfo.fetch(conn, name)
+    assert info.name == name
+    assert info.oid > 0
+    assert info.oid != info.array_oid > 0
+    assert len(info.enum_labels) == len(labels)
+    assert info.enum_labels == labels
 
 
+@pytest.mark.parametrize("encoding", encodings)
 @pytest.mark.parametrize("fmt_in", PyFormat)
-def test_enum_dumper(conn, testenum, fmt_in):
-    register_enum(testenum, _TestEnum, conn)
-
-    cur = conn.execute(f"select %{fmt_in}", [_TestEnum.ONE])
-    assert cur.fetchone()[0] is _TestEnum.ONE
-
-
 @pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_loader(conn, testenum, fmt_out):
-    register_enum(testenum, _TestEnum, conn)
-
-    cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
-    assert cur.fetchone()[0] == _TestEnum.ONE
-
+def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
+    conn.execute(f"set client_encoding to {encoding}")
 
-@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_array_loader(conn, testenum, fmt_out):
-    register_enum(testenum, _TestEnum, conn)
+    name, enum, labels = testenum
+    register_enum(EnumInfo.fetch(conn, name), enum, conn)
 
-    cur = conn.execute(
-        "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
-        binary=fmt_out,
-    )
-    assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
+    for label in labels:
+        cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
+        assert cur.fetchone()[0] == enum(label)
 
 
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_loader_generated(conn, testenum, fmt_out):
-    register_enum(testenum, context=conn)
+def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out):
+    conn.execute(f"set client_encoding to {encoding}")
+
+    name, enum, labels = testenum
+    register_enum(EnumInfo.fetch(conn, name), enum, conn)
 
-    cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
-    assert cur.fetchone()[0] == testenum.python_type.ONE
+    for item in enum:
+        cur = conn.execute(f"select %{fmt_in}", [item], binary=fmt_out)
+        assert cur.fetchone()[0] == item
 
 
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_array_loader_generated(conn, testenum, fmt_out):
-    register_enum(testenum, context=conn)
+def test_generic_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
+    conn.execute(f"set client_encoding to {encoding}")
 
-    cur = conn.execute(
-        "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
-        binary=fmt_out,
-    )
-    assert cur.fetchone()[0] == [testenum.python_type.ONE, testenum.python_type.TWO]
+    name, enum, labels = testenum
+    info = EnumInfo.fetch(conn, name)
+    register_enum(info, None, conn)
 
+    for label in labels:
+        cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
+        assert cur.fetchone()[0] == info.python_type(label)
 
-@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum(conn, testenum, fmt_out):
-    register_enum(testenum, _TestEnum, conn)
 
-    cur = conn.cursor()
-    cur.execute(
-        """
-        drop table if exists testenumtable;
-        create table testenumtable (id serial primary key, value testenum);
-        """
-    )
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader(conn, testenum, encoding, fmt_in, fmt_out):
+    conn.execute(f"set client_encoding to {encoding}")
 
-    cur.execute(
-        "insert into testenumtable (value) values (%s), (%s), (%s)",
-        (
-            _TestEnum.ONE,
-            _TestEnum.TWO,
-            _TestEnum.THREE,
-        ),
-    )
+    name, enum, labels = testenum
+    register_enum(EnumInfo.fetch(conn, name), enum, conn)
 
-    cur = conn.execute("select value from testenumtable order by id", binary=fmt_out)
-    assert cur.fetchone()[0] == _TestEnum.ONE
-    assert cur.fetchone()[0] == _TestEnum.TWO
-    assert cur.fetchone()[0] == _TestEnum.THREE
+    cur = conn.execute(f"select %{fmt_in}::{name}[]", [labels], binary=fmt_out)
+    assert cur.fetchone()[0] == list(enum)
 
 
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_array(conn, testenum, fmt_out):
-    register_enum(testenum, _TestEnum, conn)
-
-    cur = conn.cursor()
-    cur.execute(
-        """
-        drop table if exists testenumtable;
-        create table testenumtable (id serial primary key, values testenum[]);
-        """
-    )
+def test_enum_array_dumper(conn, testenum, encoding, fmt_in, fmt_out):
+    conn.execute(f"set client_encoding to {encoding}")
 
-    cur.execute(
-        "insert into testenumtable (values) values (%s), (%s), (%s)",
-        (
-            [_TestEnum.ONE, _TestEnum.TWO],
-            [_TestEnum.TWO, _TestEnum.THREE],
-            [_TestEnum.THREE, _TestEnum.ONE],
-        ),
-    )
+    name, enum, labels = testenum
+    register_enum(EnumInfo.fetch(conn, name), enum, conn)
 
-    cur = conn.execute("select values from testenumtable order by id", binary=fmt_out)
-    assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
-    assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE]
-    assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE]
+    cur = conn.execute(f"select %{fmt_in}", [list(enum)], binary=fmt_out)
+    assert cur.fetchone()[0] == list(enum)
 
 
-@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("encoding", encodings)
 @pytest.mark.parametrize("fmt_in", PyFormat)
-@pytest.mark.parametrize("encoding", ["utf8", "latin1"])
-def test_non_ascii_enum(conn, nonasciienum, fmt_out, fmt_in, encoding):
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_array_loader(conn, testenum, encoding, fmt_in, fmt_out):
     conn.execute(f"set client_encoding to {encoding}")
-    info = EnumInfo.fetch(conn, "nonasciienum")
-    register_enum(info, context=conn)
-    assert [x.name for x in info.python_type] == ["x\xe0"]
-    val = list(info.python_type)[0]
 
-    cur = conn.execute("select 'x\xe0'::nonasciienum", binary=fmt_out)
-    assert cur.fetchone()[0] is val
+    name, enum, labels = testenum
+    info = EnumInfo.fetch(conn, name)
+    register_enum(info, enum, conn)
 
-    cur = conn.execute(f"select %{fmt_in}", [val])
-    assert cur.fetchone()[0] is val
+    cur = conn.execute(f"select %{fmt_in}::{name}[]", [labels], binary=fmt_out)
+    assert cur.fetchone()[0] == list(info.python_type)