]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor(pool): add function wrapper to call either async/sync version
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 11 May 2025 19:09:06 +0000 (21:09 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 11 May 2025 23:49:48 +0000 (01:49 +0200)
psycopg_pool/psycopg_pool/_acompat.py
psycopg_pool/psycopg_pool/pool_async.py
tools/async_to_sync.py

index f5f3a022cb1d0adba482b3493669a24b25beb719..136c948dd7ff2bf6dc431694f70a3a8c46b56a12 100644 (file)
@@ -15,13 +15,15 @@ import queue
 import asyncio
 import logging
 import threading
-from typing import Any, TypeAlias
+from typing import Any, ParamSpec, TypeAlias, overload
+from inspect import isawaitable
 from collections.abc import Callable, Coroutine
 
 from ._compat import TypeVar
 
 logger = logging.getLogger("psycopg.pool")
 T = TypeVar("T")
+P = ParamSpec("P")
 
 # Re-exports
 Event = threading.Event
@@ -160,3 +162,24 @@ def asleep(seconds: float) -> Coroutine[Any, Any, None]:
     Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync.
     """
     return asyncio.sleep(seconds)
+
+
+@overload
+async def ensure_async(
+    f: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs
+) -> T: ...
+
+
+@overload
+async def ensure_async(f: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...
+
+
+async def ensure_async(
+    f: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]],
+    *args: P.args,
+    **kwargs: P.kwargs,
+) -> T:
+    rv = f(*args, **kwargs)
+    if isawaitable(rv):
+        rv = await rv
+    return rv
index 796955f14960a4994444a82cc46512bdd819d937..5f882d4eea436d06be8c8ed35587c268fea550d2 100644 (file)
@@ -26,7 +26,7 @@ from .base import AttemptWithBackoff, BasePool
 from .errors import PoolClosed, PoolTimeout, TooManyRequests
 from ._compat import Self
 from ._acompat import ACondition, AEvent, ALock, AQueue, AWorker, agather, asleep
-from ._acompat import aspawn, current_task_name
+from ._acompat import aspawn, current_task_name, ensure_async
 from .sched_async import AsyncScheduler
 
 if True:  # ASYNC
@@ -598,13 +598,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool):
         if not self._reconnect_failed:
             return
 
-        if True:  # ASYNC
-            if asyncio.iscoroutinefunction(self._reconnect_failed):
-                await self._reconnect_failed(self)
-            else:
-                self._reconnect_failed(self)
-        else:
-            self._reconnect_failed(self)
+        await ensure_async(self._reconnect_failed, self)
 
     def run_task(self, task: MaintenanceTask) -> None:
         """Run a maintenance task in a worker."""
index 4fa5258023eca979325da7ad74b6a3dede5a011f..e3c4b04cfef7d3dcaf8b0afc90954042036a10f4 100755 (executable)
@@ -332,6 +332,7 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
     }
     _skip_imports = {
         "acompat": {"alist", "anext"},
+        "_acompat": {"ensure_async"},
     }
 
     def visit_Module(self, node: ast.Module) -> ast.AST:
@@ -360,6 +361,9 @@ class RenameAsyncToSync(ast.NodeTransformer):  # type: ignore
             case ast.Call(func=ast.Name(id="cast")):
                 node.args[0] = self._convert_if_literal_string(node.args[0])
 
+            case ast.Call(func=ast.Name(id="ensure_async")):
+                node.func = node.args.pop(0)
+
         self.generic_visit(node)
         return node