]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
test(enum): refactor in order to write enum-specific type
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Apr 2022 08:21:18 +0000 (10:21 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 22 Apr 2022 03:03:24 +0000 (05:03 +0200)
They all get created, tests can choose to use only some of them.

tests/types/test_enum.py

index 90cc00d9f4de4335a8274df0b84e16df110d4589..182fd8ace10b9488c7048855dd962358021de9e6 100644 (file)
@@ -32,6 +32,7 @@ class IntTestEnum(int, Enum):
 
 
 enum_cases = [PureTestEnum, StrTestEnum, NonAsciiEnum, IntTestEnum]
+ascii_cases = [PureTestEnum, StrTestEnum, IntTestEnum]
 
 encodings = ["utf8", "latin1"]
 
@@ -39,10 +40,19 @@ encodings = ["utf8", "latin1"]
 @pytest.fixture(scope="session", params=enum_cases)
 def testenum(request, svcconn):
     enum = request.param
+    return ensure_enum(enum, svcconn)
+
+
+@pytest.fixture(scope="session", autouse=True)
+def make_test_enums(request, svcconn):
+    for enum in enum_cases:
+        ensure_enum(enum, svcconn)
+
+
+def ensure_enum(enum, conn):
     name = enum.__name__.lower()
     labels = list(enum.__members__.keys())
-    cur = svcconn.cursor()
-    cur.execute(
+    conn.execute(
         sql.SQL(
             """
             drop type if exists {name} cascade;
@@ -90,16 +100,14 @@ def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out):
 
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_loader_sqlascii(conn, testenum, fmt_in, fmt_out):
-    name, enum, labels = testenum
-    if name == "nonasciienum":
-        pytest.skip("ascii-only test")
-
-    register_enum(EnumInfo.fetch(conn, name), enum, conn)
+@pytest.mark.parametrize("enum", ascii_cases)
+def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out):
+    info = EnumInfo.fetch(conn, enum.__name__.lower())
+    register_enum(info, enum, conn)
     conn.execute("set client_encoding to sql_ascii")
 
-    for label in labels:
-        cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out)
+    for label in info.labels:
+        cur = conn.execute(f"select %{fmt_in}::{info.name}", [label], binary=fmt_out)
         assert cur.fetchone()[0] == enum[label]
 
 
@@ -119,12 +127,10 @@ def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out):
 
 @pytest.mark.parametrize("fmt_in", PyFormat)
 @pytest.mark.parametrize("fmt_out", pq.Format)
-def test_enum_dumper_sqlascii(conn, testenum, fmt_in, fmt_out):
-    name, enum, labels = testenum
-    if name == "nonasciienum":
-        pytest.skip("ascii-only test")
-
-    register_enum(EnumInfo.fetch(conn, name), enum, conn)
+@pytest.mark.parametrize("enum", ascii_cases)
+def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out):
+    info = EnumInfo.fetch(conn, enum.__name__.lower())
+    register_enum(info, enum, conn)
     conn.execute("set client_encoding to sql_ascii")
 
     for item in enum: