From 97809e8485c46f3a22bca328dcbde6965757f1a6 Mon Sep 17 00:00:00 2001 From: Vladimir Osokin Date: Wed, 13 Apr 2022 16:49:14 +0500 Subject: [PATCH] feat: add enum support --- docs/basic/pgtypes.rst | 70 +++++++++++++++++++ psycopg/psycopg/_typeinfo.py | 36 ++++++++++ psycopg/psycopg/types/enum.py | 95 ++++++++++++++++++++++++++ tests/types/test_enum.py | 125 ++++++++++++++++++++++++++++++++++ 4 files changed, 326 insertions(+) create mode 100644 psycopg/psycopg/types/enum.py create mode 100644 tests/types/test_enum.py diff --git a/docs/basic/pgtypes.rst b/docs/basic/pgtypes.rst index 2f15e93cb..6a8d450a3 100644 --- a/docs/basic/pgtypes.rst +++ b/docs/basic/pgtypes.rst @@ -365,3 +365,73 @@ connection or cursor), other connections and cursors will be unaffected:: ... "coordinates":[-48.23456,20.12345]}') ... """).fetchone()[0] '0101000020E61000009279E40F061E48C0F2B0506B9A1F3440' + +.. _adapt-enum: + +Enum types casting +------------------ + +Psycopg can adapt PostgreSQL enum types (created with the +|CREATE TYPE AS ENUM|_ command) + +.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)` +.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html + +Before using a enum type it is necessary to get information about it +using the `~psycopg.types.enum.EnumInfo` class and to register it +using `~psycopg.types.enum.register_enum()`. + +.. autoclass:: psycopg.types.enum.EnumInfo + + `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its + documentation for the generic usage, especially the + `~psycopg.types.TypeInfo.fetch()` method. + + .. attribute:: enum_labels + + Contains labels available in the PostgreSQL enum type. + + .. attribute:: python_type + + After `register_composite()` is called, it will contain the python type + mapping to the registered enum. + +.. autofunction:: psycopg.types.enum.register_enum + + After registering, fetching data of the registered enum will cast + PostgreSQL enum labels into corresponding Python enum labels. + + If no ``python_type`` is specified, a `Enum` is created based on + PostgreSQL enum labels. + +Example:: + + >>> from enum import Enum + >>> from psycopg.types.enum import EnumInfo, register_enum + + >>> class UserRole(str, Enum): + ... ADMIN = "ADMIN" + ... EDITOR = "EDITOR" + ... GUEST = "GUEST" + + >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')") + + >>> info = EnumInfo.fetch(conn, "user_role") + >>> register_enum(info, UserRole, conn) + + >>> some_editor = info.python_type("EDITOR") + >>> some_editor + + + >>> conn.execute( + ... "SELECT pg_typeof(%(editor)s), %(editor)s", + ... { "editor": some_editor } + ... ).fetchone() + ('user_role', ) + + >>> conn.execute( + ... "SELECT (%s, %s)::user_role[]", + ... [UserRole.ADMIN, UserRole.GUEST] + ... ).fetchone() + [, ] + diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index 187dc795c..87bfbbe0d 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -285,6 +285,42 @@ WHERE t.oid = %(name)s::regtype """ +class EnumInfo(TypeInfo): + """Manage information about a enum type""" + + __module__ = "psycopg.types.enum" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + enum_labels: Sequence[str], + ): + super().__init__(name, oid, array_oid) + self.enum_labels = enum_labels + # Will be set by register() if the `python_type` is a type + self.python_type: Optional[type] = None + + @classmethod + def _get_info_query( + cls, conn: "Union[Connection[Any], AsyncConnection[Any]]" + ) -> str: + return """\ +SELECT + t.typname AS name, t.oid AS oid, t.typarray AS array_oid, + array_agg(x.enumlabel) AS enum_labels +FROM pg_type t +LEFT JOIN ( + SELECT e.enumtypid, e.enumlabel + FROM pg_enum e + ORDER BY e.enumsortorder +) x ON x.enumtypid = t.oid +WHERE t.oid = %(name)s::regtype +GROUP BY t.typname, t.oid, t.typarray +""" + + class TypesRegistry: """ Container for the information about types in a database. diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py new file mode 100644 index 000000000..0c412a041 --- /dev/null +++ b/psycopg/psycopg/types/enum.py @@ -0,0 +1,95 @@ +""" +Adapters for the enum type. +""" +from enum import Enum +from typing import Optional, TypeVar, Generic, Type + +from .string import StrBinaryDumper, StrDumper +from .. import postgres +from .._typeinfo import EnumInfo as EnumInfo # exported here +from ..abc import AdaptContext +from ..adapt import Buffer, Loader +from ..pq import Format + +E = TypeVar("E", bound=Enum) + + +class EnumLoader(Loader, Generic[E]): + format = Format.TEXT + python_type: Type[E] + + def __init__(self, oid: int, context: Optional[AdaptContext] = None): + super().__init__(oid, context) + + def load(self, data: Buffer) -> E: + if isinstance(data, memoryview): + data = bytes(data) + return self.python_type(data.decode()) + + +class EnumBinaryLoader(EnumLoader[E]): + format = Format.BINARY + + +class EnumDumper(StrDumper): + pass + + +class EnumBinaryDumper(StrBinaryDumper): + pass + + +def register_enum( + info: EnumInfo, + python_type: Optional[Type[E]] = None, + context: Optional[AdaptContext] = None, +) -> None: + """Register the adapters to load and dump a enum type. + + :param info: The object with the information about the enum to register. + :param python_type: Python enum type matching to the Postgres one. If `!None`, + the type will be generated and put into info.python_type. + :param context: The context where to register the adapters. If `!None`, + register it globally. + + .. note:: + Only string enums are supported. + + Use binary format if any of your enum labels contains comma: + connection.execute(..., binary=True) + """ + + if not info: + raise TypeError("no info passed. Is the requested enum available?") + + if python_type is not None: + if {type(item.value) for item in python_type} != {str}: + raise TypeError("invalid enum value type (string is the only supported)") + + info.python_type = python_type + else: + info.python_type = Enum( # type: ignore + info.name.title(), + {label: label for label in info.enum_labels}, + ) + + adapters = context.adapters if context else postgres.adapters + + info.register(context) + + base = EnumLoader + name = f"{info.name.title()}{base.__name__}" + attribs = {"python_type": info.python_type} + loader = type(name, (base,), attribs) + adapters.register_loader(info.oid, loader) + + base = EnumBinaryLoader + name = f"{info.name.title()}{base.__name__}" + attribs = {"python_type": info.python_type} + loader = type(name, (base,), attribs) + adapters.register_loader(info.oid, loader) + + +def register_default_adapters(context: AdaptContext) -> None: + context.adapters.register_dumper(Enum, EnumBinaryDumper) + context.adapters.register_dumper(Enum, EnumDumper) diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py new file mode 100644 index 000000000..0f0023ebc --- /dev/null +++ b/tests/types/test_enum.py @@ -0,0 +1,125 @@ +from enum import Enum + +import pytest +from psycopg.types.enum import EnumInfo, register_enum + +from psycopg import pq + + +class _TestEnum(str, Enum): + ONE = "ONE" + TWO = "TWO" + THREE = "THREE" + + +@pytest.fixture(scope="session") +def testenum(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop type if exists testenum cascade; + + create type testenum as enum('ONE', 'TWO', 'THREE'); + """ + ) + return EnumInfo.fetch(svcconn, "testenum") + + +def test_fetch_info(conn, testenum): + assert testenum.name == "testenum" + assert testenum.oid > 0 + assert testenum.oid != testenum.array_oid > 0 + assert len(testenum.enum_labels) == 3 + assert testenum.enum_labels == ["ONE", "TWO", "THREE"] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader(conn, testenum, fmt_out): + register_enum(testenum, _TestEnum, conn) + + cur = conn.execute("select 'ONE'::testenum", binary=fmt_out) + assert cur.fetchone()[0] == _TestEnum.ONE + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array_loader(conn, testenum, fmt_out): + register_enum(testenum, _TestEnum, conn) + + cur = conn.execute( + "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]", + binary=fmt_out, + ) + assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_loader_generated(conn, testenum, fmt_out): + register_enum(testenum, context=conn) + + cur = conn.execute("select 'ONE'::testenum", binary=fmt_out) + assert cur.fetchone()[0] == testenum.python_type.ONE + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array_loader_generated(conn, testenum, fmt_out): + register_enum(testenum, context=conn) + + cur = conn.execute( + "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]", + binary=fmt_out, + ) + assert cur.fetchone()[0] == [testenum.python_type.ONE, testenum.python_type.TWO] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum(conn, testenum, fmt_out): + register_enum(testenum, _TestEnum, conn) + + cur = conn.cursor() + cur.execute( + """ + drop table if exists testenumtable; + create table testenumtable (id serial primary key, value testenum); + """ + ) + + cur.execute( + "insert into testenumtable (value) values (%s), (%s), (%s)", + ( + _TestEnum.ONE, + _TestEnum.TWO, + _TestEnum.THREE, + ), + ) + + cur = conn.execute("select value from testenumtable order by id", binary=fmt_out) + assert cur.fetchone()[0] == _TestEnum.ONE + assert cur.fetchone()[0] == _TestEnum.TWO + assert cur.fetchone()[0] == _TestEnum.THREE + + +@pytest.mark.parametrize("fmt_out", pq.Format) +def test_enum_array(conn, testenum, fmt_out): + register_enum(testenum, _TestEnum, conn) + + cur = conn.cursor() + cur.execute( + """ + drop table if exists testenumtable; + create table testenumtable (id serial primary key, values testenum[]); + """ + ) + + cur.execute( + "insert into testenumtable (values) values (%s), (%s), (%s)", + ( + [_TestEnum.ONE, _TestEnum.TWO], + [_TestEnum.TWO, _TestEnum.THREE], + [_TestEnum.THREE, _TestEnum.ONE], + ), + ) + + cur = conn.execute("select values from testenumtable order by id", binary=fmt_out) + assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO] + assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE] + assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE] -- 2.47.2