from enum import Enum
import pytest
-from psycopg.types.enum import EnumInfo, register_enum
from psycopg import pq
+from psycopg.adapt import PyFormat
+from psycopg.types.enum import EnumInfo, register_enum
class _TestEnum(str, Enum):
cur.execute(
"""
drop type if exists testenum cascade;
-
create type testenum as enum('ONE', 'TWO', 'THREE');
"""
)
return EnumInfo.fetch(svcconn, "testenum")
+@pytest.fixture(scope="session")
+def nonasciienum(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop type if exists nonasciienum cascade;
+ create type nonasciienum as enum ('x\xe0');
+ """
+ )
+ return EnumInfo.fetch(svcconn, "nonasciienum")
+
+
def test_fetch_info(conn, testenum):
assert testenum.name == "testenum"
assert testenum.oid > 0
assert testenum.enum_labels == ["ONE", "TWO", "THREE"]
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_enum_insert_generic(conn, testenum, fmt_in):
+ # No regstration, test for generic enum
+ conn.execute("create table test_enum_insert (id int primary key, val testenum)")
+ cur = conn.cursor()
+ cur.executemany(
+ f"insert into test_enum_insert (id, val) values (%s, %{fmt_in})",
+ list(enumerate(_TestEnum)),
+ )
+ cur.execute("select id, val from test_enum_insert order by id")
+ recs = cur.fetchall()
+ assert recs == [(0, "ONE"), (1, "TWO"), (2, "THREE")]
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_enum_dumper(conn, testenum, fmt_in):
+ register_enum(testenum, _TestEnum, conn)
+
+ cur = conn.execute(f"select %{fmt_in}", [_TestEnum.ONE])
+ assert cur.fetchone()[0] is _TestEnum.ONE
+
+
@pytest.mark.parametrize("fmt_out", pq.Format)
def test_enum_loader(conn, testenum, fmt_out):
register_enum(testenum, _TestEnum, conn)
assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE]
assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("encoding", ["utf8", "latin1"])
+def test_non_ascii_enum(conn, nonasciienum, fmt_out, fmt_in, encoding):
+ conn.execute(f"set client_encoding to {encoding}")
+ info = EnumInfo.fetch(conn, "nonasciienum")
+ register_enum(info, context=conn)
+ assert [x.name for x in info.python_type] == ["x\xe0"]
+ val = list(info.python_type)[0]
+
+ cur = conn.execute("select 'x\xe0'::nonasciienum", binary=fmt_out)
+ assert cur.fetchone()[0] is val
+
+ cur = conn.execute(f"select %{fmt_in}", [val])
+ assert cur.fetchone()[0] is val