]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(enum): add failing tests to illustrate enum needed improvements
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 14 Apr 2022 01:08:59 +0000 (03:08 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 14:05:46 +0000 (16:05 +0200)
tests/types/test_enum.py

index 0f0023ebcf12be807c511c1470223eebc03c6f24..535dd1473181647b0490609c40f3834965db952c 100644 (file)
@@ -1,9 +1,10 @@
 from enum import Enum
 
 import pytest
-from psycopg.types.enum import EnumInfo, register_enum
 
 from psycopg import pq
+from psycopg.adapt import PyFormat
+from psycopg.types.enum import EnumInfo, register_enum
 
 
 class _TestEnum(str, Enum):
@@ -18,13 +19,24 @@ def testenum(svcconn):
     cur.execute(
         """
         drop type if exists testenum cascade;
-
         create type testenum as enum('ONE', 'TWO', 'THREE');
         """
     )
     return EnumInfo.fetch(svcconn, "testenum")
 
 
+@pytest.fixture(scope="session")
+def nonasciienum(svcconn):
+    cur = svcconn.cursor()
+    cur.execute(
+        """
+        drop type if exists nonasciienum cascade;
+        create type nonasciienum as enum ('x\xe0');
+        """
+    )
+    return EnumInfo.fetch(svcconn, "nonasciienum")
+
+
 def test_fetch_info(conn, testenum):
     assert testenum.name == "testenum"
     assert testenum.oid > 0
@@ -33,6 +45,28 @@ def test_fetch_info(conn, testenum):
     assert testenum.enum_labels == ["ONE", "TWO", "THREE"]
 
 
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_enum_insert_generic(conn, testenum, fmt_in):
+    # No regstration, test for generic enum
+    conn.execute("create table test_enum_insert (id int primary key, val testenum)")
+    cur = conn.cursor()
+    cur.executemany(
+        f"insert into test_enum_insert (id, val) values (%s, %{fmt_in})",
+        list(enumerate(_TestEnum)),
+    )
+    cur.execute("select id, val from test_enum_insert order by id")
+    recs = cur.fetchall()
+    assert recs == [(0, "ONE"), (1, "TWO"), (2, "THREE")]
+
+
+@pytest.mark.parametrize("fmt_in", PyFormat)
+def test_enum_dumper(conn, testenum, fmt_in):
+    register_enum(testenum, _TestEnum, conn)
+
+    cur = conn.execute(f"select %{fmt_in}", [_TestEnum.ONE])
+    assert cur.fetchone()[0] is _TestEnum.ONE
+
+
 @pytest.mark.parametrize("fmt_out", pq.Format)
 def test_enum_loader(conn, testenum, fmt_out):
     register_enum(testenum, _TestEnum, conn)
@@ -123,3 +157,20 @@ def test_enum_array(conn, testenum, fmt_out):
     assert cur.fetchone()[0] == [_TestEnum.ONE, _TestEnum.TWO]
     assert cur.fetchone()[0] == [_TestEnum.TWO, _TestEnum.THREE]
     assert cur.fetchone()[0] == [_TestEnum.THREE, _TestEnum.ONE]
+
+
+@pytest.mark.parametrize("fmt_out", pq.Format)
+@pytest.mark.parametrize("fmt_in", PyFormat)
+@pytest.mark.parametrize("encoding", ["utf8", "latin1"])
+def test_non_ascii_enum(conn, nonasciienum, fmt_out, fmt_in, encoding):
+    conn.execute(f"set client_encoding to {encoding}")
+    info = EnumInfo.fetch(conn, "nonasciienum")
+    register_enum(info, context=conn)
+    assert [x.name for x in info.python_type] == ["x\xe0"]
+    val = list(info.python_type)[0]
+
+    cur = conn.execute("select 'x\xe0'::nonasciienum", binary=fmt_out)
+    assert cur.fetchone()[0] is val
+
+    cur = conn.execute(f"select %{fmt_in}", [val])
+    assert cur.fetchone()[0] is val