]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add support for connection timeout
authorDenis Laxalde <denis.laxalde@dalibo.com>
Tue, 4 May 2021 15:36:33 +0000 (17:36 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 8 Jun 2021 12:22:38 +0000 (13:22 +0100)
In *Connection.connect(), we replace call to make_conninfo() by the new
_conninfo_connect_timeout() utility function which builds the 'conninfo'
string (using make_conninfo()) and extracts the 'connect_timeout'
parameter.

For the synchronous API, this timeout value is then handled to
waiting.wait_conn(), to be used in the select() call. There, if select()
does not return within timeout, we raise a DatabaseError.

For the asynchronous API, it is passed to waiting.wait_conn_async()
where we use asyncio.wait_for() to wait for the event and also raise a
DatabaseError in case of timeout.

psycopg3/psycopg3/connection.py
psycopg3/psycopg3/conninfo.py
psycopg3/psycopg3/waiting.py
tests/test_connection.py
tests/test_connection_async.py
tests/test_conninfo.py
tests/test_waiting.py

index 22af61d433ed92cc1b6f88b22e3519df951a048d..821ed6b5d2fa35959e7b575509aba46921f55c90 100644 (file)
@@ -28,7 +28,7 @@ from .proto import AdaptContext, ConnectionType, Params, PQGen, PQGenConn
 from .proto import Query, RV
 from .compat import asynccontextmanager
 from .cursor import Cursor, AsyncCursor
-from .conninfo import make_conninfo, ConnectionInfo
+from .conninfo import _conninfo_connect_timeout, ConnectionInfo
 from .generators import notifies
 from ._preparing import PrepareManager
 from .transaction import Transaction, AsyncTransaction
@@ -483,14 +483,13 @@ class Connection(BaseConnection[Row]):
     ) -> "Connection[Any]":
         """
         Connect to a database server and return a new `Connection` instance.
-
-        TODO: connection_timeout to be implemented.
         """
-        conninfo = make_conninfo(conninfo, **kwargs)
+        conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
         return cls._wait_conn(
             cls._connect_gen(
                 conninfo, autocommit=autocommit, row_factory=row_factory
-            )
+            ),
+            timeout,
         )
 
     def __enter__(self) -> "Connection[Row]":
@@ -639,9 +638,7 @@ class Connection(BaseConnection[Row]):
         return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
 
     @classmethod
-    def _wait_conn(
-        cls, gen: PQGenConn[RV], timeout: Optional[float] = 0.1
-    ) -> RV:
+    def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
         """Consume a connection generator."""
         return waiting.wait_conn(gen, timeout=timeout)
 
@@ -697,11 +694,12 @@ class AsyncConnection(BaseConnection[Row]):
         row_factory: Optional[RowFactory[Row]] = None,
         **kwargs: Any,
     ) -> "AsyncConnection[Any]":
-        conninfo = make_conninfo(conninfo, **kwargs)
+        conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
         return await cls._wait_conn(
             cls._connect_gen(
                 conninfo, autocommit=autocommit, row_factory=row_factory
-            )
+            ),
+            timeout,
         )
 
     async def __aenter__(self) -> "AsyncConnection[Row]":
@@ -836,8 +834,10 @@ class AsyncConnection(BaseConnection[Row]):
         return await waiting.wait_async(gen, self.pgconn.socket)
 
     @classmethod
-    async def _wait_conn(cls, gen: PQGenConn[RV]) -> RV:
-        return await waiting.wait_conn_async(gen)
+    async def _wait_conn(
+        cls, gen: PQGenConn[RV], timeout: Optional[int]
+    ) -> RV:
+        return await waiting.wait_conn_async(gen, timeout)
 
     def _set_client_encoding(self, name: str) -> None:
         raise AttributeError(
index e137f1dfe83c204c961cfed9aa23cb11b90d8a50..7a15bb4c70edd27bd79d2a071db01531d79ac709 100644 (file)
@@ -5,7 +5,7 @@ Functions to manipulate conninfo strings
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
 from pathlib import Path
 from datetime import tzinfo
 
@@ -95,6 +95,22 @@ def _param_escape(s: str) -> str:
     return s
 
 
+def _conninfo_connect_timeout(
+    conninfo: str, **kwargs: Any
+) -> Tuple[str, Optional[int]]:
+    """
+    Build 'conninfo' by combining input value with kwargs and extract
+    'connect_timeout' parameter.
+    """
+    conninfo = make_conninfo(conninfo, **kwargs)
+    connect_timeout: Optional[int]
+    try:
+        connect_timeout = int(conninfo_to_dict(conninfo)["connect_timeout"])
+    except KeyError:
+        connect_timeout = None
+    return conninfo, connect_timeout
+
+
 class ConnectionInfo:
     """Allow access to information about the connection."""
 
index a489feab34c93cacce54814041a5448621e80f14..afb9cd66b4ebd72e2134f9ea9a1f276fa6c972af 100644 (file)
@@ -13,7 +13,7 @@ import select
 import selectors
 from enum import IntEnum
 from typing import Optional
-from asyncio import get_event_loop, Event
+from asyncio import get_event_loop, wait_for, Event, TimeoutError
 from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
 
 from . import errors as e
@@ -71,23 +71,24 @@ def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
     :param gen: a generator performing database operations and yielding
         (fd, `Ready`) pairs when it would block.
     :param timeout: timeout (in seconds) to check for other interrupt, e.g.
-        to allow Ctrl-C.
+        to allow Ctrl-C. If zero or None, wait indefinitely.
     :type timeout: float
     :return: whatever *gen* returns on completion.
 
     Behave like in `wait()`, but take the fileno to wait from the generator
     itself, which might change during processing.
     """
+    timeout = timeout or None
     try:
         fileno, s = next(gen)
         sel = DefaultSelector()
         while 1:
             sel.register(fileno, s)
-            ready = None
-            while not ready:
-                ready = sel.select(timeout=timeout)
+            ready = sel.select(timeout=timeout)
             sel.unregister(fileno)
-            fileno, s = gen.send(ready[0][1])
+            if not ready:
+                raise e.DatabaseError("timeout expired")
+            fileno, s = gen.send(ready[0][1])  # type: ignore[arg-type]
 
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
@@ -144,14 +145,16 @@ async def wait_async(gen: PQGen[RV], fileno: int) -> RV:
         return rv
 
 
-async def wait_conn_async(gen: PQGenConn[RV]) -> RV:
+async def wait_conn_async(
+    gen: PQGenConn[RV], timeout: Optional[float] = None
+) -> RV:
     """
     Coroutine waiting for a connection generator to complete.
 
     :param gen: a generator performing database operations and yielding
         (fd, `Ready`) pairs when it would block.
     :param timeout: timeout (in seconds) to check for other interrupt, e.g.
-        to allow Ctrl-C.
+        to allow Ctrl-C. If zero or None, wait indefinitely.
     :return: whatever *gen* returns on completion.
 
     Behave like in `wait()`, but take the fileno to wait from the generator
@@ -169,28 +172,32 @@ async def wait_conn_async(gen: PQGenConn[RV]) -> RV:
         ready = state
         ev.set()
 
+    timeout = timeout or None
     try:
         fileno, s = next(gen)
         while 1:
             ev.clear()
             if s == Wait.R:
                 loop.add_reader(fileno, wakeup, Ready.R)
-                await ev.wait()
+                await wait_for(ev.wait(), timeout)
                 loop.remove_reader(fileno)
             elif s == Wait.W:
                 loop.add_writer(fileno, wakeup, Ready.W)
-                await ev.wait()
+                await wait_for(ev.wait(), timeout)
                 loop.remove_writer(fileno)
             elif s == Wait.RW:
                 loop.add_reader(fileno, wakeup, Ready.R)
                 loop.add_writer(fileno, wakeup, Ready.W)
-                await ev.wait()
+                await wait_for(ev.wait(), timeout)
                 loop.remove_reader(fileno)
                 loop.remove_writer(fileno)
             else:
                 raise e.InternalError("bad poll status: %s")
             fileno, s = gen.send(ready)
 
+    except TimeoutError:
+        raise e.DatabaseError("timeout expired")
+
     except StopIteration as ex:
         rv: RV = ex.args[0] if ex.args else None
         return rv
index b97570d49b828e38fc45adbd56ecc6a0b2661a2b..ea8aced2180b3a181c6bcc3bdba6755e3eb8c582 100644 (file)
@@ -37,7 +37,6 @@ def test_connect_bad():
 
 
 @pytest.mark.slow
-@pytest.mark.xfail
 @pytest.mark.skipif(sys.platform == "win32", reason="connect() hangs on Win32")
 def test_connect_timeout():
     s = socket.socket(socket.AF_INET)
index e4406d1a42ed89998e2426040a73024172e417b6..e81aa42b0fb05ee18822cf618723ee6bd1c2f69d 100644 (file)
@@ -38,7 +38,6 @@ async def test_connect_str_subclass(dsn):
 
 
 @pytest.mark.slow
-@pytest.mark.xfail
 async def test_connect_timeout():
     s = socket.socket(socket.AF_INET)
     s.bind(("", 0))
index 94f936a69e9f3f066ee9e807e55817732094e5a9..eb2e8c65d6c32221a5c41049692fea65cdab11e7 100644 (file)
@@ -4,7 +4,12 @@ import pytest
 
 import psycopg3
 from psycopg3 import ProgrammingError
-from psycopg3.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from psycopg3.conninfo import (
+    _conninfo_connect_timeout,
+    make_conninfo,
+    conninfo_to_dict,
+    ConnectionInfo,
+)
 
 snowman = "\u2603"
 
@@ -89,6 +94,37 @@ def test_no_munging():
     assert dsnin == dsnout
 
 
+@pytest.mark.parametrize(
+    "dsn, kwargs, exp",
+    [
+        (
+            "",
+            {"host": "localhost", "connect_timeout": 1},
+            ({"host": "localhost", "connect_timeout": "1"}, 1),
+        ),
+        (
+            "dbname=postgres",
+            {},
+            ({"dbname": "postgres"}, None),
+        ),
+        (
+            "dbname=postgres connect_timeout=2",
+            {},
+            ({"dbname": "postgres", "connect_timeout": "2"}, 2),
+        ),
+        (
+            "postgresql:///postgres?connect_timeout=2",
+            {"connect_timeout": 10},
+            ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+        ),
+    ],
+)
+def test__conninfo_connect_timeout(dsn, kwargs, exp):
+    conninfo, connect_timeout = _conninfo_connect_timeout(dsn, **kwargs)
+    assert conninfo_to_dict(conninfo) == exp[0]
+    assert connect_timeout == exp[1]
+
+
 class TestConnectionInfo:
     @pytest.mark.parametrize(
         "attr",
index c5f3b71daf5816b422727b87bc7202872ffeb193..2b40d3c0f5bb57897f712fec6cd4619b39b47c5a 100644 (file)
@@ -75,10 +75,11 @@ def test_wait_epoll_bad(pgconn):
     assert res.status == ExecStatus.TUPLES_OK
 
 
+@pytest.mark.parametrize("timeout", timeouts)
 @pytest.mark.asyncio
-async def test_wait_conn_async(dsn):
+async def test_wait_conn_async(dsn, timeout):
     gen = generators.connect(dsn)
-    conn = await waiting.wait_conn_async(gen)
+    conn = await waiting.wait_conn_async(gen, **timeout)
     assert conn.status == ConnStatus.OK