else:
return "UTF8"
+ def cancel(self) -> None:
+ c = self.pgconn.get_cancel()
+ c.cancel()
+
def add_notice_handler(self, callback: NoticeHandler) -> None:
self._notice_handlers.append(callback)
]
+class PGcancel_struct(Structure):
+ _fields_: List[Tuple[str, type]] = []
+
+
PGconn_ptr = POINTER(PGconn_struct)
PGresult_ptr = POINTER(PGresult_struct)
PQconninfoOption_ptr = POINTER(PQconninfoOption_struct)
PGnotify_ptr = POINTER(PGnotify_struct)
+PGcancel_ptr = POINTER(PGcancel_struct)
# Function definitions as explained in PostgreSQL 12 documentation
PQflush = pq.PQflush
PQflush.argtypes = [PGconn_ptr]
-PQflush.restype == c_int
+PQflush.restype = c_int
+
+
+# 33.6. Canceling Queries in Progress
+
+PQgetCancel = pq.PQgetCancel
+PQgetCancel.argtypes = [PGconn_ptr]
+PQgetCancel.restype = PGcancel_ptr
+
+PQfreeCancel = pq.PQfreeCancel
+PQfreeCancel.argtypes = [PGcancel_ptr]
+PQfreeCancel.restype = None
+
+PQcancel = pq.PQcancel
+# TODO: raises "wrong type" error
+# PQcancel.argtypes = [PGcancel_ptr, POINTER(c_char), c_int]
+PQcancel.restype = c_int
# 33.8. Asynchronous Notification
else:
return "Optional[bytes]"
- elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",):
+ elif t.__name__ in (
+ "LP_PGconn_struct",
+ "LP_PGresult_struct",
+ "LP_PGcancel_struct",
+ ):
if narg is not None:
return f"Optional[{t.__name__[3:]}]"
else:
class PGconn_struct: ...
class PGresult_struct: ...
+class PGcancel_struct: ...
class PQconninfoOption_struct:
keyword: bytes
arg6: Optional[Array[c_int]],
arg7: int,
) -> int: ...
+def PQcancel(
+ arg1: Optional[PGcancel_struct], arg2: c_char_p, arg3: int
+) -> int: ...
def PQsetNoticeReceiver(
arg1: PGconn_struct, arg2: Callable[[Any], PGresult_struct], arg3: Any
) -> Callable[[Any], PGresult_struct]: ...
def PQsetnonblocking(arg1: Optional[PGconn_struct], arg2: int) -> int: ...
def PQisnonblocking(arg1: Optional[PGconn_struct]) -> int: ...
def PQflush(arg1: Optional[PGconn_struct]) -> int: ...
+def PQgetCancel(arg1: Optional[PGconn_struct]) -> PGcancel_struct: ...
+def PQfreeCancel(arg1: Optional[PGcancel_struct]) -> None: ...
def PQfreemem(arg1: Any) -> None: ...
def PQmakeEmptyPGresult(arg1: Optional[PGconn_struct], arg2: int) -> PGresult_struct: ...
# autogenerated: end
ctypedef struct PGresult:
pass
+ ctypedef struct PGcancel:
+ pass
+
ctypedef struct PGnotify:
char *relname
int be_pid
int PQisnonblocking(const PGconn *conn)
int PQflush(PGconn *conn)
+ # 33.6. Canceling Queries in Progress
+ PGcancel *PQgetCancel(PGconn *conn)
+ void PQfreeCancel(PGcancel *cancel)
+ int PQcancel(PGcancel *cancel, char *errbuf, int errbufsize)
+
# 33.8. Asynchronous Notification
PGnotify *PQnotifies(PGconn *conn)
from weakref import ref
from functools import partial
-from ctypes import Array, pointer, string_at
+from ctypes import Array, pointer, string_at, create_string_buffer
from ctypes import c_char_p, c_int, c_size_t, c_ulong
from typing import Any, Callable, List, Optional, Sequence
from typing import cast as t_cast, TYPE_CHECKING
raise PQerror(f"flushing failed: {error_message(self)}")
return rv
+ def get_cancel(self) -> "PGcancel":
+ rv = impl.PQgetCancel(self.pgconn_ptr)
+ if not rv:
+ raise PQerror("couldn't create cancel object")
+ return PGcancel(rv)
+
def notifies(self) -> Optional[PGnotify]:
ptr = impl.PQnotifies(self.pgconn_ptr)
if ptr:
return impl.PQoidValue(self.pgresult_ptr)
+class PGcancel:
+ __slots__ = ("pgcancel_ptr",)
+
+ def __init__(self, pgcancel_ptr: impl.PGcancel_struct):
+ self.pgcancel_ptr: Optional[impl.PGcancel_struct] = pgcancel_ptr
+
+ def __del__(self) -> None:
+ self.free()
+
+ def free(self) -> None:
+ self.pgcancel_ptr, p = None, self.pgcancel_ptr
+ if p is not None:
+ impl.PQfreeCancel(p)
+
+ def cancel(self) -> None:
+ buf = create_string_buffer(256)
+ res = impl.PQcancel(
+ self.pgcancel_ptr, pointer(buf), len(buf) # type: ignore
+ )
+ if not res:
+ raise PQerror(
+ f"cancel failed: {buf.value.decode('utf8', 'ignore')}"
+ )
+
+
class Conninfo:
@classmethod
def get_defaults(cls) -> List[ConninfoOption]:
@staticmethod
cdef PGresult _from_ptr(impl.PGresult *ptr)
+
+
+cdef class PGcancel:
+ cdef impl.PGcancel* pgcancel_ptr
+
+ @staticmethod
+ cdef PGcancel _from_ptr(impl.PGcancel *ptr)
)
return rv
+ def get_cancel(self) -> PGcancel:
+ cdef impl.PGcancel *ptr = impl.PQgetCancel(self.pgconn_ptr)
+ if not ptr:
+ raise PQerror("couldn't create cancel object")
+ return PGcancel._from_ptr(ptr)
+
def notifies(self) -> Optional[PGnotify]:
cdef impl.PGnotify *ptr = impl.PQnotifies(self.pgconn_ptr)
if ptr:
return impl.PQoidValue(self.pgresult_ptr)
+cdef class PGcancel:
+ def __cinit__(self):
+ self.pgcancel_ptr = NULL
+
+ @staticmethod
+ cdef PGcancel _from_ptr(impl.PGcancel *ptr):
+ cdef PGcancel rv = PGcancel.__new__(PGcancel)
+ rv.pgcancel_ptr = ptr
+ return rv
+
+ def __dealloc__(self) -> None:
+ self.free()
+
+ def free(self) -> None:
+ if self.pgcancel_ptr is not NULL:
+ impl.PQfreeCancel(self.pgcancel_ptr)
+ self.pgcancel_ptr = NULL
+
+ def cancel(self) -> None:
+ cdef char buf[256]
+ cdef int res = impl.PQcancel(self.pgcancel_ptr, buf, sizeof(buf))
+ if not res:
+ raise PQerror(
+ f"cancel failed: {buf.decode('utf8', 'ignore')}"
+ )
+
+
class Conninfo:
@classmethod
def get_defaults(cls) -> List[ConninfoOption]:
def flush(self) -> int:
...
+ def get_cancel(self) -> "PGcancel":
+ ...
+
def notifies(self) -> Optional["PGnotify"]:
...
...
+class PGcancel(Protocol):
+ def free(self) -> None:
+ ...
+
+ def cancel(self) -> None:
+ ...
+
+
class Conninfo(Protocol):
@classmethod
def get_defaults(cls) -> List["ConninfoOption"]:
pgconn.ssl_in_use
+def test_cancel(pgconn):
+ cancel = pgconn.get_cancel()
+ cancel.cancel()
+ cancel.cancel()
+ pgconn.finish()
+ cancel.cancel()
+ with pytest.raises(pq.PQerror):
+ pgconn.get_cancel()
+
+
+def test_cancel_free(pgconn):
+ cancel = pgconn.get_cancel()
+ cancel.free()
+ with pytest.raises(pq.PQerror):
+ cancel.cancel()
+ cancel.free()
+
+
def test_make_empty_result(pgconn):
pgconn.exec_(b"wat")
res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
assert n.channel == "foo"
assert n.payload == "2"
assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+def test_cancel(conn):
+
+ errors = []
+
+ def canceller():
+ try:
+ time.sleep(0.5)
+ conn.cancel()
+ except Exception as exc:
+ errors.append(exc)
+
+ cur = conn.cursor()
+ t = threading.Thread(target=canceller)
+ t0 = time.time()
+ t.start()
+
+ with pytest.raises(psycopg3.DatabaseError):
+ cur.execute("select pg_sleep(2)")
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ conn.rollback()
+ assert cur.execute("select 1").fetchone()[0] == 1
assert n.channel == "foo"
assert n.payload == "2"
assert t1 - t0 == pytest.approx(0.5, abs=0.05)
+
+
+@pytest.mark.slow
+async def test_cancel(aconn):
+
+ errors = []
+
+ async def canceller():
+ try:
+ await asyncio.sleep(0.5)
+ aconn.cancel()
+ except Exception as exc:
+ errors.append(exc)
+
+ async def worker():
+ cur = aconn.cursor()
+ with pytest.raises(psycopg3.DatabaseError):
+ await cur.execute("select pg_sleep(2)")
+
+ workers = [worker(), canceller()]
+
+ t0 = time.time()
+ await asyncio.wait(workers)
+
+ t1 = time.time()
+ assert not errors
+ assert 0.0 < t1 - t0 < 1.0
+
+ # still working
+ await aconn.rollback()
+ cur = aconn.cursor()
+ await cur.execute("select 1")
+ assert await cur.fetchone() == (1,)