]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add fixture to restore the state of the global adapters
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Jul 2021 01:15:45 +0000 (03:15 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Jul 2021 23:39:51 +0000 (01:39 +0200)
psycopg/psycopg/_typeinfo.py
tests/conftest.py
tests/fix_psycopg.py [new file with mode: 0644]
tests/test_adapt.py

index 2a38723cdd463d32ed4a8693c63cc99ef77f5e40..f9d2d5b763e1d5141d98980f7ceb337334d3ed77 100644 (file)
@@ -242,10 +242,13 @@ class TypesRegistry:
             self._by_range_subtype = template._by_range_subtype
             self._own_state = False
         else:
-            self._by_oid = {}
-            self._by_name = {}
-            self._by_range_subtype = {}
-            self._own_state = True
+            self.clear()
+
+    def clear(self) -> None:
+        self._by_oid = {}
+        self._by_name = {}
+        self._by_range_subtype = {}
+        self._own_state = True
 
     def add(self, info: TypeInfo) -> None:
         self._ensure_own_state()
index 8f3c36e2a55d74a65dabd4aab2dc3453e7f2402c..ff2be79863aafe2ef54c36f9af3cced54c73d984 100644 (file)
@@ -9,6 +9,7 @@ pytest_plugins = (
     "tests.fix_pq",
     "tests.fix_proxy",
     "tests.fix_faker",
+    "tests.fix_psycopg",
 )
 
 
diff --git a/tests/fix_psycopg.py b/tests/fix_psycopg.py
new file mode 100644 (file)
index 0000000..74b82b3
--- /dev/null
@@ -0,0 +1,21 @@
+from copy import deepcopy
+
+import pytest
+
+
+@pytest.fixture
+def global_adapters():
+    """Restore the global adapters after a test has changed them."""
+    from psycopg import adapters
+
+    dumpers = deepcopy(adapters._dumpers)
+    loaders = deepcopy(adapters._loaders)
+    types = list(adapters.types)
+
+    yield None
+
+    adapters._dumpers = dumpers
+    adapters._loaders = loaders
+    adapters.types.clear()
+    for t in types:
+        adapters.types.add(t)
index 06841ca62b5709e79e1d1d99b7419bffc5100178..d0ab11c0ffb22c7385825ac872159c814155759b 100644 (file)
@@ -60,20 +60,16 @@ def test_register_dumper_by_class_name(conn):
     assert conn.adapters.get_dumper(MyStr, Format.TEXT) is dumper
 
 
-def test_dump_global_ctx(dsn):
-    try:
-        psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
-        psycopg.adapters.register_dumper(MyStr, make_dumper("gt"))
-        conn = psycopg.connect(dsn)
-        cur = conn.execute("select %s", [MyStr("hello")])
-        assert cur.fetchone() == ("hellogt",)
-        cur = conn.execute("select %b", [MyStr("hello")])
-        assert cur.fetchone() == ("hellogb",)
-        cur = conn.execute("select %t", [MyStr("hello")])
-        assert cur.fetchone() == ("hellogt",)
-    finally:
-        for fmt in Format:
-            psycopg.adapters._dumpers[fmt].pop(MyStr, None)
+def test_dump_global_ctx(dsn, global_adapters):
+    psycopg.adapters.register_dumper(MyStr, make_bin_dumper("gb"))
+    psycopg.adapters.register_dumper(MyStr, make_dumper("gt"))
+    conn = psycopg.connect(dsn)
+    cur = conn.execute("select %s", [MyStr("hello")])
+    assert cur.fetchone() == ("hellogt",)
+    cur = conn.execute("select %b", [MyStr("hello")])
+    assert cur.fetchone() == ("hellogb",)
+    cur = conn.execute("select %t", [MyStr("hello")])
+    assert cur.fetchone() == ("hellogt",)
 
 
 def test_dump_connection_ctx(conn):
@@ -203,20 +199,14 @@ def test_register_loader_by_type_name(conn):
     assert conn.adapters.get_loader(TEXT_OID, pq.Format.TEXT) is loader
 
 
-def test_load_global_ctx(dsn):
-    from psycopg.types import string
-
-    try:
-        psycopg.adapters.register_loader("text", make_loader("gt"))
-        psycopg.adapters.register_loader("text", make_bin_loader("gb"))
-        conn = psycopg.connect(dsn)
-        cur = conn.cursor(binary=False).execute("select 'hello'::text")
-        assert cur.fetchone() == ("hellogt",)
-        cur = conn.cursor(binary=True).execute("select 'hello'::text")
-        assert cur.fetchone() == ("hellogb",)
-    finally:
-        psycopg.adapters.register_loader("text", string.TextLoader)
-        psycopg.adapters.register_loader("text", string.TextBinaryLoader)
+def test_load_global_ctx(dsn, global_adapters):
+    psycopg.adapters.register_loader("text", make_loader("gt"))
+    psycopg.adapters.register_loader("text", make_bin_loader("gb"))
+    conn = psycopg.connect(dsn)
+    cur = conn.cursor(binary=False).execute("select 'hello'::text")
+    assert cur.fetchone() == ("hellogt",)
+    cur = conn.cursor(binary=True).execute("select 'hello'::text")
+    assert cur.fetchone() == ("hellogb",)
 
 
 def test_load_connection_ctx(conn):