]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
tests: fix allow to skip running the pool tests again
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 11 Apr 2024 23:15:34 +0000 (01:15 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 11 Apr 2024 23:15:34 +0000 (01:15 +0200)
`pytest -m not pool` should allow to skip the pool test. However,
because of the attribute access at import time to define the test
marker, import failed as well.

Convert the marker to strings which will be used in a getattr by the
fixture. Extend async-to-sync to convert the pool classes names from the
string too.

tests/pool/test_pool_common.py
tests/pool/test_pool_common_async.py
tools/async_to_sync.py

index 0611d12d7f5f91a5f79b9c2ca64d9ae71caa1408..2f5baf4e5617fa388794d98d1909ea95e5f3403e 100644 (file)
@@ -19,9 +19,9 @@ except ImportError:
     pass
 
 
-@pytest.fixture(params=[pool.ConnectionPool, pool.NullConnectionPool])
+@pytest.fixture(params=["ConnectionPool", "NullConnectionPool"])
 def pool_cls(request):
-    return request.param
+    return getattr(pool, request.param)
 
 
 def test_defaults(pool_cls, dsn):
index 593510cf08f2c4b110be16f4ad329156af7a71cf..e566fc85dc2aeb081790b32e089b515b6aaf43ee 100644 (file)
@@ -19,9 +19,9 @@ if True:  # ASYNC
     pytestmark = [pytest.mark.anyio]
 
 
-@pytest.fixture(params=[pool.AsyncConnectionPool, pool.AsyncNullConnectionPool])
+@pytest.fixture(params=["AsyncConnectionPool", "AsyncNullConnectionPool"])
 def pool_cls(request):
-    return request.param
+    return getattr(pool, request.param)
 
 
 async def test_defaults(pool_cls, dsn):
index 2a6f8ab77c7d4afbbcb2e0ffbdde52f5be8bfbea..4c1ed2d8aa1e51b40f5fc563a25f75d295e08823 100755 (executable)
@@ -373,6 +373,8 @@ class RenameAsyncToSync(ast.NodeTransformer):
 
     def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
         self._fix_docstring(node.body)
+        if node.decorator_list:
+            self._fix_decorator(node.decorator_list)
         self.generic_visit(node)
         return node
 
@@ -386,6 +388,17 @@ class RenameAsyncToSync(ast.NodeTransformer):
             body[0].value.value = body[0].value.value.replace("Async", "")
             body[0].value.value = body[0].value.value.replace("(async", "(sync")
 
+    def _fix_decorator(self, decorator_list: list[ast.AST]) -> None:
+        for dec in decorator_list:
+            match dec:
+                case ast.Call(
+                    func=ast.Attribute(value=ast.Name(id="pytest"), attr="fixture"),
+                    keywords=[ast.keyword(arg="params", value=ast.List())],
+                ):
+                    elts = dec.keywords[0].value.elts
+                    for i, elt in enumerate(elts):
+                        elts[i] = self._visit_str(elt)
+
     def visit_Call(self, node: ast.Call) -> ast.AST:
         if isinstance(node.func, ast.Name) and node.func.id == "TypeVar":
             node = self._visit_Call_TypeVar(node)
@@ -405,6 +418,13 @@ class RenameAsyncToSync(ast.NodeTransformer):
 
         return node
 
+    def _visit_str(self, node: ast.AST) -> ast.AST:
+        match node:
+            case ast.Constant(value=str()):
+                node.value = self._visit_type_string(node.value)
+
+        return node
+
     def _visit_type_string(self, source: str) -> str:
         # Convert the string to tree, visit, and convert it back to string
         tree = ast.parse(source, type_comments=False)