]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix(async): cancel query upon receiving CanceledError in async wait
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 9 Apr 2023 18:29:44 +0000 (20:29 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 10 Apr 2023 08:45:44 +0000 (10:45 +0200)
This interrupt a running query upon Ctrl-C for example, which wasn't
working as it was on sync connection.

Close #543.

docs/news.rst
psycopg/psycopg/connection_async.py
tests/test_concurrency.py
tests/test_concurrency_async.py

index f0f547f3ef4fb558fc363816eda95147430638c5..deff90a0f7f4c8889b7bc5edae90674485f2f101 100644 (file)
@@ -15,6 +15,8 @@ Psycopg 3.1.9 (unreleased)
 
 - Fix `TypeInfo.fetch()` using a connection in `!sql_ascii` encoding
   (:ticket:`#503`).
+- Fix canceling running queries on process interruption in async connections
+  (:ticket:`#543`).
 
 
 Current release
index 6d95d7641218fe3705bf99b401c6bc79105e94f4..f6842c6ed47ebbecfcc049dec7813fc0ee388ca7 100644 (file)
@@ -346,13 +346,9 @@ class AsyncConnection(BaseConnection[Row]):
     async def wait(self, gen: PQGen[RV]) -> RV:
         try:
             return await waiting.wait_async(gen, self.pgconn.socket)
-        except KeyboardInterrupt:
-            # TODO: this doesn't seem to work as it does for sync connections
-            # see tests/test_concurrency_async.py::test_ctrl_c
-            # In the test, the code doesn't reach this branch.
-
+        except (asyncio.CancelledError, KeyboardInterrupt):
             # On Ctrl-C, try to cancel the query in the server, otherwise
-            # otherwise the connection will be stuck in ACTIVE state
+            # the connection will remain stuck in ACTIVE state.
             c = self.pgconn.get_cancel()
             c.cancel()
             try:
index eec24f1df1995f13aed1779d3a1e24ebe48adf92..7a5119c88f0389c82243c129ceda860e73964a54 100644 (file)
@@ -242,7 +242,7 @@ def test_identify_closure(conn_cls, dsn):
     sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
 )
 @pytest.mark.crdb_skip("cancel")
-def test_ctrl_c(dsn):
+def test_ctrl_c_handler(dsn):
     if sys.platform == "win32":
         sig = int(signal.CTRL_C_EVENT)
         # Or pytest will receive the Ctrl-C too
@@ -293,6 +293,72 @@ with psycopg.connect({dsn!r}) as conn:
     assert 1 < t < 2
 
 
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+    sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
+)
+@pytest.mark.crdb("skip")
+def test_ctrl_c(conn, dsn):
+    conn.autocommit = True
+
+    APPNAME = "test_ctrl_c"
+    script = f"""\
+import psycopg
+
+with psycopg.connect({dsn!r}, application_name={APPNAME!r}) as conn:
+    conn.execute("select pg_sleep(60)")
+"""
+
+    if sys.platform == "win32":
+        creationflags = sp.CREATE_NEW_PROCESS_GROUP
+        sig = signal.CTRL_C_EVENT
+    else:
+        creationflags = 0
+        sig = signal.SIGINT
+
+    proc = None
+
+    def run_process():
+        nonlocal proc
+        proc = sp.Popen(
+            [sys.executable, "-s", "-c", script],
+            creationflags=creationflags,
+        )
+        proc.communicate()
+
+    t = threading.Thread(target=run_process)
+    t.start()
+
+    for i in range(20):
+        cur = conn.execute(
+            "select pid from pg_stat_activity where application_name = %s", (APPNAME,)
+        )
+        rec = cur.fetchone()
+        if rec:
+            pid = rec[0]
+            break
+        time.sleep(0.1)
+    else:
+        assert False, "process didn't start?"
+
+    t0 = time.time()
+    assert proc
+    proc.send_signal(sig)
+    proc.wait()
+
+    for i in range(20):
+        cur = conn.execute("select 1 from pg_stat_activity where pid = %s", (pid,))
+        if not cur.fetchone():
+            break
+        time.sleep(0.1)
+    else:
+        assert False, "process didn't stop?"
+
+    t1 = time.time()
+    assert t1 - t0 < 1.0
+
+
 @pytest.mark.slow
 @pytest.mark.subprocess
 @pytest.mark.skipif(
index 67bb6afbd5fdc75bd2b0672f223b1d1985c2d394..80d3fc5cc347cff465ca5a6b866a416edd7e56ab 100644 (file)
@@ -2,6 +2,7 @@ import sys
 import time
 import signal
 import asyncio
+import threading
 import subprocess as sp
 from asyncio.queues import Queue
 from typing import List, Tuple
@@ -192,7 +193,7 @@ async def test_identify_closure(aconn_cls, dsn):
     sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
 )
 @pytest.mark.crdb_skip("cancel")
-def test_ctrl_c(dsn):
+def test_ctrl_c_handler(dsn):
     script = f"""\
 import signal
 import asyncio
@@ -238,3 +239,76 @@ asyncio.run(main())
     proc.send_signal(sig)
     proc.communicate()
     assert proc.returncode == 0
+
+
+@pytest.mark.slow
+@pytest.mark.subprocess
+@pytest.mark.skipif(
+    sys.platform == "win32", reason="don't know how to Ctrl-C on Windows"
+)
+@pytest.mark.crdb("skip")
+def test_ctrl_c(conn, dsn):
+    # https://github.com/psycopg/psycopg/issues/543
+    conn.autocommit = True
+
+    APPNAME = "test_ctrl_c"
+    script = f"""\
+import asyncio
+import psycopg
+
+async def main():
+    async with await psycopg.AsyncConnection.connect(
+        {dsn!r}, application_name={APPNAME!r}
+    ) as conn:
+        await conn.execute("select pg_sleep(5)")
+
+asyncio.run(main())
+"""
+    if sys.platform == "win32":
+        creationflags = sp.CREATE_NEW_PROCESS_GROUP
+        sig = signal.CTRL_C_EVENT
+    else:
+        creationflags = 0
+        sig = signal.SIGINT
+
+    proc = None
+
+    def run_process():
+        nonlocal proc
+        proc = sp.Popen(
+            [sys.executable, "-s", "-c", script],
+            creationflags=creationflags,
+            stderr=sp.PIPE,
+        )
+        proc.communicate()
+
+    t = threading.Thread(target=run_process)
+    t.start()
+
+    for i in range(20):
+        cur = conn.execute(
+            "select pid from pg_stat_activity where application_name = %s", (APPNAME,)
+        )
+        rec = cur.fetchone()
+        if rec:
+            pid = rec[0]
+            break
+        time.sleep(0.1)
+    else:
+        assert False, "process didn't start?"
+
+    t0 = time.time()
+    assert proc
+    proc.send_signal(sig)
+    proc.wait()
+
+    for i in range(20):
+        cur = conn.execute("select 1 from pg_stat_activity where pid = %s", (pid,))
+        if not cur.fetchone():
+            break
+        time.sleep(0.1)
+    else:
+        assert False, "process didn't stop?"
+
+    t1 = time.time()
+    assert t1 - t0 < 1.0