]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: fall back to cancel() in cancel_safe() for libpq < 17
authorDenis Laxalde <denis.laxalde@dalibo.com>
Mon, 8 Apr 2024 08:00:04 +0000 (10:00 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Apr 2024 10:07:43 +0000 (12:07 +0200)
We run this in a thread executor in the AsyncConnection.
As asyncio's to_thread() is not available in Python 3.8, so we add a
compat layer.

psycopg/psycopg/_compat.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/cursor.py
psycopg/psycopg/cursor_async.py
tests/test_concurrency_async.py

index 68d689a2d413ae2c6d32d3eeffb643f7b04632d3..8d5d48aa95225b709c31f21bb31079e75ab35b91 100644 (file)
@@ -5,18 +5,29 @@ compatibility functions for different Python versions
 # Copyright (C) 2021 The Psycopg Team
 
 import sys
+from functools import partial
+from typing import Any
 
 if sys.version_info >= (3, 9):
+    from asyncio import to_thread
     from zoneinfo import ZoneInfo
     from functools import cache
     from collections import Counter, deque as Deque
+    from collections.abc import Callable
 else:
-    from typing import Counter, Deque
+    import asyncio
+    from typing import Callable, Counter, Deque
     from functools import lru_cache
     from backports.zoneinfo import ZoneInfo
 
     cache = lru_cache(maxsize=None)
 
+    async def to_thread(func: Callable[..., Any], /, *args: Any, **kwargs: Any) -> None:
+        loop = asyncio.get_running_loop()
+        func_call = partial(func, *args, **kwargs)
+        await loop.run_in_executor(None, func_call)
+
+
 if sys.version_info >= (3, 10):
     from typing import TypeGuard, TypeAlias
 else:
@@ -42,4 +53,5 @@ __all__ = [
     "TypeVar",
     "ZoneInfo",
     "cache",
+    "to_thread",
 ]
index b5430122d01602c915af1acacb9c3cdff2c0f8e6..f7e90d92af89ed302be694e581fd3bc8dc8c5e5a 100644 (file)
@@ -271,11 +271,14 @@ class Connection(BaseConnection[Row]):
         In contrast with `cancel()`, it is not appropriate for use in a signal
         handler.
 
-        :raises ~psycopg.NotSupportedError: if the underlying libpq is older
-            than version 17.
+        If the underlying libpq is older than version 17, fall back to
+        `cancel()`.
         """
         if self._should_cancel():
-            waiting.wait_conn(self._cancel_gen(), interval=_WAIT_INTERVAL)
+            try:
+                waiting.wait_conn(self._cancel_gen(), interval=_WAIT_INTERVAL)
+            except e.NotSupportedError:
+                self.cancel()
 
     @contextmanager
     def transaction(
@@ -382,10 +385,7 @@ class Connection(BaseConnection[Row]):
             if self.pgconn.transaction_status == ACTIVE:
                 # On Ctrl-C, try to cancel the query in the server, otherwise
                 # the connection will remain stuck in ACTIVE state.
-                try:
-                    self.cancel_safe()
-                except e.NotSupportedError:
-                    self.cancel()
+                self.cancel_safe()
                 try:
                     waiting.wait(gen, self.pgconn.socket, interval=interval)
                 except e.QueryCanceled:
index 3b2022a69dae68455f254dbd5a36357d71f6ba0c..ab370db13981c312be0ef48a28d5914ec88dff8e 100644 (file)
@@ -36,6 +36,7 @@ if True:  # ASYNC
     import sys
     import asyncio
     from asyncio import Lock
+    from ._compat import to_thread
 else:
     from threading import Lock
 
@@ -287,11 +288,19 @@ class AsyncConnection(BaseConnection[Row]):
         In contrast with `cancel()`, it is not appropriate for use in a signal
         handler.
 
-        :raises ~psycopg.NotSupportedError: if the underlying libpq is older
-            than version 17.
+        If the underlying libpq is older than version 17, fall back to
+        `cancel()`.
         """
         if self._should_cancel():
-            await waiting.wait_conn_async(self._cancel_gen(), interval=_WAIT_INTERVAL)
+            try:
+                await waiting.wait_conn_async(
+                    self._cancel_gen(), interval=_WAIT_INTERVAL
+                )
+            except e.NotSupportedError:
+                if True:  # ASYNC
+                    await to_thread(self.cancel)
+                else:
+                    self.cancel()
 
     @asynccontextmanager
     async def transaction(
@@ -400,10 +409,7 @@ class AsyncConnection(BaseConnection[Row]):
             if self.pgconn.transaction_status == ACTIVE:
                 # On Ctrl-C, try to cancel the query in the server, otherwise
                 # the connection will remain stuck in ACTIVE state.
-                try:
-                    await self.cancel_safe()
-                except e.NotSupportedError:
-                    self.cancel()
+                await self.cancel_safe()
                 try:
                     await waiting.wait_async(gen, self.pgconn.socket, interval=interval)
                 except e.QueryCanceled:
index 530bbd364d2e0b10f9319eb0b126e7443017c2bf..099076b516e51fac92e5b1f0b4953eb57fdcfd3c 100644 (file)
@@ -159,10 +159,7 @@ class Cursor(BaseCursor["Connection[Any]", Row]):
                 if self._pgconn.transaction_status == ACTIVE:
                     # Try to cancel the query, then consume the results
                     # already received.
-                    try:
-                        self._conn.cancel_safe()
-                    except e.NotSupportedError:
-                        self._conn.cancel()
+                    self._conn.cancel_safe()
                     try:
                         while self._conn.wait(self._stream_fetchone_gen(first=False)):
                             pass
index 2ea18ee54ac93086153e2728775e702ee22acd6e..485cf0e9358a7d154dd5ad88e3ffd410835d99b7 100644 (file)
@@ -164,10 +164,7 @@ class AsyncCursor(BaseCursor["AsyncConnection[Any]", Row]):
                 if self._pgconn.transaction_status == ACTIVE:
                     # Try to cancel the query, then consume the results
                     # already received.
-                    try:
-                        await self._conn.cancel_safe()
-                    except e.NotSupportedError:
-                        self._conn.cancel()
+                    await self._conn.cancel_safe()
                     try:
                         while await self._conn.wait(
                             self._stream_fetchone_gen(first=False)
index 98d659d689300726ddfc183e73e0ef5d8453b066..bd6247c7d4350b97f5e6c89a4556fd0b462b37ca 100644 (file)
@@ -61,10 +61,7 @@ async def test_concurrent_execution(aconn_cls, dsn):
 async def canceller(aconn, errors):
     try:
         await asyncio.sleep(0.5)
-        try:
-            await aconn.cancel_safe()
-        except e.NotSupportedError:
-            aconn.cancel()
+        await aconn.cancel_safe()
     except Exception as exc:
         errors.append(exc)