From 18949c126499d3a083e46527bd8cadc55630619b Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 21 Apr 2022 10:21:18 +0200 Subject: [PATCH] test(enum): refactor in order to write enum-specific type They all get created, tests can choose to use only some of them. --- tests/types/test_enum.py | 38 ++++++++++++++++++++++---------------- 1 file changed, 22 insertions(+), 16 deletions(-) diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py index 90cc00d9f..182fd8ace 100644 --- a/tests/types/test_enum.py +++ b/tests/types/test_enum.py @@ -32,6 +32,7 @@ class IntTestEnum(int, Enum): enum_cases = [PureTestEnum, StrTestEnum, NonAsciiEnum, IntTestEnum] +ascii_cases = [PureTestEnum, StrTestEnum, IntTestEnum] encodings = ["utf8", "latin1"] @@ -39,10 +40,19 @@ encodings = ["utf8", "latin1"] @pytest.fixture(scope="session", params=enum_cases) def testenum(request, svcconn): enum = request.param + return ensure_enum(enum, svcconn) + + +@pytest.fixture(scope="session", autouse=True) +def make_test_enums(request, svcconn): + for enum in enum_cases: + ensure_enum(enum, svcconn) + + +def ensure_enum(enum, conn): name = enum.__name__.lower() labels = list(enum.__members__.keys()) - cur = svcconn.cursor() - cur.execute( + conn.execute( sql.SQL( """ drop type if exists {name} cascade; @@ -90,16 +100,14 @@ def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out): @pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_loader_sqlascii(conn, testenum, fmt_in, fmt_out): - name, enum, labels = testenum - if name == "nonasciienum": - pytest.skip("ascii-only test") - - register_enum(EnumInfo.fetch(conn, name), enum, conn) +@pytest.mark.parametrize("enum", ascii_cases) +def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__.lower()) + register_enum(info, enum, conn) conn.execute("set client_encoding to sql_ascii") - for label in labels: - cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out) + for label in info.labels: + cur = conn.execute(f"select %{fmt_in}::{info.name}", [label], binary=fmt_out) assert cur.fetchone()[0] == enum[label] @@ -119,12 +127,10 @@ def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out): @pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_dumper_sqlascii(conn, testenum, fmt_in, fmt_out): - name, enum, labels = testenum - if name == "nonasciienum": - pytest.skip("ascii-only test") - - register_enum(EnumInfo.fetch(conn, name), enum, conn) +@pytest.mark.parametrize("enum", ascii_cases) +def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__.lower()) + register_enum(info, enum, conn) conn.execute("set client_encoding to sql_ascii") for item in enum: -- 2.47.2