pass
-@pytest.fixture(params=[pool.ConnectionPool, pool.NullConnectionPool])
+@pytest.fixture(params=["ConnectionPool", "NullConnectionPool"])
def pool_cls(request):
- return request.param
+ return getattr(pool, request.param)
def test_defaults(pool_cls, dsn):
pytestmark = [pytest.mark.anyio]
-@pytest.fixture(params=[pool.AsyncConnectionPool, pool.AsyncNullConnectionPool])
+@pytest.fixture(params=["AsyncConnectionPool", "AsyncNullConnectionPool"])
def pool_cls(request):
- return request.param
+ return getattr(pool, request.param)
async def test_defaults(pool_cls, dsn):
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST:
self._fix_docstring(node.body)
+ if node.decorator_list:
+ self._fix_decorator(node.decorator_list)
self.generic_visit(node)
return node
body[0].value.value = body[0].value.value.replace("Async", "")
body[0].value.value = body[0].value.value.replace("(async", "(sync")
+ def _fix_decorator(self, decorator_list: list[ast.AST]) -> None:
+ for dec in decorator_list:
+ match dec:
+ case ast.Call(
+ func=ast.Attribute(value=ast.Name(id="pytest"), attr="fixture"),
+ keywords=[ast.keyword(arg="params", value=ast.List())],
+ ):
+ elts = dec.keywords[0].value.elts
+ for i, elt in enumerate(elts):
+ elts[i] = self._visit_str(elt)
+
def visit_Call(self, node: ast.Call) -> ast.AST:
if isinstance(node.func, ast.Name) and node.func.id == "TypeVar":
node = self._visit_Call_TypeVar(node)
return node
+ def _visit_str(self, node: ast.AST) -> ast.AST:
+ match node:
+ case ast.Constant(value=str()):
+ node.value = self._visit_type_string(node.value)
+
+ return node
+
def _visit_type_string(self, source: str) -> str:
# Convert the string to tree, visit, and convert it back to string
tree = ast.parse(source, type_comments=False)