>>> conn.execute("select '::ffff:1.2.3.0/120'::cidr").fetchone()[0]
IPv6Network('::ffff:102:300/120')
+
+.. _adapt-enum:
+
+Enum adaptation
+---------------
+
+.. versionadded:: 3.1
+
+Psycopg can adapt PostgreSQL enum types (created with the
+|CREATE TYPE AS ENUM|_ command)
+
+.. |CREATE TYPE AS ENUM| replace:: :sql:`CREATE TYPE ... AS ENUM (...)`
+.. _CREATE TYPE AS ENUM: https://www.postgresql.org/docs/current/static/datatype-enum.html
+
+Before using a enum type it is necessary to get information about it
+using the `~psycopg.types.enum.EnumInfo` class and to register it
+using `~psycopg.types.enum.register_enum()`.
+
+If you use enum array and your enum labels contains comma, you should use
+the binary format to return values. See :ref:`binary-data` for details.
+
+.. autoclass:: psycopg.types.enum.EnumInfo
+
+ `!EnumInfo` is a `~psycopg.types.TypeInfo` subclass: check its
+ documentation for the generic usage, especially the
+ `~psycopg.types.TypeInfo.fetch()` method.
+
+ .. attribute:: enum_labels
+
+ Contains labels available in the PostgreSQL enum type.
+
+ .. attribute:: python_type
+
+ After `register_enum()` is called, it will contain the python type
+ mapping to the registered enum.
+
+.. autofunction:: psycopg.types.enum.register_enum
+
+ After registering, fetching data of the registered enum will cast
+ PostgreSQL enum labels into corresponding Python enum labels.
+
+ If no ``python_type`` is specified, a `Enum` is created based on
+ PostgreSQL enum labels.
+
+Example::
+
+ >>> from enum import Enum
+ >>> from psycopg.types.enum import EnumInfo, register_enum
+
+ >>> class UserRole(str, Enum):
+ ... ADMIN = "ADMIN"
+ ... EDITOR = "EDITOR"
+ ... GUEST = "GUEST"
+
+ >>> conn.execute("CREATE TYPE user_role AS ('ADMIN', 'EDITOR', 'GUEST')")
+
+ >>> info = EnumInfo.fetch(conn, "user_role")
+ >>> register_enum(info, UserRole, conn)
+
+ >>> some_editor = info.python_type("EDITOR")
+ >>> some_editor
+ <UserRole.EDITOR: 'EDITOR'>
+
+ >>> conn.execute(
+ ... "SELECT pg_typeof(%(editor)s), %(editor)s",
+ ... { "editor": some_editor }
+ ... ).fetchone()
+ ('user_role', <UserRole.EDITOR: 'EDITOR'>)
+
+ >>> conn.execute(
+ ... "SELECT (%s, %s)::user_role[]",
+ ... [UserRole.ADMIN, UserRole.GUEST]
+ ... ).fetchone()
+ [<UserRole.ADMIN: 'ADMIN'>, <UserRole.GUEST: 'GUEST'>]
Adapters for the enum type.
"""
from enum import Enum
-from typing import Optional, TypeVar, Generic, Type
+from typing import Optional, TypeVar, Generic, Type, Dict, Any
+from ..adapt import Dumper, Loader
from .string import StrBinaryDumper, StrDumper
from .. import postgres
+from .._encodings import pgconn_encoding
from .._typeinfo import EnumInfo as EnumInfo # exported here
from ..abc import AdaptContext
-from ..adapt import Buffer, Loader
+from ..adapt import Buffer
from ..pq import Format
+
E = TypeVar("E", bound=Enum)
class EnumLoader(Loader, Generic[E]):
format = Format.TEXT
+ _encoding = "utf-8"
python_type: Type[E]
def __init__(self, oid: int, context: Optional[AdaptContext] = None):
super().__init__(oid, context)
+ conn = self.connection
+ if conn:
+ self._encoding = pgconn_encoding(conn.pgconn)
def load(self, data: Buffer) -> E:
if isinstance(data, memoryview):
- data = bytes(data)
- return self.python_type(data.decode())
+ label = bytes(data).decode(self._encoding)
+ else:
+ label = data.decode(self._encoding)
+
+ return self.python_type(label)
class EnumBinaryLoader(EnumLoader[E]):
class EnumDumper(StrDumper):
- pass
+ def dump(self, obj: str) -> bytes:
+ return super().dump(obj)
class EnumBinaryDumper(StrBinaryDumper):
.. note::
Only string enums are supported.
- Use binary format if any of your enum labels contains comma:
+ Use binary format if you use enum array and enum labels contains comma:
connection.execute(..., binary=True)
"""
info.register(context)
- base = EnumLoader
- name = f"{info.name.title()}{base.__name__}"
- attribs = {"python_type": info.python_type}
- loader = type(name, (base,), attribs)
+ attribs: Dict[str, Any] = {"python_type": info.python_type}
+
+ loader_base = EnumLoader
+ name = f"{info.name.title()}{loader_base.__name__}"
+ loader = type(name, (loader_base,), attribs)
adapters.register_loader(info.oid, loader)
- base = EnumBinaryLoader
- name = f"{info.name.title()}{base.__name__}"
- attribs = {"python_type": info.python_type}
- loader = type(name, (base,), attribs)
+ loader_base = EnumBinaryLoader
+ name = f"{info.name.title()}{loader_base.__name__}"
+ loader = type(name, (loader_base,), attribs)
adapters.register_loader(info.oid, loader)
+ attribs = {"oid": info.oid}
+
+ dumper_base: Type[Dumper] = EnumBinaryDumper
+ name = f"{info.name.title()}{dumper_base.__name__}"
+ dumper = type(name, (dumper_base,), attribs)
+ adapters.register_dumper(info.python_type, dumper)
+
+ dumper_base = EnumDumper
+ name = f"{info.name.title()}{dumper_base.__name__}"
+ dumper = type(name, (dumper_base,), attribs)
+ adapters.register_dumper(info.python_type, dumper)
+
def register_default_adapters(context: AdaptContext) -> None:
context.adapters.register_dumper(Enum, EnumBinaryDumper)
import pytest
-from psycopg import pq
+from psycopg import pq, sql
from psycopg.adapt import PyFormat
from psycopg.types.enum import EnumInfo, register_enum
-class _TestEnum(str, Enum):
+class StrTestEnum(str, Enum):
ONE = "ONE"
TWO = "TWO"
THREE = "THREE"
-@pytest.fixture(scope="session")
-def testenum(svcconn):
- cur = svcconn.cursor()
- cur.execute(
- """
- drop type if exists testenum cascade;
- create type testenum as enum('ONE', 'TWO', 'THREE');
- """
- )
- return EnumInfo.fetch(svcconn, "testenum")
+class NonAscciEnum(str, Enum):
+ XE0 = "x\xe0"
+ XE1 = "x\xe1"
+
+enum_cases = [
+ ("strtestenum", StrTestEnum, [item.value for item in StrTestEnum]),
+ ("nonasccienum", NonAscciEnum, [item.value for item in NonAscciEnum]),
+]
+
+encodings = ["utf8", "latin1"]
+
+
+@pytest.fixture(scope="session", params=enum_cases)
+def testenum(request, svcconn):
+ name, enum, labels = request.param
+ quoted_labels = [sql.quote(label) for label in labels]
-@pytest.fixture(scope="session")
-def nonasciienum(svcconn):
cur = svcconn.cursor()
cur.execute(
- """
- drop type if exists nonasciienum cascade;
- create type nonasciienum as enum ('x\xe0');
+ f"""
+ drop type if exists {name} cascade;
+ create type {name} as enum({','.join(quoted_labels)});
"""
)
- return EnumInfo.fetch(svcconn, "nonasciienum")
+ return name, enum, labels
def test_fetch_info(conn, testenum):
- assert testenum.name == "testenum"
- assert testenum.oid > 0
- assert testenum.oid != testenum.array_oid > 0
- assert len(testenum.enum_labels) == 3
- assert testenum.enum_labels == ["ONE", "TWO", "THREE"]
-
+ name, enum, labels = testenum
-@pytest.mark.parametrize("fmt_in", PyFormat)
-def test_enum_insert_generic(conn, testenum, fmt_in):
- # No regstration, test for generic enum
- conn.execute("create table test_enum_insert (id int primary key, val testenum)")
- cur = conn.cursor()
- cur.executemany(
- f"insert into test_enum_insert (id, val) values (%s, %{fmt_in})",
- list(enumerate(_TestEnum)),
- )
- cur.execute("select id, val from test_enum_insert order by id")
- recs = cur.fetchall()
- assert recs == [(0, "ONE"), (1, "TWO"), (2, "THREE")]
+ info = EnumInfo.fetch(conn, name)
+ assert info.name == name
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.enum_labels) == len(labels)
+ assert info.enum_labels == labels
+@pytest.mark.parametrize("encoding", encodings)
@pytest.mark.parametrize("fmt_in", PyFormat)
-def test_enum_dumper(conn, testenum, fmt_in):
- register_enum(testenum, _TestEnum, conn)
-
- cur = conn.execute(f"select %{fmt_in}", [_TestEnum.ONE])
- assert cur.fetchone()[0] is _TestEnum.ONE
-
-
@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_loader(conn, testenum, fmt_out):
- register_enum(testenum, _TestEnum, conn)
-
- cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
- assert cur.fetchone()[0] == _TestEnum.ONE
-
+def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
-@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_array_loader(conn, testenum, fmt_out):
- register_enum(testenum, _TestEnum, conn)
+ name, enum, labels = testenum
+ register_enum(EnumInfo.fetch(conn, name), enum, conn)
- cur = conn.execute(
- "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
- binary=fmt_out,
- )
- assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
+ for label in labels:
+ cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
+ assert cur.fetchone()[0] == enum(label)
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_loader_generated(conn, testenum, fmt_out):
- register_enum(testenum, context=conn)
+def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
+
+ name, enum, labels = testenum
+ register_enum(EnumInfo.fetch(conn, name), enum, conn)
- cur = conn.execute("select 'ONE'::testenum", binary=fmt_out)
- assert cur.fetchone()[0] == testenum.python_type.ONE
+ for item in enum:
+ cur = conn.execute(f"select %{fmt_in}", [item], binary=fmt_out)
+ assert cur.fetchone()[0] == item
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_array_loader_generated(conn, testenum, fmt_out):
- register_enum(testenum, context=conn)
+def test_generic_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
- cur = conn.execute(
- "select ARRAY['ONE'::testenum, 'TWO'::testenum]::testenum[]",
- binary=fmt_out,
- )
- assert cur.fetchone()[0] == [testenum.python_type.ONE, testenum.python_type.TWO]
+ name, enum, labels = testenum
+ info = EnumInfo.fetch(conn, name)
+ register_enum(info, None, conn)
+ for label in labels:
+ cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
+ assert cur.fetchone()[0] == info.python_type(label)
-@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum(conn, testenum, fmt_out):
- register_enum(testenum, _TestEnum, conn)
- cur = conn.cursor()
- cur.execute(
- """
- drop table if exists testenumtable;
- create table testenumtable (id serial primary key, value testenum);
- """
- )
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_enum_array_loader(conn, testenum, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
- cur.execute(
- "insert into testenumtable (value) values (%s), (%s), (%s)",
- (
- _TestEnum.ONE,
- _TestEnum.TWO,
- _TestEnum.THREE,
- ),
- )
+ name, enum, labels = testenum
+ register_enum(EnumInfo.fetch(conn, name), enum, conn)
- cur = conn.execute("select value from testenumtable order by id", binary=fmt_out)
- assert cur.fetchone()[0] == _TestEnum.ONE
- assert cur.fetchone()[0] == _TestEnum.TWO
- assert cur.fetchone()[0] == _TestEnum.THREE
+ cur = conn.execute(f"select %{fmt_in}::{name}[]", [labels], binary=fmt_out)
+ assert cur.fetchone()[0] == list(enum)
+@pytest.mark.parametrize("encoding", encodings)
+@pytest.mark.parametrize("fmt_in", PyFormat)
@pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_array(conn, testenum, fmt_out):
- register_enum(testenum, _TestEnum, conn)
-
- cur = conn.cursor()
- cur.execute(
- """
- drop table if exists testenumtable;
- create table testenumtable (id serial primary key, values testenum[]);
- """
- )
+def test_enum_array_dumper(conn, testenum, encoding, fmt_in, fmt_out):
+ conn.execute(f"set client_encoding to {encoding}")
- cur.execute(
- "insert into testenumtable (values) values (%s), (%s), (%s)",
- (
- [_TestEnum.ONE, _TestEnum.TWO],
- [_TestEnum.TWO, _TestEnum.THREE],
- [_TestEnum.THREE, _TestEnum.ONE],
- ),
- )
+ name, enum, labels = testenum
+ register_enum(EnumInfo.fetch(conn, name), enum, conn)
- cur = conn.execute("select values from testenumtable order by id", binary=fmt_out)
- assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
- assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE]
- assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE]
+ cur = conn.execute(f"select %{fmt_in}", [list(enum)], binary=fmt_out)
+ assert cur.fetchone()[0] == list(enum)
-@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("encoding", encodings)
@pytest.mark.parametrize("fmt_in", PyFormat)
-@pytest.mark.parametrize("encoding", ["utf8", "latin1"])
-def test_non_ascii_enum(conn, nonasciienum, fmt_out, fmt_in, encoding):
+@pytest.mark.parametrize("fmt_out", pq.Format)
+def test_generic_enum_array_loader(conn, testenum, encoding, fmt_in, fmt_out):
conn.execute(f"set client_encoding to {encoding}")
- info = EnumInfo.fetch(conn, "nonasciienum")
- register_enum(info, context=conn)
- assert [x.name for x in info.python_type] == ["x\xe0"]
- val = list(info.python_type)[0]
- cur = conn.execute("select 'x\xe0'::nonasciienum", binary=fmt_out)
- assert cur.fetchone()[0] is val
+ name, enum, labels = testenum
+ info = EnumInfo.fetch(conn, name)
+ register_enum(info, enum, conn)
- cur = conn.execute(f"select %{fmt_in}", [val])
- assert cur.fetchone()[0] is val
+ cur = conn.execute(f"select %{fmt_in}::{name}[]", [labels], binary=fmt_out)
+ assert cur.fetchone()[0] == list(info.python_type)