]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
fix: include timeout as part of the generators/wait conversation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 24 Jan 2024 19:51:30 +0000 (19:51 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 29 Jan 2024 02:25:32 +0000 (02:25 +0000)
So far, the wait functions would have shielded the generators from a
wait timeout. However this behaviour makes impossible to make a
generator interruptible.

Note that the `wait_c` generator was interruptible, but probably it
wasn't doing the right thing. In the `poll` branch, I understand that
the returned ready value, in case of timeout, would have been the same
of the input wait value, because of the input/output nature of the
pollfd struct; I haven't analyzed more deeply the select() case.

psycopg/psycopg/_enums.py
psycopg/psycopg/abc.py
psycopg/psycopg/generators.py
psycopg/psycopg/waiting.py
psycopg_c/psycopg_c/_psycopg/generators.pyx
psycopg_c/psycopg_c/_psycopg/waiting.pyx
tests/test_waiting.py

index a7cb78df4c2123c008f81fc13b849dbf47a332ad..1975650c6654138862ec4f8542c7d8e88994befd 100644 (file)
@@ -20,6 +20,7 @@ class Wait(IntEnum):
 
 
 class Ready(IntEnum):
+    NONE = 0
     R = EVENT_READ
     W = EVENT_WRITE
     RW = EVENT_READ | EVENT_WRITE
index 58111ff23510e9aae8c49e7896b477f6d7e5236b..ad4a96646c13643139220adab2a561305312d78a 100644 (file)
@@ -39,13 +39,13 @@ ConnMapping: TypeAlias = Mapping[str, ConnParam]
 
 RV = TypeVar("RV")
 
-PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], "Ready", RV]
+PQGenConn: TypeAlias = Generator[Tuple[int, "Wait"], Union["Ready", int], RV]
 """Generator for processes where the connection file number can change.
 
 This can happen in connection and reset, but not in normal querying.
 """
 
-PQGen: TypeAlias = Generator["Wait", "Ready", RV]
+PQGen: TypeAlias = Generator["Wait", Union["Ready", int], RV]
 """Generator for processes where the connection file number won't change.
 """
 
index 4f2ec878bb9cc70d06c392f884cba07962b9669e..2e463196e6e5eee462f3bdbde8839f6a3ef712a9 100644 (file)
@@ -7,10 +7,15 @@ the operations, yielding a polling state whenever there is to wait. The
 functions in the `waiting` module are the ones who wait more or less
 cooperatively for the socket to be ready and make these generators continue.
 
-All these generators yield pairs (fileno, `Wait`) whenever an operation would
-block. The generator can be restarted sending the appropriate `Ready` state
-when the file descriptor is ready.
-
+These generators yield `Wait` objects whenever an operation would block. These
+generators assume the connection fileno will not change. In case of the
+connection function, where the fileno may change, the generators yield pairs
+(fileno, `Wait`).
+
+The generator can be restarted sending the appropriate `Ready` state when the
+file descriptor is ready. If a None value is sent, it means that the wait
+function timed out without any file descriptor becoming ready; in this case the
+generator should probably yield the same value again in order to wait more.
 """
 
 # Copyright (C) 2020 The Psycopg Team
@@ -119,7 +124,11 @@ def _send(pgconn: PGconn) -> PQGen[None]:
         if f == 0:
             break
 
-        ready = yield WAIT_RW
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
+
         if ready & READY_R:
             # This call may read notifies: they will be saved in the
             # PGconn buffer and passed to Python later, in `fetch()`.
@@ -168,12 +177,19 @@ def _fetch(pgconn: PGconn) -> PQGen[Optional[PGresult]]:
     Return a result from the database (whether success or error).
     """
     if pgconn.is_busy():
-        yield WAIT_R
+        while True:
+            ready = yield WAIT_R
+            if ready:
+                break
+
         while True:
             pgconn.consume_input()
             if not pgconn.is_busy():
                 break
-            yield WAIT_R
+            while True:
+                ready = yield WAIT_R
+                if ready:
+                    break
 
     _consume_notifies(pgconn)
 
@@ -191,7 +207,10 @@ def _pipeline_communicate(
     results = []
 
     while True:
-        ready = yield WAIT_RW
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
 
         if ready & READY_R:
             pgconn.consume_input()
@@ -263,7 +282,10 @@ def copy_from(pgconn: PGconn) -> PQGen[Union[memoryview, PGresult]]:
             break
 
         # would block
-        yield WAIT_R
+        while True:
+            ready = yield WAIT_R
+            if ready:
+                break
         pgconn.consume_input()
 
     if nbytes > 0:
@@ -291,17 +313,26 @@ def copy_to(pgconn: PGconn, buffer: Buffer) -> PQGen[None]:
     # into smaller ones. We prefer to do it there instead of here in order to
     # do it upstream the queue decoupling the writer task from the producer one.
     while pgconn.put_copy_data(buffer) == 0:
-        yield WAIT_W
+        while True:
+            ready = yield WAIT_W
+            if ready:
+                break
 
 
 def copy_end(pgconn: PGconn, error: Optional[bytes]) -> PQGen[PGresult]:
     # Retry enqueuing end copy message until successful
     while pgconn.put_copy_end(error) == 0:
-        yield WAIT_W
+        while True:
+            ready = yield WAIT_W
+            if ready:
+                break
 
     # Repeat until it the message is flushed to the server
     while True:
-        yield WAIT_W
+        while True:
+            ready = yield WAIT_W
+            if ready:
+                break
         f = pgconn.flush()
         if f == 0:
             break
index d6db0d922e899f53c130d538069fd26e7697243e..6315c0ad7c459224af7db1b742f35336e79de053 100644 (file)
@@ -26,6 +26,7 @@ from ._cmodule import _psycopg
 WAIT_R = Wait.R
 WAIT_W = Wait.W
 WAIT_RW = Wait.RW
+READY_NONE = Ready.NONE
 READY_R = Ready.R
 READY_W = Ready.W
 READY_RW = Ready.RW
@@ -51,16 +52,17 @@ def wait_selector(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None)
     try:
         s = next(gen)
         with DefaultSelector() as sel:
+            sel.register(fileno, s)
             while True:
-                sel.register(fileno, s)
-                rlist = None
-                while not rlist:
-                    rlist = sel.select(timeout=timeout)
+                rlist = sel.select(timeout=timeout)
+                if not rlist:
+                    gen.send(READY_NONE)
+                    continue
+
                 sel.unregister(fileno)
-                # note: this line should require a cast, but mypy doesn't complain
-                ready: Ready = rlist[0][1]
-                assert s & ready
+                ready = rlist[0][1]
                 s = gen.send(ready)
+                sel.register(fileno, s)
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
@@ -92,7 +94,7 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
                 sel.unregister(fileno)
                 if not rlist:
                     raise e.ConnectionTimeout("connection timeout expired")
-                ready: Ready = rlist[0][1]  # type: ignore[assignment]
+                ready = rlist[0][1]
                 fileno, s = gen.send(ready)
 
     except StopIteration as ex:
@@ -119,12 +121,12 @@ async def wait_async(
     # Not sure this is the best implementation but it's a start.
     ev = Event()
     loop = get_event_loop()
-    ready: Ready
+    ready: int
     s: Wait
 
     def wakeup(state: Ready) -> None:
         nonlocal ready
-        ready |= state  # type: ignore[assignment]
+        ready |= state
         ev.set()
 
     try:
@@ -135,19 +137,19 @@ async def wait_async(
             if not reader and not writer:
                 raise e.InternalError(f"bad poll status: {s}")
             ev.clear()
-            ready = 0  # type: ignore[assignment]
+            ready = 0
             if reader:
                 loop.add_reader(fileno, wakeup, READY_R)
             if writer:
                 loop.add_writer(fileno, wakeup, READY_W)
             try:
-                if timeout is None:
-                    await ev.wait()
-                else:
+                if timeout:
                     try:
                         await wait_for(ev.wait(), timeout)
                     except TimeoutError:
                         pass
+                else:
+                    await ev.wait()
             finally:
                 if reader:
                     loop.remove_reader(fileno)
@@ -155,6 +157,9 @@ async def wait_async(
                     loop.remove_writer(fileno)
             s = gen.send(ready)
 
+    except OSError as ex:
+        # Assume the connection was closed
+        raise e.OperationalError(str(ex))
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
         return rv
@@ -245,9 +250,10 @@ def wait_select(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) ->
             if wl:
                 ready |= READY_W
             if not ready:
+                gen.send(READY_NONE)
                 continue
-            # assert s & ready
-            s = gen.send(ready)  # type: ignore
+
+            s = gen.send(ready)
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
@@ -285,24 +291,22 @@ def wait_epoll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) ->
         s = next(gen)
 
         if timeout is None or timeout < 0:
-            timeout = 0
-        else:
-            timeout = int(timeout * 1000.0)
+            timeout = 0.0
 
         with select.epoll() as epoll:
             evmask = _epoll_evmasks[s]
             epoll.register(fileno, evmask)
             while True:
-                fileevs = None
-                while not fileevs:
-                    fileevs = epoll.poll(timeout)
+                fileevs = epoll.poll(timeout)
+                if not fileevs:
+                    gen.send(READY_NONE)
+                    continue
                 ev = fileevs[0][1]
                 ready = 0
                 if ev & ~select.EPOLLOUT:
                     ready = READY_R
                 if ev & ~select.EPOLLIN:
                     ready |= READY_W
-                # assert s & ready
                 s = gen.send(ready)
                 evmask = _epoll_evmasks[s]
                 epoll.modify(fileno, evmask)
@@ -340,16 +344,17 @@ def wait_poll(gen: PQGen[RV], fileno: int, timeout: Optional[float] = None) -> R
         evmask = _poll_evmasks[s]
         poll.register(fileno, evmask)
         while True:
-            fileevs = None
-            while not fileevs:
-                fileevs = poll.poll(timeout)
+            fileevs = poll.poll(timeout)
+            if not fileevs:
+                gen.send(READY_NONE)
+                continue
+
             ev = fileevs[0][1]
             ready = 0
             if ev & ~select.POLLOUT:
                 ready = READY_R
             if ev & ~select.POLLIN:
                 ready |= READY_W
-            # assert s & ready
             s = gen.send(ready)
             evmask = _poll_evmasks[s]
             poll.modify(fileno, evmask)
index a51fce5e28d5f2117e1a161fa6c3c2a8663ede9c..70335cf8995731ea0d156baa8f9026a17ce117dd 100644 (file)
@@ -18,9 +18,11 @@ from psycopg._encodings import conninfo_encoding
 cdef object WAIT_W = Wait.W
 cdef object WAIT_R = Wait.R
 cdef object WAIT_RW = Wait.RW
+cdef object PY_READY_NONE = Ready.NONE
 cdef object PY_READY_R = Ready.R
 cdef object PY_READY_W = Ready.W
 cdef object PY_READY_RW = Ready.RW
+cdef int READY_NONE = Ready.NONE
 cdef int READY_R = Ready.R
 cdef int READY_W = Ready.W
 cdef int READY_RW = Ready.RW
@@ -96,15 +98,19 @@ def send(pq.PGconn pgconn) -> PQGen[None]:
     to retrieve the results available.
     """
     cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
-    cdef int status
+    cdef int ready
     cdef int cires
 
     while True:
         if pgconn.flush() == 0:
             break
 
-        status = yield WAIT_RW
-        if status & READY_R:
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
+
+        if ready & READY_R:
             with nogil:
                 # This call may read notifies which will be saved in the
                 # PGconn buffer and passed to Python later.
@@ -166,11 +172,16 @@ def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]:
     cdef libpq.PGconn *pgconn_ptr = pgconn._pgconn_ptr
     cdef int cires, ibres
     cdef libpq.PGresult *pgres
+    cdef object ready
 
     with nogil:
         ibres = libpq.PQisBusy(pgconn_ptr)
     if ibres:
-        yield WAIT_R
+        while True:
+            ready = yield WAIT_R
+            if ready:
+                break
+
         while True:
             with nogil:
                 cires = libpq.PQconsumeInput(pgconn_ptr)
@@ -182,7 +193,10 @@ def fetch(pq.PGconn pgconn) -> PQGen[Optional[PGresult]]:
                     f"consuming input failed: {error_message(pgconn)}")
             if not ibres:
                 break
-            yield WAIT_R
+            while True:
+                ready = yield WAIT_R
+                if ready:
+                    break
 
     _consume_notifies(pgconn)
 
@@ -211,7 +225,10 @@ def pipeline_communicate(
     cdef pq.PGresult r
 
     while True:
-        ready = yield WAIT_RW
+        while True:
+            ready = yield WAIT_RW
+            if ready:
+                break
 
         if ready & READY_R:
             with nogil:
index 33c54c513b8d0e4ed5afbd4db7cb9b5e8c792dd8..3a6cc6e255eb0c4a41cfe38d5030f684ea8ef898 100644 (file)
@@ -51,7 +51,7 @@ static int
 wait_c_impl(int fileno, int wait, float timeout)
 {
     int select_rv;
-    int rv = 0;
+    int rv = -1;
 
 #if defined(HAVE_POLL) && !defined(HAVE_BROKEN_POLL)
 
@@ -83,11 +83,14 @@ retry_eintr:
         goto retry_eintr;
     }
 
-    if (select_rv < 0) { goto error; }
     if (PyErr_CheckSignals()) { goto finally; }
+    if (select_rv < 0) { goto finally; }  /* poll error */
 
-    if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
-    if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+    rv = 0;  /* success, maybe with timeout */
+    if (select_rv >= 0) {
+        if (input_fd.events & POLLIN) { rv |= SELECT_EV_READ; }
+        if (input_fd.events & POLLOUT) { rv |= SELECT_EV_WRITE; }
+    }
 
 #else
 
@@ -135,11 +138,14 @@ retry_eintr:
         goto retry_eintr;
     }
 
-    if (select_rv < 0) { goto error; }
     if (PyErr_CheckSignals()) { goto finally; }
+    if (select_rv < 0) { goto error; }  /* select error */
 
-    if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
-    if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+    rv = 0;
+    if (select_rv > 0) {
+        if (FD_ISSET(fileno, &ifds)) { rv |= SELECT_EV_READ; }
+        if (FD_ISSET(fileno, &ofds)) { rv |= SELECT_EV_WRITE; }
+    }
 
 #endif  /* HAVE_POLL */
 
@@ -147,6 +153,8 @@ retry_eintr:
 
 error:
 
+    rv = -1;
+
 #ifdef MS_WINDOWS
     if (select_rv == SOCKET_ERROR) {
         PyErr_SetExcFromWindowsErr(PyExc_OSError, WSAGetLastError());
@@ -162,7 +170,7 @@ error:
 
 finally:
 
-    return -1;
+    return rv;
 
 }
     """
@@ -191,8 +199,8 @@ def wait_c(gen: PQGen[RV], int fileno, timeout = None) -> RV:
 
         while True:
             ready = wait_c_impl(fileno, wait, ctimeout)
-            if ready == 0:
-                continue
+            if ready == READY_NONE:
+                pyready = <PyObject *>PY_READY_NONE
             elif ready == READY_R:
                 pyready = <PyObject *>PY_READY_R
             elif ready == READY_RW:
index 6a9ad88f376155388db44433fc66736302e6d96b..c4d8915e8a801209fcc416203b77fa20e7df4ac1 100644 (file)
@@ -1,6 +1,7 @@
+import sys
+import time
 import select  # noqa: used in pytest.mark.skipif
 import socket
-import sys
 
 import pytest
 
@@ -26,6 +27,7 @@ waitfns = [
     pytest.param("wait_c", marks=pytest.mark.skipif("not psycopg._cmodule._psycopg")),
 ]
 
+events = ["R", "W", "RW"]
 timeouts = [pytest.param({}, id="blank")]
 timeouts += [pytest.param({"timeout": x}, id=str(x)) for x in [None, 0, 0.2, 10]]
 
@@ -44,9 +46,11 @@ def test_wait_conn_bad(dsn):
 
 
 @pytest.mark.parametrize("waitfn", waitfns)
-@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@pytest.mark.parametrize("event", events)
 @skip_if_not_linux
-def test_wait_ready(waitfn, wait, ready):
+def test_wait_ready(waitfn, event):
+    wait = getattr(waiting.Wait, event)
+    ready = getattr(waiting.Ready, event)
     waitfn = getattr(waiting, waitfn)
 
     def gen():
@@ -80,6 +84,34 @@ def test_wait_bad(pgconn, waitfn):
         waitfn(gen, pgconn.socket)
 
 
+@pytest.mark.slow
+@pytest.mark.timing
+@pytest.mark.parametrize("waitfn", waitfns)
+def test_wait_timeout(pgconn, waitfn):
+    waitfn = getattr(waiting, waitfn)
+
+    pgconn.send_query(b"select pg_sleep(0.5)")
+    gen = generators.execute(pgconn)
+
+    ts = [time.time()]
+
+    def gen_wrapper():
+        try:
+            for x in gen:
+                res = yield x
+                ts.append(time.time())
+                gen.send(res)
+        except StopIteration as ex:
+            return ex.value
+
+    (res,) = waitfn(gen_wrapper(), pgconn.socket, timeout=0.1)
+    assert res.status == ExecStatus.TUPLES_OK
+    ds = [t1 - t0 for t0, t1 in zip(ts[:-1], ts[1:])]
+    assert len(ds) >= 5
+    for d in ds[:5]:
+        assert d == pytest.approx(0.1, 0.05)
+
+
 @pytest.mark.slow
 @pytest.mark.skipif(
     "sys.platform == 'win32'", reason="win32 works ok, but FDs are mysterious"
@@ -130,9 +162,12 @@ async def test_wait_conn_async_bad(dsn):
 
 
 @pytest.mark.anyio
-@pytest.mark.parametrize("wait, ready", zip(waiting.Wait, waiting.Ready))
+@pytest.mark.parametrize("event", events)
 @skip_if_not_linux
-async def test_wait_ready_async(wait, ready):
+async def test_wait_ready_async(event):
+    wait = getattr(waiting.Wait, event)
+    ready = getattr(waiting.Ready, event)
+
     def gen():
         r = yield wait
         return r