]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added query canceling
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 30 May 2020 18:00:27 +0000 (06:00 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 30 May 2020 18:00:48 +0000 (06:00 +1200)
psycopg3/connection.py
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/libpq.pxd
psycopg3/pq/pq_ctypes.py
psycopg3/pq/pq_cython.pxd
psycopg3/pq/pq_cython.pyx
psycopg3/pq/proto.py
tests/pq/test_pgconn.py
tests/test_concurrency.py
tests/test_concurrency_async.py

index 917870111aa759e68309809596ececf8dba7e645..90992c32d64559a4e27993068ad5e5c3d0f79093 100644 (file)
@@ -149,6 +149,10 @@ class BaseConnection:
         else:
             return "UTF8"
 
+    def cancel(self) -> None:
+        c = self.pgconn.get_cancel()
+        c.cancel()
+
     def add_notice_handler(self, callback: NoticeHandler) -> None:
         self._notice_handlers.append(callback)
 
index 3623c80a92e798fc4b4a8c6312bb327bf410cc5f..ec7ec9e31f1697d2f8fb7282cece16aa026e7e4f 100644 (file)
@@ -61,10 +61,15 @@ class PGnotify_struct(Structure):
     ]
 
 
+class PGcancel_struct(Structure):
+    _fields_: List[Tuple[str, type]] = []
+
+
 PGconn_ptr = POINTER(PGconn_struct)
 PGresult_ptr = POINTER(PGresult_struct)
 PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
 PGnotify_ptr = POINTER(PGnotify_struct)
+PGcancel_ptr = POINTER(PGcancel_struct)
 
 
 # Function definitions as explained in PostgreSQL 12 documentation
@@ -456,7 +461,23 @@ PQisnonblocking.restype = c_int
 
 PQflush = pq.PQflush
 PQflush.argtypes = [PGconn_ptr]
-PQflush.restype == c_int
+PQflush.restype = c_int
+
+
+# 33.6. Canceling Queries in Progress
+
+PQgetCancel = pq.PQgetCancel
+PQgetCancel.argtypes = [PGconn_ptr]
+PQgetCancel.restype = PGcancel_ptr
+
+PQfreeCancel = pq.PQfreeCancel
+PQfreeCancel.argtypes = [PGcancel_ptr]
+PQfreeCancel.restype = None
+
+PQcancel = pq.PQcancel
+# TODO: raises "wrong type" error
+# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
+PQcancel.restype = c_int
 
 
 # 33.8. Asynchronous Notification
@@ -503,7 +524,11 @@ def generate_stub() -> None:
             else:
                 return "Optional[bytes]"
 
-        elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",):
+        elif t.__name__ in (
+            "LP_PGconn_struct",
+            "LP_PGresult_struct",
+            "LP_PGcancel_struct",
+        ):
             if narg is not None:
                 return f"Optional[{t.__name__[3:]}]"
             else:
index 35c6c054bea8d248335c211ccaee2ad0a5951944..a4ebefe4df1ed25d36acec7e18258e4ce0621f47 100644 (file)
@@ -12,6 +12,7 @@ Oid = c_uint
 
 class PGconn_struct: ...
 class PGresult_struct: ...
+class PGcancel_struct: ...
 
 class PQconninfoOption_struct:
     keyword: bytes
@@ -66,6 +67,9 @@ def PQsendQueryPrepared(
     arg6: Optional[Array[c_int]],
     arg7: int,
 ) -> int: ...
+def PQcancel(
+    arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int
+) -> int: ...
 def PQsetNoticeReceiver(
     arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
 ) -> Callable[[Any], PGresult_struct]: ...
@@ -144,6 +148,8 @@ def PQisBusy(arg1: Optional[PGconn_struct]) -> int: ...
 def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
 def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
 def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
+def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
+def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
 def PQfreemem(arg1: Any) -> None: ...
 def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
 # autogenerated: end
index bea3204c4c190f975d0909e25dbe7f01ccdce6bb..f0db306939f9c281abec95c5b6093063e20cd147 100644 (file)
@@ -17,6 +17,9 @@ cdef extern from "libpq-fe.h":
     ctypedef struct PGresult:
         pass
 
+    ctypedef struct PGcancel:
+        pass
+
     ctypedef struct PGnotify:
         char   *relname
         int     be_pid
@@ -213,6 +216,11 @@ cdef extern from "libpq-fe.h":
     int PQisnonblocking(const PGconn *conn)
     int PQflush(PGconn *conn)
 
+    # 33.6. Canceling Queries in Progress
+    PGcancel *PQgetCancel(PGconn *conn)
+    void PQfreeCancel(PGcancel *cancel)
+    int PQcancel(PGcancel *cancel, char *errbuf, int errbufsize)
+
     # 33.8. Asynchronous Notification
     PGnotify *PQnotifies(PGconn *conn)
 
index 25cd3d43f976f713b496a7a3403326241106b061..f0ddef9b055dc7a84a8a98da691a599f3b7123a1 100644 (file)
@@ -13,7 +13,7 @@ import logging
 from weakref import ref
 from functools import partial
 
-from ctypes import Array, pointer, string_at
+from ctypes import Array, pointer, string_at, create_string_buffer
 from ctypes import c_char_p, c_int, c_size_t, c_ulong
 from typing import Any, Callable, List, Optional, Sequence
 from typing import cast as t_cast, TYPE_CHECKING
@@ -484,6 +484,12 @@ class PGconn:
             raise PQerror(f"flushing failed: {error_message(self)}")
         return rv
 
+    def get_cancel(self) -> "PGcancel":
+        rv = impl.PQgetCancel(self.pgconn_ptr)
+        if not rv:
+            raise PQerror("couldn't create cancel object")
+        return PGcancel(rv)
+
     def notifies(self) -> Optional[PGnotify]:
         ptr = impl.PQnotifies(self.pgconn_ptr)
         if ptr:
@@ -627,6 +633,31 @@ class PGresult:
         return impl.PQoidValue(self.pgresult_ptr)
 
 
+class PGcancel:
+    __slots__ = ("pgcancel_ptr",)
+
+    def __init__(self, pgcancel_ptr: impl.PGcancel_struct):
+        self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr
+
+    def __del__(self) -> None:
+        self.free()
+
+    def free(self) -> None:
+        self.pgcancel_ptr, p = None, self.pgcancel_ptr
+        if p is not None:
+            impl.PQfreeCancel(p)
+
+    def cancel(self) -> None:
+        buf = create_string_buffer(256)
+        res = impl.PQcancel(
+            self.pgcancel_ptr, pointer(buf), len(buf)  # type: ignore
+        )
+        if not res:
+            raise PQerror(
+                f"cancel failed: {buf.value.decode('utf8', 'ignore')}"
+            )
+
+
 class Conninfo:
     @classmethod
     def get_defaults(cls) -> List[ConninfoOption]:
index 8c41348db1bb1742635097ea6f1e898743664ada..289aea2e022cdfc2068c386dbab9c2171ab9d5b9 100644 (file)
@@ -25,3 +25,10 @@ cdef class PGresult:
 
     @staticmethod
     cdef PGresult _from_ptr(impl.PGresult *ptr)
+
+
+cdef class PGcancel:
+    cdef impl.PGcancel* pgcancel_ptr
+
+    @staticmethod
+    cdef PGcancel _from_ptr(impl.PGcancel *ptr)
index 505e7ce26e238d1b95620d0298aa7e3b1bfbd364..566b63e6ad69860eaebf30694113053b6e5fbee0 100644 (file)
@@ -426,6 +426,12 @@ cdef class PGconn:
             )
         return rv
 
+    def get_cancel(self) -> PGcancel:
+        cdef impl.PGcancel *ptr = impl.PQgetCancel(self.pgconn_ptr)
+        if not ptr:
+            raise PQerror("couldn't create cancel object")
+        return PGcancel._from_ptr(ptr)
+
     def notifies(self) -> Optional[PGnotify]:
         cdef impl.PGnotify *ptr = impl.PQnotifies(self.pgconn_ptr)
         if ptr:
@@ -686,6 +692,33 @@ cdef class PGresult:
         return impl.PQoidValue(self.pgresult_ptr)
 
 
+cdef class PGcancel:
+    def __cinit__(self):
+        self.pgcancel_ptr = NULL
+
+    @staticmethod
+    cdef PGcancel _from_ptr(impl.PGcancel *ptr):
+        cdef PGcancel rv = PGcancel.__new__(PGcancel)
+        rv.pgcancel_ptr = ptr
+        return rv
+
+    def __dealloc__(self) -> None:
+        self.free()
+
+    def free(self) -> None:
+        if self.pgcancel_ptr is not NULL:
+            impl.PQfreeCancel(self.pgcancel_ptr)
+            self.pgcancel_ptr = NULL
+
+    def cancel(self) -> None:
+        cdef char buf[256]
+        cdef int res = impl.PQcancel(self.pgcancel_ptr, buf, sizeof(buf))
+        if not res:
+            raise PQerror(
+                f"cancel failed: {buf.decode('utf8', 'ignore')}"
+            )
+
+
 class Conninfo:
     @classmethod
     def get_defaults(cls) -> List[ConninfoOption]:
index 2e684c5c2393576e2d556a0cde7ed6305e42eab2..45a3d3dc48ab96ac0d9d2d030c86e2d474c4c6bd 100644 (file)
@@ -218,6 +218,9 @@ class PGconn(Protocol):
     def flush(self) -> int:
         ...
 
+    def get_cancel(self) -> "PGcancel":
+        ...
+
     def notifies(self) -> Optional["PGnotify"]:
         ...
 
@@ -298,6 +301,14 @@ class PGresult(Protocol):
         ...
 
 
+class PGcancel(Protocol):
+    def free(self) -> None:
+        ...
+
+    def cancel(self) -> None:
+        ...
+
+
 class Conninfo(Protocol):
     @classmethod
     def get_defaults(cls) -> List["ConninfoOption"]:
index 53b1e64eb941c52f06b59fd9ba5f455005235438..819b1f87157f11d29882db3effb4796db1a4437c 100644 (file)
@@ -345,6 +345,24 @@ def test_ssl_in_use(pgconn):
         pgconn.ssl_in_use
 
 
+def test_cancel(pgconn):
+    cancel = pgconn.get_cancel()
+    cancel.cancel()
+    cancel.cancel()
+    pgconn.finish()
+    cancel.cancel()
+    with pytest.raises(pq.PQerror):
+        pgconn.get_cancel()
+
+
+def test_cancel_free(pgconn):
+    cancel = pgconn.get_cancel()
+    cancel.free()
+    with pytest.raises(pq.PQerror):
+        cancel.cancel()
+    cancel.free()
+
+
 def test_make_empty_result(pgconn):
     pgconn.exec_(b"wat")
     res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
index 6f923737f49df797a9a05e99ad0fa34553d4097b..509c9962ca47c254843d570684fbfeee2136b9fc 100644 (file)
@@ -142,3 +142,32 @@ def test_notifies(conn, dsn):
     assert n.channel == "foo"
     assert n.payload == "2"
     assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+def test_cancel(conn):
+
+    errors = []
+
+    def canceller():
+        try:
+            time.sleep(0.5)
+            conn.cancel()
+        except Exception as exc:
+            errors.append(exc)
+
+    cur = conn.cursor()
+    t = threading.Thread(target=canceller)
+    t0 = time.time()
+    t.start()
+
+    with pytest.raises(psycopg3.DatabaseError):
+        cur.execute("select pg_sleep(2)")
+
+    t1 = time.time()
+    assert not errors
+    assert 0.0 < t1 - t0 < 1.0
+
+    # still working
+    conn.rollback()
+    assert cur.execute("select 1").fetchone()[0] == 1
index 30910c29494830544e03c9cb8593636f42266204..5477e858e8720d7f108a4a099b2ea2fe50512d5d 100644 (file)
@@ -91,3 +91,36 @@ async def test_notifies(aconn, dsn):
     assert n.channel == "foo"
     assert n.payload == "2"
     assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+async def test_cancel(aconn):
+
+    errors = []
+
+    async def canceller():
+        try:
+            await asyncio.sleep(0.5)
+            aconn.cancel()
+        except Exception as exc:
+            errors.append(exc)
+
+    async def worker():
+        cur = aconn.cursor()
+        with pytest.raises(psycopg3.DatabaseError):
+            await cur.execute("select pg_sleep(2)")
+
+    workers = [worker(), canceller()]
+
+    t0 = time.time()
+    await asyncio.wait(workers)
+
+    t1 = time.time()
+    assert not errors
+    assert 0.0 < t1 - t0 < 1.0
+
+    # still working
+    await aconn.rollback()
+    cur = aconn.cursor()
+    await cur.execute("select 1")
+    assert await cur.fetchone() == (1,)