]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: check for closed connection in wait functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Oct 2025 23:49:52 +0000 (01:49 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 13 Oct 2025 01:25:16 +0000 (03:25 +0200)
Catch and display errors in a more homogeneous way: if a wait function
finds a connection closed it will raise an OperationalError chained to
the OSError obtained from stat'ing the socket. Previously control would
have gone back to the generator with a read-ready state and it would
have failed on whatever libpq function would have touched the socket.

Fix the problem reported in #608 affecting the epoll wait function, for
which we opted to use the poll function instead (which, more simply,
made the closed fd easier to spot).

note that on Windows it's not possible to use os.fstat() to ckeck a
socket state, therefore make it no-op.

psycopg/psycopg/waiting.py
psycopg_c/psycopg_c/_psycopg/waiting.pyx
tests/test_waiting.py
tests/test_waiting_async.py

index 08b3e2f90f8814402639f29a53d4a9b0f03078f9..0f3ef92e58c1e6f9cea9cd0267b4befbcf582282 100644 (file)
@@ -35,6 +35,24 @@ READY_RW = Ready.RW
 logger = logging.getLogger(__name__)
 
 
+if sys.platform != "win32":
+
+    def _check_fd_closed(fileno: int) -> None:
+        """
+        Raise OperationalError if the connection is lost.
+        """
+        try:
+            os.fstat(fileno)
+        except Exception as ex:
+            raise e.OperationalError("connection socket closed") from ex
+
+else:
+
+    # On windows we cannot use os.fstat() to check a socket.
+    def _check_fd_closed(fileno: int) -> None:
+        return
+
+
 def wait_selector(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
     """
     Wait for a generator using the best strategy available.
@@ -57,10 +75,11 @@ def wait_selector(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
             sel.register(fileno, s)
             while True:
                 if not (rlist := sel.select(timeout=interval)):
+                    # Check if it was a timeout or we were disconnected
+                    _check_fd_closed(fileno)
                     gen.send(READY_NONE)
                     continue
 
-                sel.unregister(fileno)
                 ready = rlist[0][1]
                 s = gen.send(ready)
                 sel.register(fileno, s)
@@ -146,13 +165,10 @@ async def wait_async(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
             if writer:
                 loop.add_writer(fileno, wakeup, READY_W)
             try:
-                if interval is not None:
-                    try:
-                        await wait_for(ev.wait(), interval)
-                    except TimeoutError:
-                        pass
-                else:
-                    await ev.wait()
+                try:
+                    await wait_for(ev.wait(), interval)
+                except TimeoutError:
+                    pass
             finally:
                 if reader:
                     loop.remove_reader(fileno)
@@ -162,7 +178,7 @@ async def wait_async(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
 
     except OSError as ex:
         # Assume the connection was closed
-        raise e.OperationalError(str(ex))
+        raise e.OperationalError("connection socket closed") from ex
     except StopIteration as ex:
         rv: RV = ex.value
         return rv
@@ -252,17 +268,21 @@ def wait_select(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
                 fnlist,
                 interval,
             )
+            if xl:
+                _check_fd_closed(fileno)
+                # Unlikely: the exception should have been raised above
+                raise e.OperationalError("connection socket closed")
             ready = 0
             if rl:
                 ready = READY_R
             if wl:
                 ready |= READY_W
-            if not ready:
-                gen.send(READY_NONE)
-                continue
 
             s = gen.send(ready)
 
+    except OSError as ex:
+        # This happens on macOS but not on Linux (the xl list is set)
+        raise e.OperationalError("connection socket closed") from ex
     except StopIteration as ex:
         rv: RV = ex.value
         return rv
@@ -307,6 +327,7 @@ def wait_epoll(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
             epoll.register(fileno, evmask)
             while True:
                 if not (fileevs := epoll.poll(interval)):
+                    _check_fd_closed(fileno)
                     gen.send(READY_NONE)
                     continue
                 ev = fileevs[0][1]
@@ -330,6 +351,7 @@ if hasattr(selectors, "PollSelector"):
         WAIT_W: select.POLLOUT,
         WAIT_RW: select.POLLIN | select.POLLOUT,
     }
+    POLL_BAD = ~(select.POLLIN | select.POLLOUT)
 else:
     _poll_evmasks = {}
 
@@ -359,6 +381,11 @@ def wait_poll(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
                 continue
 
             ev = fileevs[0][1]
+            if ev & POLL_BAD:
+                _check_fd_closed(fileno)
+                # Unlikely: the exception should have been raised above
+                raise e.OperationalError("connection socket closed")
+
             ready = 0
             if ev & select.POLLIN:
                 ready = READY_R
index 03e91bb693f50aad1761c91aa9e333c98fcc320b..3c22a6aca598d376da815270d5e281ca39f79cc9 100644 (file)
@@ -5,13 +5,17 @@ C implementation of waiting functions
 # Copyright (C) 2022 The Psycopg Team
 
 from cpython.object cimport PyObject_CallFunctionObjArgs
+
+from os import fstat
 from typing import TypeVar
 
+from psycopg import errors as e
+
 RV = TypeVar("RV")
 
 
 cdef extern from *:
-    """
+    r"""
 #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
 
 #if defined(HAVE_POLL_H)
@@ -32,6 +36,8 @@ cdef extern from *:
 
 #define SELECT_EV_READ 1
 #define SELECT_EV_WRITE 2
+#define CWAIT_SELECT_ERROR -1
+#define CWAIT_SOCKET_ERROR -2
 
 #define SEC_TO_MS 1000
 #define SEC_TO_US (1000 * 1000)
@@ -39,8 +45,9 @@ cdef extern from *:
 /* Use select to wait for readiness on fileno.
  *
  * - Return SELECT_EV_* if the file is ready
- * - Return 0 on timeout
- * - Return -1 (and set an exception) on error.
+ * - Return SELECT_EV_NONE on timeout
+ * - Return CWAIT_SELECT_ERROR (and set an exception) on error.
+ * - Return CWAIT_SOCKET_ERROR on poll success but fd error.
  *
  * The wisdom of this function comes from:
  *
@@ -51,7 +58,7 @@ static int
 wait_c_impl(int fileno, int wait, float timeout)
 {
     int select_rv;
-    int rv = -1;
+    int rv = CWAIT_SELECT_ERROR;
 
 #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
 
@@ -88,8 +95,13 @@ retry_eintr:
 
     rv = 0;  /* success, maybe with timeout */
     if (select_rv >= 0) {
-        if (input_fd.revents & POLLIN) { rv |= SELECT_EV_READ; }
-        if (input_fd.revents & POLLOUT) { rv |= SELECT_EV_WRITE; }
+        if (input_fd.revents & ~(POLLIN | POLLOUT)) {
+            rv = CWAIT_SOCKET_ERROR;
+        }
+        else {
+            if (input_fd.revents & POLLIN) { rv |= SELECT_EV_READ; }
+            if (input_fd.revents & POLLOUT) { rv |= SELECT_EV_WRITE; }
+        }
     }
 
 #else
@@ -104,7 +116,7 @@ retry_eintr:
         PyErr_SetString(
             PyExc_ValueError,  /* same exception of Python's 'select.select()' */
             "connection file descriptor out of range for 'select()'");
-        return -1;
+        return CWAIT_SELECT_ERROR;
     }
 #endif
 
@@ -143,8 +155,18 @@ retry_eintr:
 
     rv = 0;
     if (select_rv > 0) {
-        if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
-        if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+        if (!FD_ISSET(fileno, &efds)) {
+            if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
+            if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+        }
+        else {
+            /* There is an error on the FD. Assume it means it is closed. We
+             * want to raise a chained exception, which is tricky in C, so
+             * return the special value CWAIT_SOCKET_ERROR to signal the Cython
+             * wrapper to check the fd and raise the appropriate exception.
+             */
+            rv = CWAIT_SOCKET_ERROR;
+        }
     }
 
 #endif  /* HAVE_POLL */
@@ -153,7 +175,7 @@ retry_eintr:
 
 error:
 
-    rv = -1;
+    rv = CWAIT_SELECT_ERROR;
 
 #ifdef MS_WINDOWS
     if (select_rv == SOCKET_ERROR) {
@@ -175,6 +197,7 @@ finally:
 }
     """
     cdef int wait_c_impl(int fileno, int wait, float timeout) except -1
+    cdef int CWAIT_SOCKET_ERROR
 
 
 def wait_c(gen: PQGen[RV], int fileno, interval = 0.0) -> RV:
@@ -207,6 +230,13 @@ def wait_c(gen: PQGen[RV], int fileno, interval = 0.0) -> RV:
                 pyready = <PyObject *>PY_READY_RW
             elif ready == READY_W:
                 pyready = <PyObject *>PY_READY_W
+            elif ready == CWAIT_SOCKET_ERROR:  # FD closed?
+                try:
+                    fstat(fileno)
+                except Exception as ex:
+                    raise e.OperationalError("connection socket closed") from ex
+                else:
+                    raise e.OperationalError("connection socket closed")
             else:
                 raise AssertionError(f"unexpected ready value: {ready}")
 
index 78eae8652fd57bdf3cfe80dbac4163f413bce6c4..a3fa0f710216c404e7990972b365cbbce8aa03a5 100644 (file)
@@ -321,6 +321,36 @@ def test_wait_large_fd(dsn, waitfn):
             f.close()
 
 
+@pytest.mark.slow
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.skipif(
+    "sys.platform == 'win32'", reason="Windows doesn't see the fd closed."
+)
+def test_socket_closed(dsn, waitfn, pgconn):
+    waitfn = getattr(waiting, waitfn)
+
+    def closer():
+        sleep(0.5)
+        pgconn.finish()
+
+    t = spawn(closer)
+
+    pgconn.send_query(b"select pg_sleep(2)")
+    with pytest.raises(
+        psycopg.OperationalError, match="connection socket closed"
+    ) as ex:
+        t0 = time.time()
+        gen = generators.execute(pgconn)
+        waitfn(gen, pgconn.socket, 0.1)
+
+    assert pgconn.status == ConnStatus.BAD
+    assert isinstance(ex.value.__cause__, OSError)
+    dt = time.time() - t0
+    gather(t)
+
+    assert dt < 1.0
+
+
 @pytest.mark.parametrize("waitfn", waitfns)
 def test_wait_timeout_none_unsupported(waitfn):
     waitfn = getattr(waiting, waitfn)
index 212c9e12ddd6f5a7fd0c219c59e990eafebba56f..e274fc1f52444039d4ce7ebe061f2a8c24c07ce8 100644 (file)
@@ -329,6 +329,36 @@ async def test_wait_large_fd(dsn, waitfn):
             f.close()
 
 
+@pytest.mark.slow
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.skipif(
+    "sys.platform == 'win32'", reason="Windows doesn't see the fd closed."
+)
+async def test_socket_closed(dsn, waitfn, pgconn):
+    waitfn = getattr(waiting, waitfn)
+
+    async def closer():
+        await asleep(0.5)
+        pgconn.finish()
+
+    t = spawn(closer)
+
+    pgconn.send_query(b"select pg_sleep(2)")
+    with pytest.raises(
+        psycopg.OperationalError, match="connection socket closed"
+    ) as ex:
+        t0 = time.time()
+        gen = generators.execute(pgconn)
+        await waitfn(gen, pgconn.socket, 0.1)
+
+    assert pgconn.status == ConnStatus.BAD
+    assert isinstance(ex.value.__cause__, OSError)
+    dt = time.time() - t0
+    await gather(t)
+
+    assert dt < 1.0
+
+
 @pytest.mark.parametrize("waitfn", waitfns)
 async def test_wait_timeout_none_unsupported(waitfn):
     waitfn = getattr(waiting, waitfn)