]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: add generators.cancel()
authorDenis Laxalde <denis.laxalde@dalibo.com>
Fri, 24 Mar 2023 13:08:20 +0000 (14:08 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 9 Apr 2024 10:07:43 +0000 (12:07 +0200)
This is a PQGenConn generator as the socket for the PGcancelConn needs
to be retrieved after (at least) the first poll() call.

psycopg/psycopg/generators.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/generators.pyx
tests/test_generators.py

index 96143af939f66d00a13696e341ca05f6c7cfd7a3..0c6098cf63ad12517136000b2760ce5ca7604199 100644 (file)
@@ -27,7 +27,7 @@ from typing import List, Optional, Union
 from . import pq
 from . import errors as e
 from .abc import Buffer, PipelineCommand, PQGen, PQGenConn
-from .pq.abc import PGconn, PGresult
+from .pq.abc import PGcancelConn, PGconn, PGresult
 from .waiting import Wait, Ready
 from ._compat import Deque
 from ._cmodule import _psycopg
@@ -100,6 +100,23 @@ def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]:
     return conn
 
 
+def _cancel(cancel_conn: PGcancelConn) -> PQGenConn[None]:
+    while True:
+        status = cancel_conn.poll()
+        if status == POLL_OK:
+            break
+        elif status == POLL_READING:
+            yield cancel_conn.socket, WAIT_R
+        elif status == POLL_WRITING:
+            yield cancel_conn.socket, WAIT_W
+        elif status == POLL_FAILED:
+            raise e.OperationalError(
+                f"cancellation failed: {cancel_conn.error_message}"
+            )
+        else:
+            raise e.InternalError(f"unexpected poll status: {status}")
+
+
 def _execute(pgconn: PGconn) -> PQGen[List[PGresult]]:
     """
     Generator sending a query and returning results without blocking.
@@ -357,6 +374,7 @@ def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
 # Override functions with fast versions if available
 if _psycopg:
     connect = _psycopg.connect
+    cancel = _psycopg.cancel
     execute = _psycopg.execute
     send = _psycopg.send
     fetch_many = _psycopg.fetch_many
@@ -365,6 +383,7 @@ if _psycopg:
 
 else:
     connect = _connect
+    cancel = _cancel
     execute = _execute
     send = _send
     fetch_many = _fetch_many
index ec976eb5c9275abfe5786e3a396bb102e7e486e2..3881501918e4f459ba4988639e8a7c10785dae45 100644 (file)
@@ -7,12 +7,12 @@ information. Will submit a bug.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from typing import Any, Iterable, List, Optional, Sequence, Tuple
+from typing import Any, List, Optional, Sequence, Tuple
 
 from psycopg import pq, abc, BaseConnection
 from psycopg.rows import Row, RowMaker
 from psycopg.adapt import AdaptersMap, PyFormat
-from psycopg.pq.abc import PGconn, PGresult
+from psycopg.pq.abc import PGcancelConn, PGconn, PGresult
 from psycopg._compat import Deque
 
 class Transformer(abc.AdaptContext):
@@ -52,6 +52,7 @@ class Transformer(abc.AdaptContext):
 
 # Generators
 def connect(conninfo: str, *, timeout: float = 0.0) -> abc.PQGenConn[PGconn]: ...
+def cancel(cancel_conn: PGcancelConn) -> abc.PQGenConn[None]: ...
 def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
 def send(pgconn: PGconn) -> abc.PQGen[None]: ...
 def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
index 1b2be5f60eeb96ddc432cf6540452aaf32857f7c..7198fddf152bbc199ca9b91d8ae5ce6e403f393b 100644 (file)
@@ -81,6 +81,26 @@ def connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[abc.PGconn]:
     return conn
 
 
+def cancel(pq.PGcancelConn cancel_conn) -> PQGenConn[None]:
+    cdef libpq.PGcancelConn *pgcancelconn_ptr = cancel_conn.pgcancelconn_ptr
+    cdef int status
+    while True:
+        with nogil:
+            status = libpq.PQcancelPoll(pgcancelconn_ptr)
+        if status == libpq.PGRES_POLLING_OK:
+            break
+        elif status == libpq.PGRES_POLLING_READING:
+            yield libpq.PQcancelSocket(pgcancelconn_ptr), WAIT_R
+        elif status == libpq.PGRES_POLLING_WRITING:
+            yield libpq.PQcancelSocket(pgcancelconn_ptr), WAIT_W
+        elif status == libpq.PGRES_POLLING_FAILED:
+            raise e.OperationalError(
+                f"cancellation failed: {cancel_conn.error_message}"
+            )
+        else:
+            raise e.InternalError(f"unexpected poll status: {status}")
+
+
 def execute(pq.PGconn pgconn) -> PQGen[List[abc.PGresult]]:
     """
     Generator sending a query and returning results without blocking.
index 2df55e3e08eede59d72df371e28a4e3cd741ffa2..8397c203b649212efb16b88bf3cb88707282ef4f 100644 (file)
@@ -1,3 +1,4 @@
+import time
 from collections import deque
 from functools import partial
 from typing import List
@@ -44,6 +45,30 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch):
         pgconn.exec_(b"select 1")
 
 
+@pytest.mark.libpq(">= 17")
+def test_cancel(pgconn, conn, generators):
+    pgconn.send_query_params(b"SELECT pg_sleep($1)", [b"180"])
+    while not conn.execute(
+        "SELECT count(*) FROM pg_stat_activity"
+        " WHERE query = 'SELECT pg_sleep($1)'"
+        " AND state = 'active'"
+    ).fetchone():
+        time.sleep(0.01)
+    cancel_conn = pgconn.cancel_conn()
+    assert cancel_conn.status != pq.ConnStatus.BAD
+    cancel_conn.start()
+    gen = generators.cancel(cancel_conn)
+    waiting.wait_conn(gen)
+    assert cancel_conn.status == pq.ConnStatus.OK
+
+    res = pgconn.get_result()
+    assert res is not None
+    assert res.status == pq.ExecStatus.FATAL_ERROR
+    assert res.error_field(pq.DiagnosticField.SQLSTATE) == b"57014"
+    while pgconn.is_busy():
+        pgconn.consume_input()
+
+
 @pytest.fixture
 def pipeline(pgconn):
     nb, pgconn.nonblocking = pgconn.nonblocking, True