From: Daniele Varrazzo Date: Thu, 14 Apr 2022 01:08:59 +0000 (+0200) Subject: test(enum): add failing tests to illustrate enum needed improvements X-Git-Tag: 3.1~137^2~20 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e853524f06ad10129b7172bb76c09eefcb5aedbe;p=thirdparty%2Fpsycopg.git test(enum): add failing tests to illustrate enum needed improvements --- diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py index 0f0023ebc..535dd1473 100644 --- a/tests/types/test_enum.py +++ b/tests/types/test_enum.py @@ -1,9 +1,10 @@ from enum import Enum import pytest -from psycopg.types.enum import EnumInfo, register_enum from psycopg import pq +from psycopg.adapt import PyFormat +from psycopg.types.enum import EnumInfo, register_enum class _TestEnum(str, Enum): @@ -18,13 +19,24 @@ def testenum(svcconn): cur.execute( """ drop type if exists testenum cascade; - create type testenum as enum('ONE', 'TWO', 'THREE'); """ ) return EnumInfo.fetch(svcconn, "testenum") +@pytest.fixture(scope="session") +def nonasciienum(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop type if exists nonasciienum cascade; + create type nonasciienum as enum ('x\xe0'); + """ + ) + return EnumInfo.fetch(svcconn, "nonasciienum") + + def test_fetch_info(conn, testenum): assert testenum.name == "testenum" assert testenum.oid > 0 @@ -33,6 +45,28 @@ def test_fetch_info(conn, testenum): assert testenum.enum_labels == ["ONE", "TWO", "THREE"] +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_enum_insert_generic(conn, testenum, fmt_in): + # No regstration, test for generic enum + conn.execute("create table test_enum_insert (id int primary key, val testenum)") + cur = conn.cursor() + cur.executemany( + f"insert into test_enum_insert (id, val) values (%s, %{fmt_in})", + list(enumerate(_TestEnum)), + ) + cur.execute("select id, val from test_enum_insert order by id") + recs = cur.fetchall() + assert recs == [(0, "ONE"), (1, "TWO"), (2, "THREE")] + + +@pytest.mark.parametrize("fmt_in", PyFormat) +def test_enum_dumper(conn, testenum, fmt_in): + register_enum(testenum, _TestEnum, conn) + + cur = conn.execute(f"select %{fmt_in}", [_TestEnum.ONE]) + assert cur.fetchone()[0] is _TestEnum.ONE + + @pytest.mark.parametrize("fmt_out", pq.Format) def test_enum_loader(conn, testenum, fmt_out): register_enum(testenum, _TestEnum, conn) @@ -123,3 +157,20 @@ def test_enum_array(conn, testenum, fmt_out): assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO] assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE] assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE] + + +@pytest.mark.parametrize("fmt_out", pq.Format) +@pytest.mark.parametrize("fmt_in", PyFormat) +@pytest.mark.parametrize("encoding", ["utf8", "latin1"]) +def test_non_ascii_enum(conn, nonasciienum, fmt_out, fmt_in, encoding): + conn.execute(f"set client_encoding to {encoding}") + info = EnumInfo.fetch(conn, "nonasciienum") + register_enum(info, context=conn) + assert [x.name for x in info.python_type] == ["x\xe0"] + val = list(info.python_type)[0] + + cur = conn.execute("select 'x\xe0'::nonasciienum", binary=fmt_out) + assert cur.fetchone()[0] is val + + cur = conn.execute(f"select %{fmt_in}", [val]) + assert cur.fetchone()[0] is val