From: Daniele Varrazzo Date: Fri, 16 Jul 2021 01:15:45 +0000 (+0200) Subject: Add fixture to restore the state of the global adapters X-Git-Tag: 3.0.dev1~6 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=4ce5e66149c26e7be482b0073cb64a3fe4d1f77f;p=thirdparty%2Fpsycopg.git Add fixture to restore the state of the global adapters --- diff --git a/psycopg/psycopg/_typeinfo.py b/psycopg/psycopg/_typeinfo.py index 2a38723cd..f9d2d5b76 100644 --- a/psycopg/psycopg/_typeinfo.py +++ b/psycopg/psycopg/_typeinfo.py @@ -242,10 +242,13 @@ class TypesRegistry: self._by_range_subtype = template._by_range_subtype self._own_state = False else: - self._by_oid = {} - self._by_name = {} - self._by_range_subtype = {} - self._own_state = True + self.clear() + + def clear(self) -> None: + self._by_oid = {} + self._by_name = {} + self._by_range_subtype = {} + self._own_state = True def add(self, info: TypeInfo) -> None: self._ensure_own_state() diff --git a/tests/conftest.py b/tests/conftest.py index 8f3c36e2a..ff2be7986 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,6 +9,7 @@ pytest_plugins = ( "tests.fix_pq", "tests.fix_proxy", "tests.fix_faker", + "tests.fix_psycopg", ) diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py new file mode 100644 index 000000000..74b82b3cb --- /dev/null +++ b/tests/fix_psycopg.py @@ -0,0 +1,21 @@ +from copy import deepcopy + +import pytest + + +@pytest.fixture +def global_adapters(): + """Restore the global adapters after a test has changed them.""" + from psycopg import adapters + + dumpers = deepcopy(adapters._dumpers) + loaders = deepcopy(adapters._loaders) + types = list(adapters.types) + + yield None + + adapters._dumpers = dumpers + adapters._loaders = loaders + adapters.types.clear() + for t in types: + adapters.types.add(t) diff --git a/tests/test_adapt.py b/tests/test_adapt.py index 06841ca62..d0ab11c0f 100644 --- a/tests/test_adapt.py +++ b/tests/test_adapt.py @@ -60,20 +60,16 @@ def test_register_dumper_by_class_name(conn): assert conn.adapters.get_dumper(MyStr, Format.TEXT) is dumper -def test_dump_global_ctx(dsn): - try: - psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb")) - psycopg.adapters.register_dumper(MyStr, make_dumper("gt")) - conn = psycopg.connect(dsn) - cur = conn.execute("select %s", [MyStr("hello")]) - assert cur.fetchone() == ("hellogt",) - cur = conn.execute("select %b", [MyStr("hello")]) - assert cur.fetchone() == ("hellogb",) - cur = conn.execute("select %t", [MyStr("hello")]) - assert cur.fetchone() == ("hellogt",) - finally: - for fmt in Format: - psycopg.adapters._dumpers[fmt].pop(MyStr, None) +def test_dump_global_ctx(dsn, global_adapters): + psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb")) + psycopg.adapters.register_dumper(MyStr, make_dumper("gt")) + conn = psycopg.connect(dsn) + cur = conn.execute("select %s", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) + cur = conn.execute("select %b", [MyStr("hello")]) + assert cur.fetchone() == ("hellogb",) + cur = conn.execute("select %t", [MyStr("hello")]) + assert cur.fetchone() == ("hellogt",) def test_dump_connection_ctx(conn): @@ -203,20 +199,14 @@ def test_register_loader_by_type_name(conn): assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader -def test_load_global_ctx(dsn): - from psycopg.types import string - - try: - psycopg.adapters.register_loader("text", make_loader("gt")) - psycopg.adapters.register_loader("text", make_bin_loader("gb")) - conn = psycopg.connect(dsn) - cur = conn.cursor(binary=False).execute("select 'hello'::text") - assert cur.fetchone() == ("hellogt",) - cur = conn.cursor(binary=True).execute("select 'hello'::text") - assert cur.fetchone() == ("hellogb",) - finally: - psycopg.adapters.register_loader("text", string.TextLoader) - psycopg.adapters.register_loader("text", string.TextBinaryLoader) +def test_load_global_ctx(dsn, global_adapters): + psycopg.adapters.register_loader("text", make_loader("gt")) + psycopg.adapters.register_loader("text", make_bin_loader("gb")) + conn = psycopg.connect(dsn) + cur = conn.cursor(binary=False).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogt",) + cur = conn.cursor(binary=True).execute("select 'hello'::text") + assert cur.fetchone() == ("hellogb",) def test_load_connection_ctx(conn):