logger = logging.getLogger(__name__)
+if sys.platform != "win32":
+
+ def _check_fd_closed(fileno: int) -> None:
+ """
+ Raise OperationalError if the connection is lost.
+ """
+ try:
+ os.fstat(fileno)
+ except Exception as ex:
+ raise e.OperationalError("connection socket closed") from ex
+
+else:
+
+ # On windows we cannot use os.fstat() to check a socket.
+ def _check_fd_closed(fileno: int) -> None:
+ return
+
+
def wait_selector(gen: PQGen[RV], fileno: int, interval: float = 0.0) -> RV:
"""
Wait for a generator using the best strategy available.
sel.register(fileno, s)
while True:
if not (rlist := sel.select(timeout=interval)):
+ # Check if it was a timeout or we were disconnected
+ _check_fd_closed(fileno)
gen.send(READY_NONE)
continue
- sel.unregister(fileno)
ready = rlist[0][1]
s = gen.send(ready)
sel.register(fileno, s)
if writer:
loop.add_writer(fileno, wakeup, READY_W)
try:
- if interval is not None:
- try:
- await wait_for(ev.wait(), interval)
- except TimeoutError:
- pass
- else:
- await ev.wait()
+ try:
+ await wait_for(ev.wait(), interval)
+ except TimeoutError:
+ pass
finally:
if reader:
loop.remove_reader(fileno)
except OSError as ex:
# Assume the connection was closed
- raise e.OperationalError(str(ex))
+ raise e.OperationalError("connection socket closed") from ex
except StopIteration as ex:
rv: RV = ex.value
return rv
fnlist,
interval,
)
+ if xl:
+ _check_fd_closed(fileno)
+ # Unlikely: the exception should have been raised above
+ raise e.OperationalError("connection socket closed")
ready = 0
if rl:
ready = READY_R
if wl:
ready |= READY_W
- if not ready:
- gen.send(READY_NONE)
- continue
s = gen.send(ready)
+ except OSError as ex:
+ # This happens on macOS but not on Linux (the xl list is set)
+ raise e.OperationalError("connection socket closed") from ex
except StopIteration as ex:
rv: RV = ex.value
return rv
epoll.register(fileno, evmask)
while True:
if not (fileevs := epoll.poll(interval)):
+ _check_fd_closed(fileno)
gen.send(READY_NONE)
continue
ev = fileevs[0][1]
WAIT_W: select.POLLOUT,
WAIT_RW: select.POLLIN | select.POLLOUT,
}
+ POLL_BAD = ~(select.POLLIN | select.POLLOUT)
else:
_poll_evmasks = {}
continue
ev = fileevs[0][1]
+ if ev & POLL_BAD:
+ _check_fd_closed(fileno)
+ # Unlikely: the exception should have been raised above
+ raise e.OperationalError("connection socket closed")
+
ready = 0
if ev & select.POLLIN:
ready = READY_R
# Copyright (C) 2022 The Psycopg Team
from cpython.object cimport PyObject_CallFunctionObjArgs
+
+from os import fstat
from typing import TypeVar
+from psycopg import errors as e
+
RV = TypeVar("RV")
cdef extern from *:
- """
+ r"""
#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
#if defined(HAVE_POLL_H)
#define SELECT_EV_READ 1
#define SELECT_EV_WRITE 2
+#define CWAIT_SELECT_ERROR -1
+#define CWAIT_SOCKET_ERROR -2
#define SEC_TO_MS 1000
#define SEC_TO_US (1000 * 1000)
/* Use select to wait for readiness on fileno.
*
* - Return SELECT_EV_* if the file is ready
- * - Return 0 on timeout
- * - Return -1 (and set an exception) on error.
+ * - Return SELECT_EV_NONE on timeout
+ * - Return CWAIT_SELECT_ERROR (and set an exception) on error.
+ * - Return CWAIT_SOCKET_ERROR on poll success but fd error.
*
* The wisdom of this function comes from:
*
wait_c_impl(int fileno, int wait, float timeout)
{
int select_rv;
- int rv = -1;
+ int rv = CWAIT_SELECT_ERROR;
#if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
rv = 0; /* success, maybe with timeout */
if (select_rv >= 0) {
- if (input_fd.revents & POLLIN) { rv |= SELECT_EV_READ; }
- if (input_fd.revents & POLLOUT) { rv |= SELECT_EV_WRITE; }
+ if (input_fd.revents & ~(POLLIN | POLLOUT)) {
+ rv = CWAIT_SOCKET_ERROR;
+ }
+ else {
+ if (input_fd.revents & POLLIN) { rv |= SELECT_EV_READ; }
+ if (input_fd.revents & POLLOUT) { rv |= SELECT_EV_WRITE; }
+ }
}
#else
PyErr_SetString(
PyExc_ValueError, /* same exception of Python's 'select.select()' */
"connection file descriptor out of range for 'select()'");
- return -1;
+ return CWAIT_SELECT_ERROR;
}
#endif
rv = 0;
if (select_rv > 0) {
- if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
- if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+ if (!FD_ISSET(fileno, &efds)) {
+ if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
+ if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+ }
+ else {
+ /* There is an error on the FD. Assume it means it is closed. We
+ * want to raise a chained exception, which is tricky in C, so
+ * return the special value CWAIT_SOCKET_ERROR to signal the Cython
+ * wrapper to check the fd and raise the appropriate exception.
+ */
+ rv = CWAIT_SOCKET_ERROR;
+ }
}
#endif /* HAVE_POLL */
error:
- rv = -1;
+ rv = CWAIT_SELECT_ERROR;
#ifdef MS_WINDOWS
if (select_rv == SOCKET_ERROR) {
}
"""
cdef int wait_c_impl(int fileno, int wait, float timeout) except -1
+ cdef int CWAIT_SOCKET_ERROR
def wait_c(gen: PQGen[RV], int fileno, interval = 0.0) -> RV:
pyready = <PyObject *>PY_READY_RW
elif ready == READY_W:
pyready = <PyObject *>PY_READY_W
+ elif ready == CWAIT_SOCKET_ERROR: # FD closed?
+ try:
+ fstat(fileno)
+ except Exception as ex:
+ raise e.OperationalError("connection socket closed") from ex
+ else:
+ raise e.OperationalError("connection socket closed")
else:
raise AssertionError(f"unexpected ready value: {ready}")
f.close()
+@pytest.mark.slow
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.skipif(
+ "sys.platform == 'win32'", reason="Windows doesn't see the fd closed."
+)
+def test_socket_closed(dsn, waitfn, pgconn):
+ waitfn = getattr(waiting, waitfn)
+
+ def closer():
+ sleep(0.5)
+ pgconn.finish()
+
+ t = spawn(closer)
+
+ pgconn.send_query(b"select pg_sleep(2)")
+ with pytest.raises(
+ psycopg.OperationalError, match="connection socket closed"
+ ) as ex:
+ t0 = time.time()
+ gen = generators.execute(pgconn)
+ waitfn(gen, pgconn.socket, 0.1)
+
+ assert pgconn.status == ConnStatus.BAD
+ assert isinstance(ex.value.__cause__, OSError)
+ dt = time.time() - t0
+ gather(t)
+
+ assert dt < 1.0
+
+
@pytest.mark.parametrize("waitfn", waitfns)
def test_wait_timeout_none_unsupported(waitfn):
waitfn = getattr(waiting, waitfn)
f.close()
+@pytest.mark.slow
+@pytest.mark.parametrize("waitfn", waitfns)
+@pytest.mark.skipif(
+ "sys.platform == 'win32'", reason="Windows doesn't see the fd closed."
+)
+async def test_socket_closed(dsn, waitfn, pgconn):
+ waitfn = getattr(waiting, waitfn)
+
+ async def closer():
+ await asleep(0.5)
+ pgconn.finish()
+
+ t = spawn(closer)
+
+ pgconn.send_query(b"select pg_sleep(2)")
+ with pytest.raises(
+ psycopg.OperationalError, match="connection socket closed"
+ ) as ex:
+ t0 = time.time()
+ gen = generators.execute(pgconn)
+ await waitfn(gen, pgconn.socket, 0.1)
+
+ assert pgconn.status == ConnStatus.BAD
+ assert isinstance(ex.value.__cause__, OSError)
+ dt = time.time() - t0
+ await gather(t)
+
+ assert dt < 1.0
+
+
@pytest.mark.parametrize("waitfn", waitfns)
async def test_wait_timeout_none_unsupported(waitfn):
waitfn = getattr(waiting, waitfn)