From 6dc2e8d464e195d8fcf06b39b1b85fe05c00bf9d Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 31 May 2020 06:00:27 +1200 Subject: [PATCH] Added query canceling --- psycopg3/connection.py | 4 ++++ psycopg3/pq/_pq_ctypes.py | 29 +++++++++++++++++++++++++++-- psycopg3/pq/_pq_ctypes.pyi | 6 ++++++ psycopg3/pq/libpq.pxd | 8 ++++++++ psycopg3/pq/pq_ctypes.py | 33 ++++++++++++++++++++++++++++++++- psycopg3/pq/pq_cython.pxd | 7 +++++++ psycopg3/pq/pq_cython.pyx | 33 +++++++++++++++++++++++++++++++++ psycopg3/pq/proto.py | 11 +++++++++++ tests/pq/test_pgconn.py | 18 ++++++++++++++++++ tests/test_concurrency.py | 29 +++++++++++++++++++++++++++++ tests/test_concurrency_async.py | 33 +++++++++++++++++++++++++++++++++ 11 files changed, 208 insertions(+), 3 deletions(-) diff --git a/psycopg3/connection.py b/psycopg3/connection.py index 917870111..90992c32d 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -149,6 +149,10 @@ class BaseConnection: else: return "UTF8" + def cancel(self) -> None: + c = self.pgconn.get_cancel() + c.cancel() + def add_notice_handler(self, callback: NoticeHandler) -> None: self._notice_handlers.append(callback) diff --git a/psycopg3/pq/_pq_ctypes.py b/psycopg3/pq/_pq_ctypes.py index 3623c80a9..ec7ec9e31 100644 --- a/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/pq/_pq_ctypes.py @@ -61,10 +61,15 @@ class PGnotify_struct(Structure): ] +class PGcancel_struct(Structure): + _fields_: List[Tuple[str, type]] = [] + + PGconn_ptr = POINTER(PGconn_struct) PGresult_ptr = POINTER(PGresult_struct) PQconninfoOption_ptr = POINTER(PQconninfoOption_struct) PGnotify_ptr = POINTER(PGnotify_struct) +PGcancel_ptr = POINTER(PGcancel_struct) # Function definitions as explained in PostgreSQL 12 documentation @@ -456,7 +461,23 @@ PQisnonblocking.restype = c_int PQflush = pq.PQflush PQflush.argtypes = [PGconn_ptr] -PQflush.restype == c_int +PQflush.restype = c_int + + +# 33.6. Canceling Queries in Progress + +PQgetCancel = pq.PQgetCancel +PQgetCancel.argtypes = [PGconn_ptr] +PQgetCancel.restype = PGcancel_ptr + +PQfreeCancel = pq.PQfreeCancel +PQfreeCancel.argtypes = [PGcancel_ptr] +PQfreeCancel.restype = None + +PQcancel = pq.PQcancel +# TODO: raises "wrong type" error +# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int] +PQcancel.restype = c_int # 33.8. Asynchronous Notification @@ -503,7 +524,11 @@ def generate_stub() -> None: else: return "Optional[bytes]" - elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",): + elif t.__name__ in ( + "LP_PGconn_struct", + "LP_PGresult_struct", + "LP_PGcancel_struct", + ): if narg is not None: return f"Optional[{t.__name__[3:]}]" else: diff --git a/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/pq/_pq_ctypes.pyi index 35c6c054b..a4ebefe4d 100644 --- a/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/pq/_pq_ctypes.pyi @@ -12,6 +12,7 @@ Oid = c_uint class PGconn_struct: ... class PGresult_struct: ... +class PGcancel_struct: ... class PQconninfoOption_struct: keyword: bytes @@ -66,6 +67,9 @@ def PQsendQueryPrepared( arg6: Optional[Array[c_int]], arg7: int, ) -> int: ... +def PQcancel( + arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int +) -> int: ... def PQsetNoticeReceiver( arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any ) -> Callable[[Any], PGresult_struct]: ... @@ -144,6 +148,8 @@ def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ... def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ... def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ... def PQflush(arg1: Optional[PGconn_struct]) -> int: ... +def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ... +def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ... def PQfreemem(arg1: Any) -> None: ... def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ... # autogenerated: end diff --git a/psycopg3/pq/libpq.pxd b/psycopg3/pq/libpq.pxd index bea3204c4..f0db30693 100644 --- a/psycopg3/pq/libpq.pxd +++ b/psycopg3/pq/libpq.pxd @@ -17,6 +17,9 @@ cdef extern from "libpq-fe.h": ctypedef struct PGresult: pass + ctypedef struct PGcancel: + pass + ctypedef struct PGnotify: char *relname int be_pid @@ -213,6 +216,11 @@ cdef extern from "libpq-fe.h": int PQisnonblocking(const PGconn *conn) int PQflush(PGconn *conn) + # 33.6. Canceling Queries in Progress + PGcancel *PQgetCancel(PGconn *conn) + void PQfreeCancel(PGcancel *cancel) + int PQcancel(PGcancel *cancel, char *errbuf, int errbufsize) + # 33.8. Asynchronous Notification PGnotify *PQnotifies(PGconn *conn) diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index 25cd3d43f..f0ddef9b0 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -13,7 +13,7 @@ import logging from weakref import ref from functools import partial -from ctypes import Array, pointer, string_at +from ctypes import Array, pointer, string_at, create_string_buffer from ctypes import c_char_p, c_int, c_size_t, c_ulong from typing import Any, Callable, List, Optional, Sequence from typing import cast as t_cast, TYPE_CHECKING @@ -484,6 +484,12 @@ class PGconn: raise PQerror(f"flushing failed: {error_message(self)}") return rv + def get_cancel(self) -> "PGcancel": + rv = impl.PQgetCancel(self.pgconn_ptr) + if not rv: + raise PQerror("couldn't create cancel object") + return PGcancel(rv) + def notifies(self) -> Optional[PGnotify]: ptr = impl.PQnotifies(self.pgconn_ptr) if ptr: @@ -627,6 +633,31 @@ class PGresult: return impl.PQoidValue(self.pgresult_ptr) +class PGcancel: + __slots__ = ("pgcancel_ptr",) + + def __init__(self, pgcancel_ptr: impl.PGcancel_struct): + self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr + + def __del__(self) -> None: + self.free() + + def free(self) -> None: + self.pgcancel_ptr, p = None, self.pgcancel_ptr + if p is not None: + impl.PQfreeCancel(p) + + def cancel(self) -> None: + buf = create_string_buffer(256) + res = impl.PQcancel( + self.pgcancel_ptr, pointer(buf), len(buf) # type: ignore + ) + if not res: + raise PQerror( + f"cancel failed: {buf.value.decode('utf8', 'ignore')}" + ) + + class Conninfo: @classmethod def get_defaults(cls) -> List[ConninfoOption]: diff --git a/psycopg3/pq/pq_cython.pxd b/psycopg3/pq/pq_cython.pxd index 8c41348db..289aea2e0 100644 --- a/psycopg3/pq/pq_cython.pxd +++ b/psycopg3/pq/pq_cython.pxd @@ -25,3 +25,10 @@ cdef class PGresult: @staticmethod cdef PGresult _from_ptr(impl.PGresult *ptr) + + +cdef class PGcancel: + cdef impl.PGcancel* pgcancel_ptr + + @staticmethod + cdef PGcancel _from_ptr(impl.PGcancel *ptr) diff --git a/psycopg3/pq/pq_cython.pyx b/psycopg3/pq/pq_cython.pyx index 505e7ce26..566b63e6a 100644 --- a/psycopg3/pq/pq_cython.pyx +++ b/psycopg3/pq/pq_cython.pyx @@ -426,6 +426,12 @@ cdef class PGconn: ) return rv + def get_cancel(self) -> PGcancel: + cdef impl.PGcancel *ptr = impl.PQgetCancel(self.pgconn_ptr) + if not ptr: + raise PQerror("couldn't create cancel object") + return PGcancel._from_ptr(ptr) + def notifies(self) -> Optional[PGnotify]: cdef impl.PGnotify *ptr = impl.PQnotifies(self.pgconn_ptr) if ptr: @@ -686,6 +692,33 @@ cdef class PGresult: return impl.PQoidValue(self.pgresult_ptr) +cdef class PGcancel: + def __cinit__(self): + self.pgcancel_ptr = NULL + + @staticmethod + cdef PGcancel _from_ptr(impl.PGcancel *ptr): + cdef PGcancel rv = PGcancel.__new__(PGcancel) + rv.pgcancel_ptr = ptr + return rv + + def __dealloc__(self) -> None: + self.free() + + def free(self) -> None: + if self.pgcancel_ptr is not NULL: + impl.PQfreeCancel(self.pgcancel_ptr) + self.pgcancel_ptr = NULL + + def cancel(self) -> None: + cdef char buf[256] + cdef int res = impl.PQcancel(self.pgcancel_ptr, buf, sizeof(buf)) + if not res: + raise PQerror( + f"cancel failed: {buf.decode('utf8', 'ignore')}" + ) + + class Conninfo: @classmethod def get_defaults(cls) -> List[ConninfoOption]: diff --git a/psycopg3/pq/proto.py b/psycopg3/pq/proto.py index 2e684c5c2..45a3d3dc4 100644 --- a/psycopg3/pq/proto.py +++ b/psycopg3/pq/proto.py @@ -218,6 +218,9 @@ class PGconn(Protocol): def flush(self) -> int: ... + def get_cancel(self) -> "PGcancel": + ... + def notifies(self) -> Optional["PGnotify"]: ... @@ -298,6 +301,14 @@ class PGresult(Protocol): ... +class PGcancel(Protocol): + def free(self) -> None: + ... + + def cancel(self) -> None: + ... + + class Conninfo(Protocol): @classmethod def get_defaults(cls) -> List["ConninfoOption"]: diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py index 53b1e64eb..819b1f871 100644 --- a/tests/pq/test_pgconn.py +++ b/tests/pq/test_pgconn.py @@ -345,6 +345,24 @@ def test_ssl_in_use(pgconn): pgconn.ssl_in_use +def test_cancel(pgconn): + cancel = pgconn.get_cancel() + cancel.cancel() + cancel.cancel() + pgconn.finish() + cancel.cancel() + with pytest.raises(pq.PQerror): + pgconn.get_cancel() + + +def test_cancel_free(pgconn): + cancel = pgconn.get_cancel() + cancel.free() + with pytest.raises(pq.PQerror): + cancel.cancel() + cancel.free() + + def test_make_empty_result(pgconn): pgconn.exec_(b"wat") res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index 6f923737f..509c9962c 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -142,3 +142,32 @@ def test_notifies(conn, dsn): assert n.channel == "foo" assert n.payload == "2" assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +@pytest.mark.slow +def test_cancel(conn): + + errors = [] + + def canceller(): + try: + time.sleep(0.5) + conn.cancel() + except Exception as exc: + errors.append(exc) + + cur = conn.cursor() + t = threading.Thread(target=canceller) + t0 = time.time() + t.start() + + with pytest.raises(psycopg3.DatabaseError): + cur.execute("select pg_sleep(2)") + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + conn.rollback() + assert cur.execute("select 1").fetchone()[0] == 1 diff --git a/tests/test_concurrency_async.py b/tests/test_concurrency_async.py index 30910c294..5477e858e 100644 --- a/tests/test_concurrency_async.py +++ b/tests/test_concurrency_async.py @@ -91,3 +91,36 @@ async def test_notifies(aconn, dsn): assert n.channel == "foo" assert n.payload == "2" assert t1 - t0 == pytest.approx(0.5, abs=0.05) + + +@pytest.mark.slow +async def test_cancel(aconn): + + errors = [] + + async def canceller(): + try: + await asyncio.sleep(0.5) + aconn.cancel() + except Exception as exc: + errors.append(exc) + + async def worker(): + cur = aconn.cursor() + with pytest.raises(psycopg3.DatabaseError): + await cur.execute("select pg_sleep(2)") + + workers = [worker(), canceller()] + + t0 = time.time() + await asyncio.wait(workers) + + t1 = time.time() + assert not errors + assert 0.0 < t1 - t0 < 1.0 + + # still working + await aconn.rollback() + cur = aconn.cursor() + await cur.execute("select 1") + assert await cur.fetchone() == (1,) -- 2.47.2