]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
refactor: handle timeout in the connection generator, not the waiting function
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 26 Jan 2024 09:47:31 +0000 (09:47 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 29 Jan 2024 02:26:36 +0000 (02:26 +0000)
The waiting function is supposed to be generic, the timeout is a policy
decision of the generator.

This change aligns the semantics of the `timeout` parameter of `wait_conn()`
to the one of the other wait functions: the timeout is actually an
interval. It will be renamed to clarify that in a followup commit.

psycopg/psycopg/_connection_base.py
psycopg/psycopg/connection.py
psycopg/psycopg/connection_async.py
psycopg/psycopg/generators.py
psycopg/psycopg/waiting.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/generators.pyx
tests/test_connection.py
tests/test_connection_async.py
tests/test_module.py
tests/test_psycopg_dbapi20.py

index 39e00002b41eba5312a6826dbb516d878e32c48f..12ae7395ff370f518fcd490c8b8d29341e0998c2 100644 (file)
@@ -419,9 +419,11 @@ class BaseConnection(Generic[Row]):
     # should have a lock and hold it before calling and consuming them.
 
     @classmethod
-    def _connect_gen(cls, conninfo: str = "") -> PQGenConn[Self]:
+    def _connect_gen(
+        cls, conninfo: str = "", *, timeout: float = 0.0
+    ) -> PQGenConn[Self]:
         """Generator to connect to the database and create a new instance."""
-        pgconn = yield from generators.connect(conninfo)
+        pgconn = yield from generators.connect(conninfo, timeout=timeout)
         conn = cls(pgconn)
         return conn
 
index cb0244aa504d5c7dc617dc16a7bcfb057a6218ba..977b6a7adf858da4c0f7be99a819cc6af78d58b4 100644 (file)
@@ -19,7 +19,7 @@ from contextlib import contextmanager
 from . import pq
 from . import errors as e
 from . import waiting
-from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV
 from ._tpc import Xid
 from .rows import Row, RowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
@@ -100,8 +100,8 @@ class Connection(BaseConnection[Row]):
         for attempt in attempts:
             try:
                 conninfo = make_conninfo("", **attempt)
-                rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
-                break
+                gen = cls._connect_gen(conninfo, timeout=timeout)
+                rv = waiting.wait_conn(gen, timeout=_WAIT_INTERVAL)
             except e._NO_TRACEBACK as ex:
                 if len(attempts) > 1:
                     logger.debug(
@@ -112,6 +112,8 @@ class Connection(BaseConnection[Row]):
                         str(ex),
                     )
                 last_ex = ex
+            else:
+                break
 
         if not rv:
             assert last_ex
@@ -371,11 +373,6 @@ class Connection(BaseConnection[Row]):
                     pass  # as expected
             raise
 
-    @classmethod
-    def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
-        """Consume a connection generator."""
-        return waiting.wait_conn(gen, timeout)
-
     def _set_autocommit(self, value: bool) -> None:
         self.set_autocommit(value)
 
index d810d45b29850dcd8cead444a863f25e85a51adc..0e1a7799b49bc9feb68543f9517751e95b2add91 100644 (file)
@@ -16,7 +16,7 @@ from contextlib import asynccontextmanager
 from . import pq
 from . import errors as e
 from . import waiting
-from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV
+from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV
 from ._tpc import Xid
 from .rows import Row, AsyncRowFactory, tuple_row, args_row
 from .adapt import AdaptersMap
@@ -115,8 +115,8 @@ class AsyncConnection(BaseConnection[Row]):
         for attempt in attempts:
             try:
                 conninfo = make_conninfo("", **attempt)
-                rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout)
-                break
+                gen = cls._connect_gen(conninfo, timeout=timeout)
+                rv = await waiting.wait_conn_async(gen, timeout=_WAIT_INTERVAL)
             except e._NO_TRACEBACK as ex:
                 if len(attempts) > 1:
                     logger.debug(
@@ -127,6 +127,8 @@ class AsyncConnection(BaseConnection[Row]):
                         str(ex),
                     )
                 last_ex = ex
+            else:
+                break
 
         if not rv:
             assert last_ex
@@ -389,11 +391,6 @@ class AsyncConnection(BaseConnection[Row]):
                     pass  # as expected
             raise
 
-    @classmethod
-    async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
-        """Consume a connection generator."""
-        return await waiting.wait_conn_async(gen, timeout)
-
     def _set_autocommit(self, value: bool) -> None:
         if True:  # ASYNC
             self._no_set_async("autocommit")
index 2e463196e6e5eee462f3bdbde8839f6a3ef712a9..96143af939f66d00a13696e341ca05f6c7cfd7a3 100644 (file)
@@ -21,6 +21,7 @@ generator should probably yield the same value again in order to wait more.
 # Copyright (C) 2020 The Psycopg Team
 
 import logging
+from time import monotonic
 from typing import List, Optional, Union
 
 from . import pq
@@ -56,11 +57,12 @@ READY_RW = Ready.RW
 logger = logging.getLogger(__name__)
 
 
-def _connect(conninfo: str) -> PQGenConn[PGconn]:
+def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]:
     """
     Generator to create a database connection without blocking.
-
     """
+    deadline = monotonic() + timeout if timeout else 0.0
+
     conn = pq.PGconn.connect_start(conninfo.encode())
     while True:
         if conn.status == BAD:
@@ -71,12 +73,18 @@ def _connect(conninfo: str) -> PQGenConn[PGconn]:
             )
 
         status = conn.connect_poll()
-        if status == POLL_OK:
+
+        if status == POLL_READING or status == POLL_WRITING:
+            wait = WAIT_R if status == POLL_READING else WAIT_W
+            while True:
+                ready = yield conn.socket, wait
+                if deadline and monotonic() > deadline:
+                    raise e.ConnectionTimeout("connection timeout expired")
+                if ready:
+                    break
+
+        elif status == POLL_OK:
             break
-        elif status == POLL_READING:
-            yield conn.socket, WAIT_R
-        elif status == POLL_WRITING:
-            yield conn.socket, WAIT_W
         elif status == POLL_FAILED:
             encoding = conninfo_encoding(conninfo)
             raise e.OperationalError(
index 6315c0ad7c459224af7db1b742f35336e79de053..295133a0f7a0892fa488d570a7cac9abbdee2e12 100644 (file)
@@ -88,14 +88,17 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
         if not timeout:
             timeout = None
         with DefaultSelector() as sel:
+            sel.register(fileno, s)
             while True:
-                sel.register(fileno, s)
                 rlist = sel.select(timeout=timeout)
-                sel.unregister(fileno)
                 if not rlist:
-                    raise e.ConnectionTimeout("connection timeout expired")
+                    gen.send(READY_NONE)
+                    continue
+
+                sel.unregister(fileno)
                 ready = rlist[0][1]
                 fileno, s = gen.send(ready)
+                sel.register(fileno, s)
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
@@ -205,7 +208,10 @@ async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -
                 loop.add_writer(fileno, wakeup, READY_W)
             try:
                 if timeout:
-                    await wait_for(ev.wait(), timeout)
+                    try:
+                        await wait_for(ev.wait(), timeout)
+                    except TimeoutError:
+                        pass
                 else:
                     await ev.wait()
             finally:
@@ -215,9 +221,6 @@ async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) -
                     loop.remove_writer(fileno)
             fileno, s = gen.send(ready)
 
-    except TimeoutError:
-        raise e.ConnectionTimeout("connection timeout expired")
-
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
         return rv
index 7d456ba538d8cbcffcd86108d2ff70638e8255b6..10d230bfb9a614244f33977461e8b9242ecadf69 100644 (file)
@@ -51,7 +51,7 @@ class Transformer(abc.AdaptContext):
     def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ...
 
 # Generators
-def connect(conninfo: str) -> abc.PQGenConn[PGconn]: ...
+def connect(conninfo: str, *, timeout: float = 0.0) -> abc.PQGenConn[PGconn]: ...
 def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
 def send(pgconn: PGconn) -> abc.PQGen[None]: ...
 def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ...
index 70335cf8995731ea0d156baa8f9026a17ce117dd..b29ad327a2eb8d2ecf76e0f5378a8a6509ad079a 100644 (file)
@@ -7,6 +7,7 @@ C implementation of generators for the communication protocols with the libpq
 from cpython.object cimport PyObject_CallFunctionObjArgs
 
 from typing import List
+from time import monotonic
 
 from psycopg import errors as e
 from psycopg.pq import abc, error_message
@@ -27,15 +28,17 @@ cdef int READY_R = Ready.R
 cdef int READY_W = Ready.W
 cdef int READY_RW = Ready.RW
 
-def connect(conninfo: str) -> PQGenConn[abc.PGconn]:
+def connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[abc.PGconn]:
     """
     Generator to create a database connection without blocking.
-
     """
+    cdef int deadline = monotonic() + timeout if timeout else 0.0
+
     cdef pq.PGconn conn = pq.PGconn.connect_start(conninfo.encode())
     cdef libpq.PGconn *pgconn_ptr = conn._pgconn_ptr
     cdef int conn_status = libpq.PQstatus(pgconn_ptr)
     cdef int poll_status
+    cdef object wait, ready
 
     while True:
         if conn_status == libpq.CONNECTION_BAD:
@@ -48,12 +51,18 @@ def connect(conninfo: str) -> PQGenConn[abc.PGconn]:
         with nogil:
             poll_status = libpq.PQconnectPoll(pgconn_ptr)
 
-        if poll_status == libpq.PGRES_POLLING_OK:
+        if poll_status == libpq.PGRES_POLLING_READING \
+        or poll_status == libpq.PGRES_POLLING_WRITING:
+            wait = WAIT_R if poll_status == libpq.PGRES_POLLING_READING else WAIT_W
+            while True:
+                ready = yield (libpq.PQsocket(pgconn_ptr), wait)
+                if deadline and monotonic() > deadline:
+                    raise e.ConnectionTimeout("connection timeout expired")
+                if ready:
+                    break
+
+        elif poll_status == libpq.PGRES_POLLING_OK:
             break
-        elif poll_status == libpq.PGRES_POLLING_READING:
-            yield (libpq.PQsocket(pgconn_ptr), WAIT_R)
-        elif poll_status == libpq.PGRES_POLLING_WRITING:
-            yield (libpq.PQsocket(pgconn_ptr), WAIT_W)
         elif poll_status == libpq.PGRES_POLLING_FAILED:
             encoding = conninfo_encoding(conninfo)
             raise e.OperationalError(
index f17f8d2a06f85f3097dad4e685e740b52b9ba732..eb854123817b33357b11443cff1cf3d6f04583a0 100644 (file)
@@ -446,7 +446,7 @@ def test_connect_args(
 ):
     got_conninfo: str
 
-    def fake_connect(conninfo):
+    def fake_connect(conninfo, *, timeout=0.0):
         nonlocal got_conninfo
         got_conninfo = conninfo
         return pgconn
@@ -468,12 +468,6 @@ def test_connect_args(
     ],
 )
 def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype):
-
-    def fake_connect(conninfo):
-        return pgconn
-        yield
-
-    monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     with pytest.raises(exctype):
         conn_cls.connect(*args, **kwargs)
 
index d7aa7ca8b41898e927a5e1c9dbe15fc30420d75c..cd761b330ede8f3b43f497c287ffcfda87189bc2 100644 (file)
@@ -443,7 +443,7 @@ async def test_connect_args(
 ):
     got_conninfo: str
 
-    def fake_connect(conninfo):
+    def fake_connect(conninfo, *, timeout=0.0):
         nonlocal got_conninfo
         got_conninfo = conninfo
         return pgconn
@@ -465,11 +465,6 @@ async def test_connect_args(
     ],
 )
 async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype):
-    def fake_connect(conninfo):
-        return pgconn
-        yield
-
-    monkeypatch.setattr(psycopg.generators, "connect", fake_connect)
     with pytest.raises(exctype):
         await aconn_cls.connect(*args, **kwargs)
 
index 9b144d7d6be31bc0edb32884db1c89e39b69737c..2b1869e945ac9423586ef380487bd16c9c7f157e 100644 (file)
@@ -22,10 +22,10 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv):
 
     got_conninfo: str
 
-    def mock_connect(conninfo):
+    def mock_connect(conninfo, *, timeout):
         nonlocal got_conninfo
         got_conninfo = conninfo
-        return orig_connect(dsn_env)
+        return orig_connect(dsn_env, timeout=timeout)
 
     setpgenv({})
     monkeypatch.setattr(psycopg.generators, "connect", mock_connect)
index b4feac792ab1783cb4429795b715a39f5d0d4a3d..df1c981a66d90aa70c976d93d978a37c4a019c28 100644 (file)
@@ -135,7 +135,7 @@ def test_time_from_ticks(ticks, want):
 def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_resolve):
     got_conninfo: str
 
-    def fake_connect(conninfo):
+    def fake_connect(conninfo, *, timeout=0.0):
         nonlocal got_conninfo
         got_conninfo = conninfo
         return pgconn
@@ -157,9 +157,5 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_re
     ],
 )
 def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype):
-    def fake_connect(conninfo):
-        return pgconn
-        yield
-
     with pytest.raises(exctype):
         psycopg.connect(*args, **kwargs)