From: Denis Laxalde Date: Tue, 4 May 2021 15:36:33 +0000 (+0200) Subject: Add support for connection timeout X-Git-Tag: 3.0.dev0~36^2~1 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=102ec02871d37a93b0d8f08fc615d693b9464722;p=thirdparty%2Fpsycopg.git Add support for connection timeout In *Connection.connect(), we replace call to make_conninfo() by the new _conninfo_connect_timeout() utility function which builds the 'conninfo' string (using make_conninfo()) and extracts the 'connect_timeout' parameter. For the synchronous API, this timeout value is then handled to waiting.wait_conn(), to be used in the select() call. There, if select() does not return within timeout, we raise a DatabaseError. For the asynchronous API, it is passed to waiting.wait_conn_async() where we use asyncio.wait_for() to wait for the event and also raise a DatabaseError in case of timeout. --- diff --git a/psycopg3/psycopg3/connection.py b/psycopg3/psycopg3/connection.py index 22af61d43..821ed6b5d 100644 --- a/psycopg3/psycopg3/connection.py +++ b/psycopg3/psycopg3/connection.py @@ -28,7 +28,7 @@ from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn from .proto import Query, RV from .compat import asynccontextmanager from .cursor import Cursor, AsyncCursor -from .conninfo import make_conninfo, ConnectionInfo +from .conninfo import _conninfo_connect_timeout, ConnectionInfo from .generators import notifies from ._preparing import PrepareManager from .transaction import Transaction, AsyncTransaction @@ -483,14 +483,13 @@ class Connection(BaseConnection[Row]): ) -> "Connection[Any]": """ Connect to a database server and return a new `Connection` instance. - - TODO: connection_timeout to be implemented. """ - conninfo = make_conninfo(conninfo, **kwargs) + conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs) return cls._wait_conn( cls._connect_gen( conninfo, autocommit=autocommit, row_factory=row_factory - ) + ), + timeout, ) def __enter__(self) -> "Connection[Row]": @@ -639,9 +638,7 @@ class Connection(BaseConnection[Row]): return waiting.wait(gen, self.pgconn.socket, timeout=timeout) @classmethod - def _wait_conn( - cls, gen: PQGenConn[RV], timeout: Optional[float] = 0.1 - ) -> RV: + def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV: """Consume a connection generator.""" return waiting.wait_conn(gen, timeout=timeout) @@ -697,11 +694,12 @@ class AsyncConnection(BaseConnection[Row]): row_factory: Optional[RowFactory[Row]] = None, **kwargs: Any, ) -> "AsyncConnection[Any]": - conninfo = make_conninfo(conninfo, **kwargs) + conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs) return await cls._wait_conn( cls._connect_gen( conninfo, autocommit=autocommit, row_factory=row_factory - ) + ), + timeout, ) async def __aenter__(self) -> "AsyncConnection[Row]": @@ -836,8 +834,10 @@ class AsyncConnection(BaseConnection[Row]): return await waiting.wait_async(gen, self.pgconn.socket) @classmethod - async def _wait_conn(cls, gen: PQGenConn[RV]) -> RV: - return await waiting.wait_conn_async(gen) + async def _wait_conn( + cls, gen: PQGenConn[RV], timeout: Optional[int] + ) -> RV: + return await waiting.wait_conn_async(gen, timeout) def _set_client_encoding(self, name: str) -> None: raise AttributeError( diff --git a/psycopg3/psycopg3/conninfo.py b/psycopg3/psycopg3/conninfo.py index e137f1dfe..7a15bb4c7 100644 --- a/psycopg3/psycopg3/conninfo.py +++ b/psycopg3/psycopg3/conninfo.py @@ -5,7 +5,7 @@ Functions to manipulate conninfo strings # Copyright (C) 2020-2021 The Psycopg Team import re -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from pathlib import Path from datetime import tzinfo @@ -95,6 +95,22 @@ def _param_escape(s: str) -> str: return s +def _conninfo_connect_timeout( + conninfo: str, **kwargs: Any +) -> Tuple[str, Optional[int]]: + """ + Build 'conninfo' by combining input value with kwargs and extract + 'connect_timeout' parameter. + """ + conninfo = make_conninfo(conninfo, **kwargs) + connect_timeout: Optional[int] + try: + connect_timeout = int(conninfo_to_dict(conninfo)["connect_timeout"]) + except KeyError: + connect_timeout = None + return conninfo, connect_timeout + + class ConnectionInfo: """Allow access to information about the connection.""" diff --git a/psycopg3/psycopg3/waiting.py b/psycopg3/psycopg3/waiting.py index a489feab3..afb9cd66b 100644 --- a/psycopg3/psycopg3/waiting.py +++ b/psycopg3/psycopg3/waiting.py @@ -13,7 +13,7 @@ import select import selectors from enum import IntEnum from typing import Optional -from asyncio import get_event_loop, Event +from asyncio import get_event_loop, wait_for, Event, TimeoutError from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE from . import errors as e @@ -71,23 +71,24 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV: :param gen: a generator performing database operations and yielding (fd, `Ready`) pairs when it would block. :param timeout: timeout (in seconds) to check for other interrupt, e.g. - to allow Ctrl-C. + to allow Ctrl-C. If zero or None, wait indefinitely. :type timeout: float :return: whatever *gen* returns on completion. Behave like in `wait()`, but take the fileno to wait from the generator itself, which might change during processing. """ + timeout = timeout or None try: fileno, s = next(gen) sel = DefaultSelector() while 1: sel.register(fileno, s) - ready = None - while not ready: - ready = sel.select(timeout=timeout) + ready = sel.select(timeout=timeout) sel.unregister(fileno) - fileno, s = gen.send(ready[0][1]) + if not ready: + raise e.DatabaseError("timeout expired") + fileno, s = gen.send(ready[0][1]) # type: ignore[arg-type] except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None @@ -144,14 +145,16 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV: return rv -async def wait_conn_async(gen: PQGenConn[RV]) -> RV: +async def wait_conn_async( + gen: PQGenConn[RV], timeout: Optional[float] = None +) -> RV: """ Coroutine waiting for a connection generator to complete. :param gen: a generator performing database operations and yielding (fd, `Ready`) pairs when it would block. :param timeout: timeout (in seconds) to check for other interrupt, e.g. - to allow Ctrl-C. + to allow Ctrl-C. If zero or None, wait indefinitely. :return: whatever *gen* returns on completion. Behave like in `wait()`, but take the fileno to wait from the generator @@ -169,28 +172,32 @@ async def wait_conn_async(gen: PQGenConn[RV]) -> RV: ready = state ev.set() + timeout = timeout or None try: fileno, s = next(gen) while 1: ev.clear() if s == Wait.R: loop.add_reader(fileno, wakeup, Ready.R) - await ev.wait() + await wait_for(ev.wait(), timeout) loop.remove_reader(fileno) elif s == Wait.W: loop.add_writer(fileno, wakeup, Ready.W) - await ev.wait() + await wait_for(ev.wait(), timeout) loop.remove_writer(fileno) elif s == Wait.RW: loop.add_reader(fileno, wakeup, Ready.R) loop.add_writer(fileno, wakeup, Ready.W) - await ev.wait() + await wait_for(ev.wait(), timeout) loop.remove_reader(fileno) loop.remove_writer(fileno) else: raise e.InternalError("bad poll status: %s") fileno, s = gen.send(ready) + except TimeoutError: + raise e.DatabaseError("timeout expired") + except StopIteration as ex: rv: RV = ex.args[0] if ex.args else None return rv diff --git a/tests/test_connection.py b/tests/test_connection.py index b97570d49..ea8aced21 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -37,7 +37,6 @@ def test_connect_bad(): @pytest.mark.slow -@pytest.mark.xfail @pytest.mark.skipif(sys.platform == "win32", reason="connect() hangs on Win32") def test_connect_timeout(): s = socket.socket(socket.AF_INET) diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index e4406d1a4..e81aa42b0 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -38,7 +38,6 @@ async def test_connect_str_subclass(dsn): @pytest.mark.slow -@pytest.mark.xfail async def test_connect_timeout(): s = socket.socket(socket.AF_INET) s.bind(("", 0)) diff --git a/tests/test_conninfo.py b/tests/test_conninfo.py index 94f936a69..eb2e8c65d 100644 --- a/tests/test_conninfo.py +++ b/tests/test_conninfo.py @@ -4,7 +4,12 @@ import pytest import psycopg3 from psycopg3 import ProgrammingError -from psycopg3.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo +from psycopg3.conninfo import ( + _conninfo_connect_timeout, + make_conninfo, + conninfo_to_dict, + ConnectionInfo, +) snowman = "\u2603" @@ -89,6 +94,37 @@ def test_no_munging(): assert dsnin == dsnout +@pytest.mark.parametrize( + "dsn, kwargs, exp", + [ + ( + "", + {"host": "localhost", "connect_timeout": 1}, + ({"host": "localhost", "connect_timeout": "1"}, 1), + ), + ( + "dbname=postgres", + {}, + ({"dbname": "postgres"}, None), + ), + ( + "dbname=postgres connect_timeout=2", + {}, + ({"dbname": "postgres", "connect_timeout": "2"}, 2), + ), + ( + "postgresql:///postgres?connect_timeout=2", + {"connect_timeout": 10}, + ({"dbname": "postgres", "connect_timeout": "10"}, 10), + ), + ], +) +def test__conninfo_connect_timeout(dsn, kwargs, exp): + conninfo, connect_timeout = _conninfo_connect_timeout(dsn, **kwargs) + assert conninfo_to_dict(conninfo) == exp[0] + assert connect_timeout == exp[1] + + class TestConnectionInfo: @pytest.mark.parametrize( "attr", diff --git a/tests/test_waiting.py b/tests/test_waiting.py index c5f3b71da..2b40d3c0f 100644 --- a/tests/test_waiting.py +++ b/tests/test_waiting.py @@ -75,10 +75,11 @@ def test_wait_epoll_bad(pgconn): assert res.status == ExecStatus.TUPLES_OK +@pytest.mark.parametrize("timeout", timeouts) @pytest.mark.asyncio -async def test_wait_conn_async(dsn): +async def test_wait_conn_async(dsn, timeout): gen = generators.connect(dsn) - conn = await waiting.wait_conn_async(gen) + conn = await waiting.wait_conn_async(gen, **timeout) assert conn.status == ConnStatus.OK