]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add enum support
authorVladimir Osokin <ertaquo@gmail.com>
Wed, 13 Apr 2022 11:49:14 +0000 (16:49 +0500)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 14:05:46 +0000 (16:05 +0200)
docs/basic/pgtypes.rst
psycopg/psycopg/_typeinfo.py
psycopg/psycopg/types/enum.py [new file with mode: 0644]
tests/types/test_enum.py [new file with mode: 0644]

index 2f15e93cb7bf6133b5b9e58da40b3c060e05d3fc..6a8d450a3ba398eb468cbd7f9e36170c58f80633 100644 (file)
@@ -365,3 +365,73 @@ connection or cursor), other connections and cursors will be unaffected::
     ...     "coordinates":[-48.23456,20.12345]}')
     ... """).fetchone()[0]
     '0101000020E61000009279E40F061E48C0F2B0506B9A1F3440'
+
+.. _adapt-enum:
+
+Enum types casting
+------------------
+
+Psycopg can adapt PostgreSQL enum types (created with the
+|CREATE TYPE AS ENUM|_ command)
+
+.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)`
+.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html
+
+Before using a enum type it is necessary to get information about it
+using the `~psycopg.types.enum.EnumInfo` class and to register it
+using `~psycopg.types.enum.register_enum()`.
+
+.. autoclass:: psycopg.types.enum.EnumInfo
+
+   `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its
+   documentation for the generic usage, especially the
+   `~psycopg.types.TypeInfo.fetch()` method.
+
+   .. attribute:: enum_labels
+
+       Contains labels available in the PostgreSQL enum type.
+
+   .. attribute:: python_type
+
+       After `register_composite()` is called, it will contain the python type
+       mapping to the registered enum.
+
+.. autofunction:: psycopg.types.enum.register_enum
+
+   After registering, fetching data of the registered enum will cast
+   PostgreSQL enum labels into corresponding Python enum labels.
+
+   If no ``python_type`` is specified, a `Enum` is created based on
+   PostgreSQL enum labels.
+
+Example::
+
+    >>> from enum import Enum
+    >>> from psycopg.types.enum import EnumInfo, register_enum
+
+    >>> class UserRole(str, Enum):
+    ...     ADMIN = "ADMIN"
+    ...     EDITOR = "EDITOR"
+    ...     GUEST = "GUEST"
+
+    >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')")
+
+    >>> info = EnumInfo.fetch(conn, "user_role")
+    >>> register_enum(info, UserRole, conn)
+
+    >>> some_editor = info.python_type("EDITOR")
+    >>> some_editor
+    <UserRole.EDITOR: 'EDITOR'>
+
+    >>> conn.execute(
+    ...     "SELECT pg_typeof(%(editor)s), %(editor)s",
+    ...     { "editor": some_editor }
+    ... ).fetchone()
+    ('user_role', <UserRole.EDITOR: 'EDITOR'>)
+
+    >>> conn.execute(
+    ...     "SELECT (%s, %s)::user_role[]",
+    ...     [UserRole.ADMIN, UserRole.GUEST]
+    ... ).fetchone()
+    [<UserRole.ADMIN: 'ADMIN'>, <UserRole.GUEST: 'GUEST'>]
+
index 187dc795cf7ae2472dfa90cd6469e0ec844bc353..87bfbbe0d29e1b45d166e010444082bc99e40842 100644 (file)
@@ -285,6 +285,42 @@ WHERE t.oid = %(name)s::regtype
 """
 
 
+class EnumInfo(TypeInfo):
+    """Manage information about a enum type"""
+
+    __module__ = "psycopg.types.enum"
+
+    def __init__(
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        enum_labels: Sequence[str],
+    ):
+        super().__init__(name, oid, array_oid)
+        self.enum_labels = enum_labels
+        # Will be set by register() if the `python_type` is a type
+        self.python_type: Optional[type] = None
+
+    @classmethod
+    def _get_info_query(
+        cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+    ) -> str:
+        return """\
+SELECT
+    t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+    array_agg(x.enumlabel) AS enum_labels
+FROM pg_type t
+LEFT JOIN (
+    SELECT e.enumtypid, e.enumlabel
+    FROM pg_enum e
+    ORDER BY e.enumsortorder
+) x ON x.enumtypid = t.oid
+WHERE t.oid = %(name)s::regtype
+GROUP BY t.typname, t.oid, t.typarray
+"""
+
+
 class TypesRegistry:
     """
     Container for the information about types in a database.
diff --git a/psycopg/psycopg/types/enum.py b/psycopg/psycopg/types/enum.py
new file mode 100644 (file)
index 0000000..0c412a0
--- /dev/null
@@ -0,0 +1,95 @@
+"""
+Adapters for the enum type.
+"""
+from enum import Enum
+from typing import Optional, TypeVar, Generic, Type
+
+from .string import StrBinaryDumper, StrDumper
+from .. import postgres
+from .._typeinfo import EnumInfo as EnumInfo  # exported here
+from ..abc import AdaptContext
+from ..adapt import Buffer, Loader
+from ..pq import Format
+
+E = TypeVar("E", bound=Enum)
+
+
+class EnumLoader(Loader, Generic[E]):
+    format = Format.TEXT
+    python_type: Type[E]
+
+    def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+        super().__init__(oid, context)
+
+    def load(self, data: Buffer) -> E:
+        if isinstance(data, memoryview):
+            data = bytes(data)
+        return self.python_type(data.decode())
+
+
+class EnumBinaryLoader(EnumLoader[E]):
+    format = Format.BINARY
+
+
+class EnumDumper(StrDumper):
+    pass
+
+
+class EnumBinaryDumper(StrBinaryDumper):
+    pass
+
+
+def register_enum(
+    info: EnumInfo,
+    python_type: Optional[Type[E]] = None,
+    context: Optional[AdaptContext] = None,
+) -> None:
+    """Register the adapters to load and dump a enum type.
+
+    :param info: The object with the information about the enum to register.
+    :param python_type: Python enum type matching to the Postgres one. If `!None`,
+        the type will be generated and put into info.python_type.
+    :param context: The context where to register the adapters. If `!None`,
+        register it globally.
+
+    .. note::
+        Only string enums are supported.
+
+        Use binary format if any of your enum labels contains comma:
+            connection.execute(..., binary=True)
+    """
+
+    if not info:
+        raise TypeError("no info passed. Is the requested enum available?")
+
+    if python_type is not None:
+        if {type(item.value) for item in python_type} != {str}:
+            raise TypeError("invalid enum value type (string is the only supported)")
+
+        info.python_type = python_type
+    else:
+        info.python_type = Enum(  # type: ignore
+            info.name.title(),
+            {label: label for label in info.enum_labels},
+        )
+
+    adapters = context.adapters if context else postgres.adapters
+
+    info.register(context)
+
+    base = EnumLoader
+    name = f"{info.name.title()}{base.__name__}"
+    attribs = {"python_type": info.python_type}
+    loader = type(name, (base,), attribs)
+    adapters.register_loader(info.oid, loader)
+
+    base = EnumBinaryLoader
+    name = f"{info.name.title()}{base.__name__}"
+    attribs = {"python_type": info.python_type}
+    loader = type(name, (base,), attribs)
+    adapters.register_loader(info.oid, loader)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+    context.adapters.register_dumper(Enum, EnumBinaryDumper)
+    context.adapters.register_dumper(Enum, EnumDumper)
diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py
new file mode 100644 (file)
index 0000000..0f0023e
--- /dev/null
@@ -0,0 +1,125 @@
+from enum import Enum
+
+import pytest
+from psycopg.types.enum import EnumInfo, register_enum
+
+from psycopg import pq
+
+
+class _TestEnum(str, Enum):
+    ONE = "ONE"
+    TWO = "TWO"
+    THREE = "THREE"
+
+
+@pytest.fixture(scope="session")
+def testenum(svcconn):
+    cur = svcconn.cursor()
+    cur.execute(
+        """
+        drop type if exists testenum cascade;
+
+        create type testenum as enum('ONE', 'TWO', 'THREE');
+        """
+    )
+    return EnumInfo.fetch(svcconn, "testenum")
+
+
+def test_fetch_info(conn, testenum):
+    assert testenum.name == "testenum"
+    assert testenum.oid > 0
+    assert testenum.oid != testenum.array_oid > 0
+    assert len(testenum.enum_labels) == 3
+    assert testenum.enum_labels == ["ONE", "TWO", "THREE"]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader(conn, testenum, fmt_out):
+    register_enum(testenum, _TestEnum, conn)
+
+    cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
+    assert cur.fetchone()[0] == _TestEnum.ONE
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader(conn, testenum, fmt_out):
+    register_enum(testenum, _TestEnum, conn)
+
+    cur = conn.execute(
+        "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
+        binary=fmt_out,
+    )
+    assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader_generated(conn, testenum, fmt_out):
+    register_enum(testenum, context=conn)
+
+    cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
+    assert cur.fetchone()[0] == testenum.python_type.ONE
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader_generated(conn, testenum, fmt_out):
+    register_enum(testenum, context=conn)
+
+    cur = conn.execute(
+        "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
+        binary=fmt_out,
+    )
+    assert cur.fetchone()[0] == [testenum.python_type.ONE, testenum.python_type.TWO]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum(conn, testenum, fmt_out):
+    register_enum(testenum, _TestEnum, conn)
+
+    cur = conn.cursor()
+    cur.execute(
+        """
+        drop table if exists testenumtable;
+        create table testenumtable (id serial primary key, value testenum);
+        """
+    )
+
+    cur.execute(
+        "insert into testenumtable (value) values (%s), (%s), (%s)",
+        (
+            _TestEnum.ONE,
+            _TestEnum.TWO,
+            _TestEnum.THREE,
+        ),
+    )
+
+    cur = conn.execute("select value from testenumtable order by id", binary=fmt_out)
+    assert cur.fetchone()[0] == _TestEnum.ONE
+    assert cur.fetchone()[0] == _TestEnum.TWO
+    assert cur.fetchone()[0] == _TestEnum.THREE
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array(conn, testenum, fmt_out):
+    register_enum(testenum, _TestEnum, conn)
+
+    cur = conn.cursor()
+    cur.execute(
+        """
+        drop table if exists testenumtable;
+        create table testenumtable (id serial primary key, values testenum[]);
+        """
+    )
+
+    cur.execute(
+        "insert into testenumtable (values) values (%s), (%s), (%s)",
+        (
+            [_TestEnum.ONE, _TestEnum.TWO],
+            [_TestEnum.TWO, _TestEnum.THREE],
+            [_TestEnum.THREE, _TestEnum.ONE],
+        ),
+    )
+
+    cur = conn.execute("select values from testenumtable order by id", binary=fmt_out)
+    assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
+    assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE]
+    assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE]