From: Daniele Varrazzo Date: Thu, 21 Apr 2022 08:21:18 +0000 (+0200) Subject: test(enum): refactor in order to write enum-specific type X-Git-Tag: 3.1~137^2~7 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=18949c126499d3a083e46527bd8cadc55630619b;p=thirdparty%2Fpsycopg.git test(enum): refactor in order to write enum-specific type They all get created, tests can choose to use only some of them. --- diff --git a/tests/types/test_enum.py b/tests/types/test_enum.py index 90cc00d9f..182fd8ace 100644 --- a/tests/types/test_enum.py +++ b/tests/types/test_enum.py @@ -32,6 +32,7 @@ class IntTestEnum(int, Enum): enum_cases = [PureTestEnum, StrTestEnum, NonAsciiEnum, IntTestEnum] +ascii_cases = [PureTestEnum, StrTestEnum, IntTestEnum] encodings = ["utf8", "latin1"] @@ -39,10 +40,19 @@ encodings = ["utf8", "latin1"] @pytest.fixture(scope="session", params=enum_cases) def testenum(request, svcconn): enum = request.param + return ensure_enum(enum, svcconn) + + +@pytest.fixture(scope="session", autouse=True) +def make_test_enums(request, svcconn): + for enum in enum_cases: + ensure_enum(enum, svcconn) + + +def ensure_enum(enum, conn): name = enum.__name__.lower() labels = list(enum.__members__.keys()) - cur = svcconn.cursor() - cur.execute( + conn.execute( sql.SQL( """ drop type if exists {name} cascade; @@ -90,16 +100,14 @@ def test_enum_loader(conn, testenum, encoding, fmt_in, fmt_out): @pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_loader_sqlascii(conn, testenum, fmt_in, fmt_out): - name, enum, labels = testenum - if name == "nonasciienum": - pytest.skip("ascii-only test") - - register_enum(EnumInfo.fetch(conn, name), enum, conn) +@pytest.mark.parametrize("enum", ascii_cases) +def test_enum_loader_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__.lower()) + register_enum(info, enum, conn) conn.execute("set client_encoding to sql_ascii") - for label in labels: - cur = conn.execute(f"select %{fmt_in}::{name}", [label], binary=fmt_out) + for label in info.labels: + cur = conn.execute(f"select %{fmt_in}::{info.name}", [label], binary=fmt_out) assert cur.fetchone()[0] == enum[label] @@ -119,12 +127,10 @@ def test_enum_dumper(conn, testenum, encoding, fmt_in, fmt_out): @pytest.mark.parametrize("fmt_in", PyFormat) @pytest.mark.parametrize("fmt_out", pq.Format) -def test_enum_dumper_sqlascii(conn, testenum, fmt_in, fmt_out): - name, enum, labels = testenum - if name == "nonasciienum": - pytest.skip("ascii-only test") - - register_enum(EnumInfo.fetch(conn, name), enum, conn) +@pytest.mark.parametrize("enum", ascii_cases) +def test_enum_dumper_sqlascii(conn, enum, fmt_in, fmt_out): + info = EnumInfo.fetch(conn, enum.__name__.lower()) + register_enum(info, enum, conn) conn.execute("set client_encoding to sql_ascii") for item in enum: