... "coordinates":[-48.23456,20.12345]}')
... """).fetchone()[0]
'0101000020E61000009279E40F061E48C0F2B0506B9A1F3440'
+
+.. _adapt-enum:
+
+Enum types casting
+------------------
+
+Psycopg can adapt PostgreSQL enum types (created with the
+|CREATE TYPE AS ENUM|_ command)
+
+.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)`
+.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html
+
+Before using a enum type it is necessary to get information about it
+using the `~psycopg.types.enum.EnumInfo` class and to register it
+using `~psycopg.types.enum.register_enum()`.
+
+.. autoclass:: psycopg.types.enum.EnumInfo
+
+ `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its
+ documentation for the generic usage, especially the
+ `~psycopg.types.TypeInfo.fetch()` method.
+
+ .. attribute:: enum_labels
+
+ Contains labels available in the PostgreSQL enum type.
+
+ .. attribute:: python_type
+
+ After `register_composite()` is called, it will contain the python type
+ mapping to the registered enum.
+
+.. autofunction:: psycopg.types.enum.register_enum
+
+ After registering, fetching data of the registered enum will cast
+ PostgreSQL enum labels into corresponding Python enum labels.
+
+ If no ``python_type`` is specified, a `Enum` is created based on
+ PostgreSQL enum labels.
+
+Example::
+
+ >>> from enum import Enum
+ >>> from psycopg.types.enum import EnumInfo, register_enum
+
+ >>> class UserRole(str, Enum):
+ ... ADMIN = "ADMIN"
+ ... EDITOR = "EDITOR"
+ ... GUEST = "GUEST"
+
+ >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')")
+
+ >>> info = EnumInfo.fetch(conn, "user_role")
+ >>> register_enum(info, UserRole, conn)
+
+ >>> some_editor = info.python_type("EDITOR")
+ >>> some_editor
+ <UserRole.EDITOR: 'EDITOR'>
+
+ >>> conn.execute(
+ ... "SELECT pg_typeof(%(editor)s), %(editor)s",
+ ... { "editor": some_editor }
+ ... ).fetchone()
+ ('user_role', <UserRole.EDITOR: 'EDITOR'>)
+
+ >>> conn.execute(
+ ... "SELECT (%s, %s)::user_role[]",
+ ... [UserRole.ADMIN, UserRole.GUEST]
+ ... ).fetchone()
+ [<UserRole.ADMIN: 'ADMIN'>, <UserRole.GUEST: 'GUEST'>]
+
"""
+class EnumInfo(TypeInfo):
+ """Manage information about a enum type"""
+
+ __module__ = "psycopg.types.enum"
+
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ enum_labels: Sequence[str],
+ ):
+ super().__init__(name, oid, array_oid)
+ self.enum_labels = enum_labels
+ # Will be set by register() if the `python_type` is a type
+ self.python_type: Optional[type] = None
+
+ @classmethod
+ def _get_info_query(
+ cls, conn: "Union[Connection[Any], AsyncConnection[Any]]"
+ ) -> str:
+ return """\
+SELECT
+ t.typname AS name, t.oid AS oid, t.typarray AS array_oid,
+ array_agg(x.enumlabel) AS enum_labels
+FROM pg_type t
+LEFT JOIN (
+ SELECT e.enumtypid, e.enumlabel
+ FROM pg_enum e
+ ORDER BY e.enumsortorder
+) x ON x.enumtypid = t.oid
+WHERE t.oid = %(name)s::regtype
+GROUP BY t.typname, t.oid, t.typarray
+"""
+
+
class TypesRegistry:
"""
Container for the information about types in a database.
--- /dev/null
+"""
+Adapters for the enum type.
+"""
+from enum import Enum
+from typing import Optional, TypeVar, Generic, Type
+
+from .string import StrBinaryDumper, StrDumper
+from .. import postgres
+from .._typeinfo import EnumInfo as EnumInfo # exported here
+from ..abc import AdaptContext
+from ..adapt import Buffer, Loader
+from ..pq import Format
+
+E = TypeVar("E", bound=Enum)
+
+
+class EnumLoader(Loader, Generic[E]):
+ format = Format.TEXT
+ python_type: Type[E]
+
+ def __init__(self, oid: int, context: Optional[AdaptContext] = None):
+ super().__init__(oid, context)
+
+ def load(self, data: Buffer) -> E:
+ if isinstance(data, memoryview):
+ data = bytes(data)
+ return self.python_type(data.decode())
+
+
+class EnumBinaryLoader(EnumLoader[E]):
+ format = Format.BINARY
+
+
+class EnumDumper(StrDumper):
+ pass
+
+
+class EnumBinaryDumper(StrBinaryDumper):
+ pass
+
+
+def register_enum(
+ info: EnumInfo,
+ python_type: Optional[Type[E]] = None,
+ context: Optional[AdaptContext] = None,
+) -> None:
+ """Register the adapters to load and dump a enum type.
+
+ :param info: The object with the information about the enum to register.
+ :param python_type: Python enum type matching to the Postgres one. If `!None`,
+ the type will be generated and put into info.python_type.
+ :param context: The context where to register the adapters. If `!None`,
+ register it globally.
+
+ .. note::
+ Only string enums are supported.
+
+ Use binary format if any of your enum labels contains comma:
+ connection.execute(..., binary=True)
+ """
+
+ if not info:
+ raise TypeError("no info passed. Is the requested enum available?")
+
+ if python_type is not None:
+ if {type(item.value) for item in python_type} != {str}:
+ raise TypeError("invalid enum value type (string is the only supported)")
+
+ info.python_type = python_type
+ else:
+ info.python_type = Enum( # type: ignore
+ info.name.title(),
+ {label: label for label in info.enum_labels},
+ )
+
+ adapters = context.adapters if context else postgres.adapters
+
+ info.register(context)
+
+ base = EnumLoader
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {"python_type": info.python_type}
+ loader = type(name, (base,), attribs)
+ adapters.register_loader(info.oid, loader)
+
+ base = EnumBinaryLoader
+ name = f"{info.name.title()}{base.__name__}"
+ attribs = {"python_type": info.python_type}
+ loader = type(name, (base,), attribs)
+ adapters.register_loader(info.oid, loader)
+
+
+def register_default_adapters(context: AdaptContext) -> None:
+ context.adapters.register_dumper(Enum, EnumBinaryDumper)
+ context.adapters.register_dumper(Enum, EnumDumper)
--- /dev/null
+from enum import Enum
+
+import pytest
+from psycopg.types.enum import EnumInfo, register_enum
+
+from psycopg import pq
+
+
+class _TestEnum(str, Enum):
+ ONE = "ONE"
+ TWO = "TWO"
+ THREE = "THREE"
+
+
+@pytest.fixture(scope="session")
+def testenum(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop type if exists testenum cascade;
+
+ create type testenum as enum('ONE', 'TWO', 'THREE');
+ """
+ )
+ return EnumInfo.fetch(svcconn, "testenum")
+
+
+def test_fetch_info(conn, testenum):
+ assert testenum.name == "testenum"
+ assert testenum.oid > 0
+ assert testenum.oid != testenum.array_oid > 0
+ assert len(testenum.enum_labels) == 3
+ assert testenum.enum_labels == ["ONE", "TWO", "THREE"]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader(conn, testenum, fmt_out):
+ register_enum(testenum, _TestEnum, conn)
+
+ cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
+ assert cur.fetchone()[0] == _TestEnum.ONE
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader(conn, testenum, fmt_out):
+ register_enum(testenum, _TestEnum, conn)
+
+ cur = conn.execute(
+ "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
+ binary=fmt_out,
+ )
+ assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_loader_generated(conn, testenum, fmt_out):
+ register_enum(testenum, context=conn)
+
+ cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
+ assert cur.fetchone()[0] == testenum.python_type.ONE
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader_generated(conn, testenum, fmt_out):
+ register_enum(testenum, context=conn)
+
+ cur = conn.execute(
+ "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
+ binary=fmt_out,
+ )
+ assert cur.fetchone()[0] == [testenum.python_type.ONE, testenum.python_type.TWO]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum(conn, testenum, fmt_out):
+ register_enum(testenum, _TestEnum, conn)
+
+ cur = conn.cursor()
+ cur.execute(
+ """
+ drop table if exists testenumtable;
+ create table testenumtable (id serial primary key, value testenum);
+ """
+ )
+
+ cur.execute(
+ "insert into testenumtable (value) values (%s), (%s), (%s)",
+ (
+ _TestEnum.ONE,
+ _TestEnum.TWO,
+ _TestEnum.THREE,
+ ),
+ )
+
+ cur = conn.execute("select value from testenumtable order by id", binary=fmt_out)
+ assert cur.fetchone()[0] == _TestEnum.ONE
+ assert cur.fetchone()[0] == _TestEnum.TWO
+ assert cur.fetchone()[0] == _TestEnum.THREE
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array(conn, testenum, fmt_out):
+ register_enum(testenum, _TestEnum, conn)
+
+ cur = conn.cursor()
+ cur.execute(
+ """
+ drop table if exists testenumtable;
+ create table testenumtable (id serial primary key, values testenum[]);
+ """
+ )
+
+ cur.execute(
+ "insert into testenumtable (values) values (%s), (%s), (%s)",
+ (
+ [_TestEnum.ONE, _TestEnum.TWO],
+ [_TestEnum.TWO, _TestEnum.THREE],
+ [_TestEnum.THREE, _TestEnum.ONE],
+ ),
+ )
+
+ cur = conn.execute("select values from testenumtable order by id", binary=fmt_out)
+ assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
+ assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE]
+ assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE]