From: Vladimir Osokin Date: Thu, 14 Apr 2022 10:37:56 +0000 (+0500) Subject: fix(enum): code review fixes X-Git-Tag: 3.1~137^2~19 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a9b57a7ef6478202478ca67a3c4603e2b786c1a3;p=thirdparty%2Fpsycopg.git fix(enum): code review fixes * moved documentation to `docs/basic/adapt.rst`; * use encoding in `EnumLoader`; * oid-based dumpers; * rewritten tests; * minor fixes. --- diff --git a/docs/basic/adapt.rst b/docs/basic/adapt.rst index 47f92cc7a..d5c392bda 100644 --- a/docs/basic/adapt.rst +++ b/docs/basic/adapt.rst @@ -371,3 +371,77 @@ address types`__: >>> conn.execute("select '::ffff:1.2.3.0/120'::cidr").fetchone()[0] IPv6Network('::ffff:102:300/120') + +.. _adapt-enum: + +Enum adaptation +--------------- + +.. versionadded:: 3.1 + +Psycopg can adapt PostgreSQL enum types (created with the +|CREATE TYPE AS ENUM|_ command) + +.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)` +.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html + +Before using a enum type it is necessary to get information about it +using the `~psycopg.types.enum.EnumInfo` class and to register it +using `~psycopg.types.enum.register_enum()`. + +If you use enum array and your enum labels contains comma, you should use +the binary format to return values. See :ref:`binary-data` for details. + +.. autoclass:: psycopg.types.enum.EnumInfo + + `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its + documentation for the generic usage, especially the + `~psycopg.types.TypeInfo.fetch()` method. + + .. attribute:: enum_labels + + Contains labels available in the PostgreSQL enum type. + + .. attribute:: python_type + + After `register_enum()` is called, it will contain the python type + mapping to the registered enum. + +.. autofunction:: psycopg.types.enum.register_enum + + After registering, fetching data of the registered enum will cast + PostgreSQL enum labels into corresponding Python enum labels. + + If no ``python_type`` is specified, a `Enum` is created based on + PostgreSQL enum labels. + +Example:: + + >>> from enum import Enum + >>> from psycopg.types.enum import EnumInfo, register_enum + + >>> class UserRole(str, Enum): + ... ADMIN = "ADMIN" + ... EDITOR = "EDITOR" + ... GUEST = "GUEST" + + >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')") + + >>> info = EnumInfo.fetch(conn, "user_role") + >>> register_enum(info, UserRole, conn) + + >>> some_editor = info.python_type("EDITOR") + >>> some_editor + + + >>> conn.execute( + ... "SELECT pg_typeof(%(editor)s), %(editor)s", + ... { "editor": some_editor } + ... ).fetchone() + ('user_role', ) + + >>> conn.execute( + ... "SELECT (%s, %s)::user_role[]", + ... [UserRole.ADMIN, UserRole.GUEST] + ... ).fetchone() + [, ] diff --git a/docs/basic/pgtypes.rst b/docs/basic/pgtypes.rst index 6a8d450a3..471187707 100644 --- a/docs/basic/pgtypes.rst +++ b/docs/basic/pgtypes.rst @@ -366,72 +366,3 @@ connection or cursor), other connections and cursors will be unaffected:: ... """).fetchone()[0] '0101000020E61000009279E40F061E48C0F2B0506B9A1F3440' -.. _adapt-enum: - -Enum types casting ------------------- - -Psycopg can adapt PostgreSQL enum types (created with the -|CREATE TYPE AS ENUM|_ command) - -.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)` -.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html - -Before using a enum type it is necessary to get information about it -using the `~psycopg.types.enum.EnumInfo` class and to register it -using `~psycopg.types.enum.register_enum()`. - -.. autoclass:: psycopg.types.enum.EnumInfo - - `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its - documentation for the generic usage, especially the - `~psycopg.types.TypeInfo.fetch()` method. - - .. attribute:: enum_labels - - Contains labels available in the PostgreSQL enum type. - - .. attribute:: python_type - - After `register_composite()` is called, it will contain the python type - mapping to the registered enum. - -.. autofunction:: psycopg.types.enum.register_enum - - After registering, fetching data of the registered enum will cast - PostgreSQL enum labels into corresponding Python enum labels. - - If no ``python_type`` is specified, a `Enum` is created based on - PostgreSQL enum labels. - -Example:: - - >>> from enum import Enum - >>> from psycopg.types.enum import EnumInfo, register_enum - - >>> class UserRole(str, Enum): - ... ADMIN = "ADMIN" - ... EDITOR = "EDITOR" - ... GUEST = "GUEST" - - >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')") - - >>> info = EnumInfo.fetch(conn, "user_role") - >>> register_enum(info, UserRole, conn) - - >>> some_editor = info.python_type("EDITOR") - >>> some_editor - - - >>> conn.execute( - ... "SELECT pg_typeof(%(editor)s), %(editor)s", - ... { "editor": some_editor } - ... ).fetchone() - ('user_role', ) - - >>> conn.execute( - ... "SELECT (%s, %s)::user_role[]", - ... [UserRole.ADMIN, UserRole.GUEST] - ... ).fetchone() - [, ] - diff --git a/docs/news.rst b/docs/news.rst index d7a2cf8cc..fe9daaccb 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -24,6 +24,7 @@ Psycopg 3.1 (unreleased) - Allow `bytearray`/`memoryview` data too as `Copy.write()` input (:ticket:`#254`). - Drop support for Python 3.6. +- Add `enum` support (:ticket:`#274`). Psycopg 3.0.12 (unreleased) diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index 87bfbbe0d..dfe770edc 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -6,7 +6,7 @@ information to the adapters if needed. """ # Copyright (C) 2020 The Psycopg Team - +from enum import Enum from typing import Any, Dict, Iterator, Optional, overload from typing import Sequence, Tuple, Type, TypeVar, Union, TYPE_CHECKING @@ -299,8 +299,8 @@ class EnumInfo(TypeInfo): ): super().__init__(name, oid, array_oid) self.enum_labels = enum_labels - # Will be set by register() if the `python_type` is a type - self.python_type: Optional[type] = None + # Will be set by register_enum() + self.python_type: Optional[Type[Enum]] = None @classmethod def _get_info_query( diff --git a/psycopg/psycopg/postgres.py b/psycopg/psycopg/postgres.py index 50ab95948..29816c248 100644 --- a/psycopg/psycopg/postgres.py +++ b/psycopg/psycopg/postgres.py @@ -107,13 +107,14 @@ TEXT_ARRAY_OID = types["text"].array_oid def register_default_adapters(context: AdaptContext) -> None: - from .types import array, bool, composite, datetime, json, multirange + from .types import array, bool, composite, datetime, enum, json, multirange from .types import net, none, numeric, range, string, uuid array.register_default_adapters(context) bool.register_default_adapters(context) composite.register_default_adapters(context) datetime.register_default_adapters(context) + enum.register_default_adapters(context) json.register_default_adapters(context) multirange.register_default_adapters(context) net.register_default_adapters(context) diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py index 0c412a041..3c3ef79f5 100644 --- a/psycopg/psycopg/types/enum.py +++ b/psycopg/psycopg/types/enum.py @@ -2,29 +2,39 @@ Adapters for the enum type. """ from enum import Enum -from typing import Optional, TypeVar, Generic, Type +from typing import Optional, TypeVar, Generic, Type, Dict, Any +from ..adapt import Dumper, Loader from .string import StrBinaryDumper, StrDumper from .. import postgres +from .._encodings import pgconn_encoding from .._typeinfo import EnumInfo as EnumInfo # exported here from ..abc import AdaptContext -from ..adapt import Buffer, Loader +from ..adapt import Buffer from ..pq import Format + E = TypeVar("E", bound=Enum) class EnumLoader(Loader, Generic[E]): format = Format.TEXT + _encoding = "utf-8" python_type: Type[E] def __init__(self, oid: int, context: Optional[AdaptContext] = None): super().__init__(oid, context) + conn = self.connection + if conn: + self._encoding = pgconn_encoding(conn.pgconn) def load(self, data: Buffer) -> E: if isinstance(data, memoryview): - data = bytes(data) - return self.python_type(data.decode()) + label = bytes(data).decode(self._encoding) + else: + label = data.decode(self._encoding) + + return self.python_type(label) class EnumBinaryLoader(EnumLoader[E]): @@ -32,7 +42,8 @@ class EnumBinaryLoader(EnumLoader[E]): class EnumDumper(StrDumper): - pass + def dump(self, obj: str) -> bytes: + return super().dump(obj) class EnumBinaryDumper(StrBinaryDumper): @@ -55,7 +66,7 @@ def register_enum( .. note:: Only string enums are supported. - Use binary format if any of your enum labels contains comma: + Use binary format if you use enum array and enum labels contains comma: connection.execute(..., binary=True) """ @@ -77,18 +88,30 @@ def register_enum( info.register(context) - base = EnumLoader - name = f"{info.name.title()}{base.__name__}" - attribs = {"python_type": info.python_type} - loader = type(name, (base,), attribs) + attribs: Dict[str, Any] = {"python_type": info.python_type} + + loader_base = EnumLoader + name = f"{info.name.title()}{loader_base.__name__}" + loader = type(name, (loader_base,), attribs) adapters.register_loader(info.oid, loader) - base = EnumBinaryLoader - name = f"{info.name.title()}{base.__name__}" - attribs = {"python_type": info.python_type} - loader = type(name, (base,), attribs) + loader_base = EnumBinaryLoader + name = f"{info.name.title()}{loader_base.__name__}" + loader = type(name, (loader_base,), attribs) adapters.register_loader(info.oid, loader) + attribs = {"oid": info.oid} + + dumper_base: Type[Dumper] = EnumBinaryDumper + name = f"{info.name.title()}{dumper_base.__name__}" + dumper = type(name, (dumper_base,), attribs) + adapters.register_dumper(info.python_type, dumper) + + dumper_base = EnumDumper + name = f"{info.name.title()}{dumper_base.__name__}" + dumper = type(name, (dumper_base,), attribs) + adapters.register_dumper(info.python_type, dumper) + def register_default_adapters(context: AdaptContext) -> None: context.adapters.register_dumper(Enum, EnumBinaryDumper) diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py index 535dd1473..c69988c2b 100644 --- a/tests/types/test_enum.py +++ b/tests/types/test_enum.py @@ -2,175 +2,134 @@ from enum import Enum import pytest -from psycopg import pq +from psycopg import pq, sql from psycopg.adapt import PyFormat from psycopg.types.enum import EnumInfo, register_enum -class _TestEnum(str, Enum): +class StrTestEnum(str, Enum): ONE = "ONE" TWO = "TWO" THREE = "THREE" -@pytest.fixture(scope="session") -def testenum(svcconn): - cur = svcconn.cursor() - cur.execute( - """ - drop type if exists testenum cascade; - create type testenum as enum('ONE', 'TWO', 'THREE'); - """ - ) - return EnumInfo.fetch(svcconn, "testenum") +class NonAscciEnum(str, Enum): + XE0 = "x\xe0" + XE1 = "x\xe1" + +enum_cases = [ + ("strtestenum", StrTestEnum, [item.value for item in StrTestEnum]), + ("nonasccienum", NonAscciEnum, [item.value for item in NonAscciEnum]), +] + +encodings = ["utf8", "latin1"] + + +@pytest.fixture(scope="session", params=enum_cases) +def testenum(request, svcconn): + name, enum, labels = request.param + quoted_labels = [sql.quote(label) for label in labels] -@pytest.fixture(scope="session") -def nonasciienum(svcconn): cur = svcconn.cursor() cur.execute( - """ - drop type if exists nonasciienum cascade; - create type nonasciienum as enum ('x\xe0'); + f""" + drop type if exists {name} cascade; + create type {name} as enum({','.join(quoted_labels)}); """ ) - return EnumInfo.fetch(svcconn, "nonasciienum") + return name, enum, labels def test_fetch_info(conn, testenum): - assert testenum.name == "testenum" - assert testenum.oid > 0 - assert testenum.oid != testenum.array_oid > 0 - assert len(testenum.enum_labels) == 3 - assert testenum.enum_labels == ["ONE", "TWO", "THREE"] - + name, enum, labels = testenum -@pytest.mark.parametrize("fmt_in", PyFormat) -def test_enum_insert_generic(conn, testenum, fmt_in): - # No regstration, test for generic enum - conn.execute("create table test_enum_insert (id int primary key, val testenum)") - cur = conn.cursor() - cur.executemany( - f"insert into test_enum_insert (id, val) values (%s, %{fmt_in})", - list(enumerate(_TestEnum)), - ) - cur.execute("select id, val from test_enum_insert order by id") - recs = cur.fetchall() - assert recs == [(0, "ONE"), (1, "TWO"), (2, "THREE")] + info = EnumInfo.fetch(conn, name) + assert info.name == name + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.enum_labels) == len(labels) + assert info.enum_labels == labels +@pytest.mark.parametrize("encoding", encodings) @pytest.mark.parametrize("fmt_in", PyFormat) -def test_enum_dumper(conn, testenum, fmt_in): - register_enum(testenum, _TestEnum, conn) - - cur = conn.execute(f"select %{fmt_in}", [_TestEnum.ONE]) - assert cur.fetchone()[0] is _TestEnum.ONE - - @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_loader(conn, testenum, fmt_out): - register_enum(testenum, _TestEnum, conn) - - cur = conn.execute("select 'ONE'::testenum", binary=fmt_out) - assert cur.fetchone()[0] == _TestEnum.ONE - +def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") -@pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_array_loader(conn, testenum, fmt_out): - register_enum(testenum, _TestEnum, conn) + name, enum, labels = testenum + register_enum(EnumInfo.fetch(conn, name), enum, conn) - cur = conn.execute( - "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]", - binary=fmt_out, - ) - assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO] + for label in labels: + cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out) + assert cur.fetchone()[0] == enum(label) +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_loader_generated(conn, testenum, fmt_out): - register_enum(testenum, context=conn) +def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") + + name, enum, labels = testenum + register_enum(EnumInfo.fetch(conn, name), enum, conn) - cur = conn.execute("select 'ONE'::testenum", binary=fmt_out) - assert cur.fetchone()[0] == testenum.python_type.ONE + for item in enum: + cur = conn.execute(f"select %{fmt_in}", [item], binary=fmt_out) + assert cur.fetchone()[0] == item +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_array_loader_generated(conn, testenum, fmt_out): - register_enum(testenum, context=conn) +def test_generic_enum_loader(conn, testenum, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") - cur = conn.execute( - "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]", - binary=fmt_out, - ) - assert cur.fetchone()[0] == [testenum.python_type.ONE, testenum.python_type.TWO] + name, enum, labels = testenum + info = EnumInfo.fetch(conn, name) + register_enum(info, None, conn) + for label in labels: + cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out) + assert cur.fetchone()[0] == info.python_type(label) -@pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum(conn, testenum, fmt_out): - register_enum(testenum, _TestEnum, conn) - cur = conn.cursor() - cur.execute( - """ - drop table if exists testenumtable; - create table testenumtable (id serial primary key, value testenum); - """ - ) +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array_loader(conn, testenum, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") - cur.execute( - "insert into testenumtable (value) values (%s), (%s), (%s)", - ( - _TestEnum.ONE, - _TestEnum.TWO, - _TestEnum.THREE, - ), - ) + name, enum, labels = testenum + register_enum(EnumInfo.fetch(conn, name), enum, conn) - cur = conn.execute("select value from testenumtable order by id", binary=fmt_out) - assert cur.fetchone()[0] == _TestEnum.ONE - assert cur.fetchone()[0] == _TestEnum.TWO - assert cur.fetchone()[0] == _TestEnum.THREE + cur = conn.execute(f"select %{fmt_in}::{name}[]", [labels], binary=fmt_out) + assert cur.fetchone()[0] == list(enum) +@pytest.mark.parametrize("encoding", encodings) +@pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_array(conn, testenum, fmt_out): - register_enum(testenum, _TestEnum, conn) - - cur = conn.cursor() - cur.execute( - """ - drop table if exists testenumtable; - create table testenumtable (id serial primary key, values testenum[]); - """ - ) +def test_enum_array_dumper(conn, testenum, encoding, fmt_in, fmt_out): + conn.execute(f"set client_encoding to {encoding}") - cur.execute( - "insert into testenumtable (values) values (%s), (%s), (%s)", - ( - [_TestEnum.ONE, _TestEnum.TWO], - [_TestEnum.TWO, _TestEnum.THREE], - [_TestEnum.THREE, _TestEnum.ONE], - ), - ) + name, enum, labels = testenum + register_enum(EnumInfo.fetch(conn, name), enum, conn) - cur = conn.execute("select values from testenumtable order by id", binary=fmt_out) - assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO] - assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE] - assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE] + cur = conn.execute(f"select %{fmt_in}", [list(enum)], binary=fmt_out) + assert cur.fetchone()[0] == list(enum) -@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("encoding", encodings) @pytest.mark.parametrize("fmt_in", PyFormat) -@pytest.mark.parametrize("encoding", ["utf8", "latin1"]) -def test_non_ascii_enum(conn, nonasciienum, fmt_out, fmt_in, encoding): +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_generic_enum_array_loader(conn, testenum, encoding, fmt_in, fmt_out): conn.execute(f"set client_encoding to {encoding}") - info = EnumInfo.fetch(conn, "nonasciienum") - register_enum(info, context=conn) - assert [x.name for x in info.python_type] == ["x\xe0"] - val = list(info.python_type)[0] - cur = conn.execute("select 'x\xe0'::nonasciienum", binary=fmt_out) - assert cur.fetchone()[0] is val + name, enum, labels = testenum + info = EnumInfo.fetch(conn, name) + register_enum(info, enum, conn) - cur = conn.execute(f"select %{fmt_in}", [val]) - assert cur.fetchone()[0] is val + cur = conn.execute(f"select %{fmt_in}::{name}[]", [labels], binary=fmt_out) + assert cur.fetchone()[0] == list(info.python_type)