]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: don't hang forever if async connection is closed while querying
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 26 Sep 2023 16:29:47 +0000 (18:29 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 26 Sep 2023 20:47:35 +0000 (22:47 +0200)
Fix #608

docs/news.rst
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/waiting.py
tests/test_concurrency.py
tests/test_concurrency_async.py

index ec9c44fa7b9ab5afd1649878ef5805e07e79fe8d..2c5f81174123d0b7af76fb61e4c118ced0a93c62 100644 (file)
@@ -10,6 +10,7 @@
 Psycopg 3.1.12 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
 
+- Fix hanging if an async connection is closed while querying (:ticket:`#608`).
 - Fix memory leak when `~register_*()` functions are called repeatedly
   (:ticket:`#647`).
 
index 82d597cbb669e250755990f16185d1e1d94d7134..ca9305394cd7fb94e079fb1b15ea561b0368b08d 100644 (file)
@@ -800,6 +800,9 @@ class Connection(BaseConnection[Row]):
         if self.closed:
             return
         self._closed = True
+
+        # TODO: maybe send a cancel on close, if the connection is ACTIVE?
+
         self.pgconn.finish()
 
     @overload
index 34e7834883dce5f938d5d0bd08b964f7a7872b90..416d00cee9cf7da99e8433ffcf631739827c4ea0 100644 (file)
@@ -197,6 +197,9 @@ class AsyncConnection(BaseConnection[Row]):
         if self.closed:
             return
         self._closed = True
+
+        # TODO: maybe send a cancel on close, if the connection is ACTIVE?
+
         self.pgconn.finish()
 
     @overload
@@ -343,15 +346,15 @@ class AsyncConnection(BaseConnection[Row]):
                     assert pipeline is self._pipeline
                     self._pipeline = None
 
-    async def wait(self, gen: PQGen[RV]) -> RV:
+    async def wait(self, gen: PQGen[RV], timeout: Optional[float] = 0.1) -> RV:
         try:
-            return await waiting.wait_async(gen, self.pgconn.socket)
+            return await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
         except (asyncio.CancelledError, KeyboardInterrupt):
             # On Ctrl-C, try to cancel the query in the server, otherwise
             # the connection will remain stuck in ACTIVE state.
             self._try_cancel(self.pgconn)
             try:
-                await waiting.wait_async(gen, self.pgconn.socket)
+                await waiting.wait_async(gen, self.pgconn.socket, timeout=timeout)
             except e.QueryCanceled:
                 pass  # as expected
             raise
index 80827152c47363067cede2cf4e33c7b4cb716213..e31896c898ebe09639dc1cf27f8497fcbce9a64f 100644 (file)
@@ -97,7 +97,9 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
         return rv
 
 
-async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
+async def wait_async(
+    gen: PQGen[RV], fileno: int, timeout: Optional[float] = None
+) -> RV:
     """
     Coroutine waiting for a generator to complete.
 
@@ -134,7 +136,13 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
             if writer:
                 loop.add_writer(fileno, wakeup, READY_W)
             try:
-                await ev.wait()
+                if timeout is None:
+                    await ev.wait()
+                else:
+                    try:
+                        await wait_for(ev.wait(), timeout)
+                    except TimeoutError:
+                        pass
             finally:
                 if reader:
                     loop.remove_reader(fileno)
index 230007aeb756c848a7791e2d13d2d27cc65bc952..9dbe9ace9e642237f76a793d971aad45ddf51e2b 100644 (file)
@@ -392,3 +392,42 @@ if __name__ == '__main__':
     env["PYTHONFAULTHANDLER"] = "1"
     out = sp.check_output([sys.executable, "-s", "-c", script], env=env)
     assert out.decode().rstrip() == "[1, 1]"
+
+
+@pytest.mark.slow
+@pytest.mark.crdb("skip")
+@pytest.mark.skipif(
+    sys.platform == "win32",
+    reason="Fails with: An operation was attempted on something that is not a socket",
+)
+def test_concurrent_close(dsn, conn):
+    # Verify something similar to the problem in #608, which doesn't affect
+    # sync connections anyway.
+    pid = conn.info.backend_pid
+    conn.autocommit = True
+
+    def worker():
+        try:
+            conn.execute("select pg_sleep(3)")
+        except psycopg.OperationalError:
+            pass  # expected
+
+    t0 = time.time()
+    th = threading.Thread(target=worker)
+    th.start()
+    time.sleep(0.5)
+    with psycopg.connect(dsn, autocommit=True) as conn1:
+        cur = conn1.execute("select query from pg_stat_activity where pid = %s", [pid])
+        assert cur.fetchone()
+        conn.close()
+        th.join()
+        time.sleep(0.5)
+        t = time.time()
+        # TODO: this check can pass if we issue a cancel on close, which is
+        # a change in behaviour to be considered better.
+        # cur = conn1.execute(
+        #     "select query from pg_stat_activity where pid = %s",
+        #     [pid],
+        # )
+        # assert not cur.fetchone()
+        assert t - t0 < 2
index 1be029798273fb3d1ea0b055c92332d38d007bfe..017bbd79e419ff050f3fc3e38f09fd8d04c7e716 100644 (file)
@@ -313,3 +313,47 @@ asyncio.run(main())
 
     t1 = time.time()
     assert t1 - t0 < 1.0
+
+
+@pytest.mark.slow
+@pytest.mark.crdb("skip")
+@pytest.mark.skipif(
+    sys.platform == "win32",
+    reason="Fails with: An operation was attempted on something that is not a socket",
+)
+async def test_concurrent_close(dsn, aconn):
+    # Test issue #608: concurrent closing shouldn't hang the server
+    # (although, at the moment, it doesn't cancel a running query).
+    pid = aconn.info.backend_pid
+    await aconn.set_autocommit(True)
+
+    async def worker():
+        try:
+            await aconn.execute("select pg_sleep(3)")
+        except psycopg.OperationalError:
+            pass  # expected
+
+    t0 = time.time()
+    task = create_task(worker())
+    await asyncio.sleep(0.5)
+
+    async def test():
+        async with await psycopg.AsyncConnection.connect(dsn, autocommit=True) as conn1:
+            cur = await conn1.execute(
+                "select query from pg_stat_activity where pid = %s", [pid]
+            )
+            assert await cur.fetchone()
+            await aconn.close()
+            await asyncio.gather(task)
+            await asyncio.sleep(0.5)
+            t = time.time()
+            # TODO: this statement can pass only if we send cancel on close
+            # but because async cancelling is not available in the libpq,
+            # we'd rather not do it.
+            # cur = await conn1.execute(
+            #     "select query from pg_stat_activity where pid = %s", [pid]
+            # )
+            # assert not await cur.fetchone()
+            assert t - t0 < 2
+
+    await asyncio.wait_for(test(), 5.0)