import asyncio
import logging
import threading
-from typing import Any, TypeAlias
+from typing import Any, ParamSpec, TypeAlias, overload
+from inspect import isawaitable
from collections.abc import Callable, Coroutine
from ._compat import TypeVar
logger = logging.getLogger("psycopg.pool")
T = TypeVar("T")
+P = ParamSpec("P")
# Re-exports
Event = threading.Event
Equivalent to asyncio.sleep(), converted to time.sleep() by async_to_sync.
"""
return asyncio.sleep(seconds)
+
+
+@overload
+async def ensure_async(
+ f: Callable[P, Coroutine[Any, Any, T]], *args: P.args, **kwargs: P.kwargs
+) -> T: ...
+
+
+@overload
+async def ensure_async(f: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T: ...
+
+
+async def ensure_async(
+ f: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]],
+ *args: P.args,
+ **kwargs: P.kwargs,
+) -> T:
+ rv = f(*args, **kwargs)
+ if isawaitable(rv):
+ rv = await rv
+ return rv
from .errors import PoolClosed, PoolTimeout, TooManyRequests
from ._compat import Self
from ._acompat import ACondition, AEvent, ALock, AQueue, AWorker, agather, asleep
-from ._acompat import aspawn, current_task_name
+from ._acompat import aspawn, current_task_name, ensure_async
from .sched_async import AsyncScheduler
if True: # ASYNC
if not self._reconnect_failed:
return
- if True: # ASYNC
- if asyncio.iscoroutinefunction(self._reconnect_failed):
- await self._reconnect_failed(self)
- else:
- self._reconnect_failed(self)
- else:
- self._reconnect_failed(self)
+ await ensure_async(self._reconnect_failed, self)
def run_task(self, task: MaintenanceTask) -> None:
"""Run a maintenance task in a worker."""
}
_skip_imports = {
"acompat": {"alist", "anext"},
+ "_acompat": {"ensure_async"},
}
def visit_Module(self, node: ast.Module) -> ast.AST:
case ast.Call(func=ast.Name(id="cast")):
node.args[0] = self._convert_if_literal_string(node.args[0])
+ case ast.Call(func=ast.Name(id="ensure_async")):
+ node.func = node.args.pop(0)
+
self.generic_visit(node)
return node