From 3cebb042f8b23ffec54812e30d9116569c1148fc Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 11 May 2025 21:09:06 +0200 Subject: [PATCH] refactor(pool): add function wrapper to call either async/sync version --- psycopg_pool/psycopg_pool/_acompat.py | 25 ++++++++++++++++++++++++- psycopg_pool/psycopg_pool/pool_async.py | 10 ++-------- tools/async_to_sync.py | 4 ++++ 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/psycopg_pool/psycopg_pool/_acompat.py b/psycopg_pool/psycopg_pool/_acompat.py index f5f3a022c..136c948dd 100644 --- a/psycopg_pool/psycopg_pool/_acompat.py +++ b/psycopg_pool/psycopg_pool/_acompat.py @@ -15,13 +15,15 @@ import queue import asyncio import logging import threading -from typing import Any, TypeAlias +from typing import Any, ParamSpec, TypeAlias, overload +from inspect import isawaitable from collections.abc import Callable, Coroutine from ._compat import TypeVar logger = logging.getLogger("psycopg.pool") T = TypeVar("T") +P = ParamSpec("P") # Re-exports Event = threading.Event @@ -160,3 +162,24 @@ def asleep(seconds: float) -> Coroutine[Any, Any, None]: Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync. """ return asyncio.sleep(seconds) + + +@overload +async def ensure_async( + f: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs +) -> T: ... + + +@overload +async def ensure_async(f: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ... + + +async def ensure_async( + f: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]], + *args: P.args, + **kwargs: P.kwargs, +) -> T: + rv = f(*args, **kwargs) + if isawaitable(rv): + rv = await rv + return rv diff --git a/psycopg_pool/psycopg_pool/pool_async.py b/psycopg_pool/psycopg_pool/pool_async.py index 796955f14..5f882d4ee 100644 --- a/psycopg_pool/psycopg_pool/pool_async.py +++ b/psycopg_pool/psycopg_pool/pool_async.py @@ -26,7 +26,7 @@ from .base import AttemptWithBackoff, BasePool from .errors import PoolClosed, PoolTimeout, TooManyRequests from ._compat import Self from ._acompat import ACondition, AEvent, ALock, AQueue, AWorker, agather, asleep -from ._acompat import aspawn, current_task_name +from ._acompat import aspawn, current_task_name, ensure_async from .sched_async import AsyncScheduler if True: # ASYNC @@ -598,13 +598,7 @@ class AsyncConnectionPool(Generic[ACT], BasePool): if not self._reconnect_failed: return - if True: # ASYNC - if asyncio.iscoroutinefunction(self._reconnect_failed): - await self._reconnect_failed(self) - else: - self._reconnect_failed(self) - else: - self._reconnect_failed(self) + await ensure_async(self._reconnect_failed, self) def run_task(self, task: MaintenanceTask) -> None: """Run a maintenance task in a worker.""" diff --git a/tools/async_to_sync.py b/tools/async_to_sync.py index 4fa525802..e3c4b04cf 100755 --- a/tools/async_to_sync.py +++ b/tools/async_to_sync.py @@ -332,6 +332,7 @@ class RenameAsyncToSync(ast.NodeTransformer): # type: ignore } _skip_imports = { "acompat": {"alist", "anext"}, + "_acompat": {"ensure_async"}, } def visit_Module(self, node: ast.Module) -> ast.AST: @@ -360,6 +361,9 @@ class RenameAsyncToSync(ast.NodeTransformer): # type: ignore case ast.Call(func=ast.Name(id="cast")): node.args[0] = self._convert_if_literal_string(node.args[0]) + case ast.Call(func=ast.Name(id="ensure_async")): + node.func = node.args.pop(0) + self.generic_visit(node) return node -- 2.47.2