From: Daniele Varrazzo Date: Fri, 26 Jan 2024 09:47:31 +0000 (+0000) Subject: refactor: handle timeout in the connection generator, not the waiting function X-Git-Tag: 3.2.0~86^2~1 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=07da60bedb4aa398100ece1c193f592c750d94f0;p=thirdparty%2Fpsycopg.git refactor: handle timeout in the connection generator, not the waiting function The waiting function is supposed to be generic, the timeout is a policy decision of the generator. This change aligns the semantics of the `timeout` parameter of `wait_conn()` to the one of the other wait functions: the timeout is actually an interval. It will be renamed to clarify that in a followup commit. --- diff --git a/psycopg/psycopg/_connection_base.py b/psycopg/psycopg/_connection_base.py index 39e00002b..12ae7395f 100644 --- a/psycopg/psycopg/_connection_base.py +++ b/psycopg/psycopg/_connection_base.py @@ -419,9 +419,11 @@ class BaseConnection(Generic[Row]): # should have a lock and hold it before calling and consuming them. @classmethod - def _connect_gen(cls, conninfo: str = "") -> PQGenConn[Self]: + def _connect_gen( + cls, conninfo: str = "", *, timeout: float = 0.0 + ) -> PQGenConn[Self]: """Generator to connect to the database and create a new instance.""" - pgconn = yield from generators.connect(conninfo) + pgconn = yield from generators.connect(conninfo, timeout=timeout) conn = cls(pgconn) return conn diff --git a/psycopg/psycopg/connection.py b/psycopg/psycopg/connection.py index cb0244aa5..977b6a7ad 100644 --- a/psycopg/psycopg/connection.py +++ b/psycopg/psycopg/connection.py @@ -19,7 +19,7 @@ from contextlib import contextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV from ._tpc import Xid from .rows import Row, RowFactory, tuple_row, args_row from .adapt import AdaptersMap @@ -100,8 +100,8 @@ class Connection(BaseConnection[Row]): for attempt in attempts: try: conninfo = make_conninfo("", **attempt) - rv = cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) - break + gen = cls._connect_gen(conninfo, timeout=timeout) + rv = waiting.wait_conn(gen, timeout=_WAIT_INTERVAL) except e._NO_TRACEBACK as ex: if len(attempts) > 1: logger.debug( @@ -112,6 +112,8 @@ class Connection(BaseConnection[Row]): str(ex), ) last_ex = ex + else: + break if not rv: assert last_ex @@ -371,11 +373,6 @@ class Connection(BaseConnection[Row]): pass # as expected raise - @classmethod - def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: - """Consume a connection generator.""" - return waiting.wait_conn(gen, timeout) - def _set_autocommit(self, value: bool) -> None: self.set_autocommit(value) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index d810d45b2..0e1a7799b 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -16,7 +16,7 @@ from contextlib import asynccontextmanager from . import pq from . import errors as e from . import waiting -from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, PQGenConn, Query, RV +from .abc import AdaptContext, ConnDict, ConnParam, Params, PQGen, Query, RV from ._tpc import Xid from .rows import Row, AsyncRowFactory, tuple_row, args_row from .adapt import AdaptersMap @@ -115,8 +115,8 @@ class AsyncConnection(BaseConnection[Row]): for attempt in attempts: try: conninfo = make_conninfo("", **attempt) - rv = await cls._wait_conn(cls._connect_gen(conninfo), timeout=timeout) - break + gen = cls._connect_gen(conninfo, timeout=timeout) + rv = await waiting.wait_conn_async(gen, timeout=_WAIT_INTERVAL) except e._NO_TRACEBACK as ex: if len(attempts) > 1: logger.debug( @@ -127,6 +127,8 @@ class AsyncConnection(BaseConnection[Row]): str(ex), ) last_ex = ex + else: + break if not rv: assert last_ex @@ -389,11 +391,6 @@ class AsyncConnection(BaseConnection[Row]): pass # as expected raise - @classmethod - async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: - """Consume a connection generator.""" - return await waiting.wait_conn_async(gen, timeout) - def _set_autocommit(self, value: bool) -> None: if True: # ASYNC self._no_set_async("autocommit") diff --git a/psycopg/psycopg/generators.py b/psycopg/psycopg/generators.py index 2e463196e..96143af93 100644 --- a/psycopg/psycopg/generators.py +++ b/psycopg/psycopg/generators.py @@ -21,6 +21,7 @@ generator should probably yield the same value again in order to wait more. # Copyright (C) 2020 The Psycopg Team import logging +from time import monotonic from typing import List, Optional, Union from . import pq @@ -56,11 +57,12 @@ READY_RW = Ready.RW logger = logging.getLogger(__name__) -def _connect(conninfo: str) -> PQGenConn[PGconn]: +def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]: """ Generator to create a database connection without blocking. - """ + deadline = monotonic() + timeout if timeout else 0.0 + conn = pq.PGconn.connect_start(conninfo.encode()) while True: if conn.status == BAD: @@ -71,12 +73,18 @@ def _connect(conninfo: str) -> PQGenConn[PGconn]: ) status = conn.connect_poll() - if status == POLL_OK: + + if status == POLL_READING or status == POLL_WRITING: + wait = WAIT_R if status == POLL_READING else WAIT_W + while True: + ready = yield conn.socket, wait + if deadline and monotonic() > deadline: + raise e.ConnectionTimeout("connection timeout expired") + if ready: + break + + elif status == POLL_OK: break - elif status == POLL_READING: - yield conn.socket, WAIT_R - elif status == POLL_WRITING: - yield conn.socket, WAIT_W elif status == POLL_FAILED: encoding = conninfo_encoding(conninfo) raise e.OperationalError( diff --git a/psycopg/psycopg/waiting.py b/psycopg/psycopg/waiting.py index 6315c0ad7..295133a0f 100644 --- a/psycopg/psycopg/waiting.py +++ b/psycopg/psycopg/waiting.py @@ -88,14 +88,17 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: if not timeout: timeout = None with DefaultSelector() as sel: + sel.register(fileno, s) while True: - sel.register(fileno, s) rlist = sel.select(timeout=timeout) - sel.unregister(fileno) if not rlist: - raise e.ConnectionTimeout("connection timeout expired") + gen.send(READY_NONE) + continue + + sel.unregister(fileno) ready = rlist[0][1] fileno, s = gen.send(ready) + sel.register(fileno, s) except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None @@ -205,7 +208,10 @@ async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) - loop.add_writer(fileno, wakeup, READY_W) try: if timeout: - await wait_for(ev.wait(), timeout) + try: + await wait_for(ev.wait(), timeout) + except TimeoutError: + pass else: await ev.wait() finally: @@ -215,9 +221,6 @@ async def wait_conn_async(gen: PQGenConn[RV], timeout: Optional[float] = None) - loop.remove_writer(fileno) fileno, s = gen.send(ready) - except TimeoutError: - raise e.ConnectionTimeout("connection timeout expired") - except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None return rv diff --git a/psycopg_c/psycopg_c/_psycopg.pyi b/psycopg_c/psycopg_c/_psycopg.pyi index 7d456ba53..10d230bfb 100644 --- a/psycopg_c/psycopg_c/_psycopg.pyi +++ b/psycopg_c/psycopg_c/_psycopg.pyi @@ -51,7 +51,7 @@ class Transformer(abc.AdaptContext): def get_loader(self, oid: int, format: pq.Format) -> abc.Loader: ... # Generators -def connect(conninfo: str) -> abc.PQGenConn[PGconn]: ... +def connect(conninfo: str, *, timeout: float = 0.0) -> abc.PQGenConn[PGconn]: ... def execute(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... def send(pgconn: PGconn) -> abc.PQGen[None]: ... def fetch_many(pgconn: PGconn) -> abc.PQGen[List[PGresult]]: ... diff --git a/psycopg_c/psycopg_c/_psycopg/generators.pyx b/psycopg_c/psycopg_c/_psycopg/generators.pyx index 70335cf89..b29ad327a 100644 --- a/psycopg_c/psycopg_c/_psycopg/generators.pyx +++ b/psycopg_c/psycopg_c/_psycopg/generators.pyx @@ -7,6 +7,7 @@ C implementation of generators for the communication protocols with the libpq from cpython.object cimport PyObject_CallFunctionObjArgs from typing import List +from time import monotonic from psycopg import errors as e from psycopg.pq import abc, error_message @@ -27,15 +28,17 @@ cdef int READY_R = Ready.R cdef int READY_W = Ready.W cdef int READY_RW = Ready.RW -def connect(conninfo: str) -> PQGenConn[abc.PGconn]: +def connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[abc.PGconn]: """ Generator to create a database connection without blocking. - """ + cdef int deadline = monotonic() + timeout if timeout else 0.0 + cdef pq.PGconn conn = pq.PGconn.connect_start(conninfo.encode()) cdef libpq.PGconn *pgconn_ptr = conn._pgconn_ptr cdef int conn_status = libpq.PQstatus(pgconn_ptr) cdef int poll_status + cdef object wait, ready while True: if conn_status == libpq.CONNECTION_BAD: @@ -48,12 +51,18 @@ def connect(conninfo: str) -> PQGenConn[abc.PGconn]: with nogil: poll_status = libpq.PQconnectPoll(pgconn_ptr) - if poll_status == libpq.PGRES_POLLING_OK: + if poll_status == libpq.PGRES_POLLING_READING \ + or poll_status == libpq.PGRES_POLLING_WRITING: + wait = WAIT_R if poll_status == libpq.PGRES_POLLING_READING else WAIT_W + while True: + ready = yield (libpq.PQsocket(pgconn_ptr), wait) + if deadline and monotonic() > deadline: + raise e.ConnectionTimeout("connection timeout expired") + if ready: + break + + elif poll_status == libpq.PGRES_POLLING_OK: break - elif poll_status == libpq.PGRES_POLLING_READING: - yield (libpq.PQsocket(pgconn_ptr), WAIT_R) - elif poll_status == libpq.PGRES_POLLING_WRITING: - yield (libpq.PQsocket(pgconn_ptr), WAIT_W) elif poll_status == libpq.PGRES_POLLING_FAILED: encoding = conninfo_encoding(conninfo) raise e.OperationalError( diff --git a/tests/test_connection.py b/tests/test_connection.py index f17f8d2a0..eb8541238 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -446,7 +446,7 @@ def test_connect_args( ): got_conninfo: str - def fake_connect(conninfo): + def fake_connect(conninfo, *, timeout=0.0): nonlocal got_conninfo got_conninfo = conninfo return pgconn @@ -468,12 +468,6 @@ def test_connect_args( ], ) def test_connect_badargs(conn_cls, monkeypatch, pgconn, args, kwargs, exctype): - - def fake_connect(conninfo): - return pgconn - yield - - monkeypatch.setattr(psycopg.generators, "connect", fake_connect) with pytest.raises(exctype): conn_cls.connect(*args, **kwargs) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index d7aa7ca8b..cd761b330 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -443,7 +443,7 @@ async def test_connect_args( ): got_conninfo: str - def fake_connect(conninfo): + def fake_connect(conninfo, *, timeout=0.0): nonlocal got_conninfo got_conninfo = conninfo return pgconn @@ -465,11 +465,6 @@ async def test_connect_args( ], ) async def test_connect_badargs(aconn_cls, monkeypatch, pgconn, args, kwargs, exctype): - def fake_connect(conninfo): - return pgconn - yield - - monkeypatch.setattr(psycopg.generators, "connect", fake_connect) with pytest.raises(exctype): await aconn_cls.connect(*args, **kwargs) diff --git a/tests/test_module.py b/tests/test_module.py index 9b144d7d6..2b1869e94 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -22,10 +22,10 @@ def test_connect(monkeypatch, dsn_env, args, kwargs, want, setpgenv): got_conninfo: str - def mock_connect(conninfo): + def mock_connect(conninfo, *, timeout): nonlocal got_conninfo got_conninfo = conninfo - return orig_connect(dsn_env) + return orig_connect(dsn_env, timeout=timeout) setpgenv({}) monkeypatch.setattr(psycopg.generators, "connect", mock_connect) diff --git a/tests/test_psycopg_dbapi20.py b/tests/test_psycopg_dbapi20.py index b4feac792..df1c981a6 100644 --- a/tests/test_psycopg_dbapi20.py +++ b/tests/test_psycopg_dbapi20.py @@ -135,7 +135,7 @@ def test_time_from_ticks(ticks, want): def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_resolve): got_conninfo: str - def fake_connect(conninfo): + def fake_connect(conninfo, *, timeout=0.0): nonlocal got_conninfo got_conninfo = conninfo return pgconn @@ -157,9 +157,5 @@ def test_connect_args(monkeypatch, pgconn, args, kwargs, want, setpgenv, fake_re ], ) def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype): - def fake_connect(conninfo): - return pgconn - yield - with pytest.raises(exctype): psycopg.connect(*args, **kwargs)