from .proto import Query, RV
from .compat import asynccontextmanager
from .cursor import Cursor, AsyncCursor
-from .conninfo import make_conninfo, ConnectionInfo
+from .conninfo import _conninfo_connect_timeout, ConnectionInfo
from .generators import notifies
from ._preparing import PrepareManager
from .transaction import Transaction, AsyncTransaction
) -> "Connection[Any]":
"""
Connect to a database server and return a new `Connection` instance.
-
- TODO: connection_timeout to be implemented.
"""
- conninfo = make_conninfo(conninfo, **kwargs)
+ conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
return cls._wait_conn(
cls._connect_gen(
conninfo, autocommit=autocommit, row_factory=row_factory
- )
+ ),
+ timeout,
)
def __enter__(self) -> "Connection[Row]":
return waiting.wait(gen, self.pgconn.socket, timeout=timeout)
@classmethod
- def _wait_conn(
- cls, gen: PQGenConn[RV], timeout: Optional[float] = 0.1
- ) -> RV:
+ def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
"""Consume a connection generator."""
return waiting.wait_conn(gen, timeout=timeout)
row_factory: Optional[RowFactory[Row]] = None,
**kwargs: Any,
) -> "AsyncConnection[Any]":
- conninfo = make_conninfo(conninfo, **kwargs)
+ conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
return await cls._wait_conn(
cls._connect_gen(
conninfo, autocommit=autocommit, row_factory=row_factory
- )
+ ),
+ timeout,
)
async def __aenter__(self) -> "AsyncConnection[Row]":
return await waiting.wait_async(gen, self.pgconn.socket)
@classmethod
- async def _wait_conn(cls, gen: PQGenConn[RV]) -> RV:
- return await waiting.wait_conn_async(gen)
+ async def _wait_conn(
+ cls, gen: PQGenConn[RV], timeout: Optional[int]
+ ) -> RV:
+ return await waiting.wait_conn_async(gen, timeout)
def _set_client_encoding(self, name: str) -> None:
raise AttributeError(
# Copyright (C) 2020-2021 The Psycopg Team
import re
-from typing import Any, Dict, List, Optional
+from typing import Any, Dict, List, Optional, Tuple
from pathlib import Path
from datetime import tzinfo
return s
+def _conninfo_connect_timeout(
+ conninfo: str, **kwargs: Any
+) -> Tuple[str, Optional[int]]:
+ """
+ Build 'conninfo' by combining input value with kwargs and extract
+ 'connect_timeout' parameter.
+ """
+ conninfo = make_conninfo(conninfo, **kwargs)
+ connect_timeout: Optional[int]
+ try:
+ connect_timeout = int(conninfo_to_dict(conninfo)["connect_timeout"])
+ except KeyError:
+ connect_timeout = None
+ return conninfo, connect_timeout
+
+
class ConnectionInfo:
"""Allow access to information about the connection."""
import selectors
from enum import IntEnum
from typing import Optional
-from asyncio import get_event_loop, Event
+from asyncio import get_event_loop, wait_for, Event, TimeoutError
from selectors import DefaultSelector, EVENT_READ, EVENT_WRITE
from . import errors as e
:param gen: a generator performing database operations and yielding
(fd, `Ready`) pairs when it would block.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
- to allow Ctrl-C.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
:type timeout: float
:return: whatever *gen* returns on completion.
Behave like in `wait()`, but take the fileno to wait from the generator
itself, which might change during processing.
"""
+ timeout = timeout or None
try:
fileno, s = next(gen)
sel = DefaultSelector()
while 1:
sel.register(fileno, s)
- ready = None
- while not ready:
- ready = sel.select(timeout=timeout)
+ ready = sel.select(timeout=timeout)
sel.unregister(fileno)
- fileno, s = gen.send(ready[0][1])
+ if not ready:
+ raise e.DatabaseError("timeout expired")
+ fileno, s = gen.send(ready[0][1]) # type: ignore[arg-type]
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
-async def wait_conn_async(gen: PQGenConn[RV]) -> RV:
+async def wait_conn_async(
+ gen: PQGenConn[RV], timeout: Optional[float] = None
+) -> RV:
"""
Coroutine waiting for a connection generator to complete.
:param gen: a generator performing database operations and yielding
(fd, `Ready`) pairs when it would block.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
- to allow Ctrl-C.
+ to allow Ctrl-C. If zero or None, wait indefinitely.
:return: whatever *gen* returns on completion.
Behave like in `wait()`, but take the fileno to wait from the generator
ready = state
ev.set()
+ timeout = timeout or None
try:
fileno, s = next(gen)
while 1:
ev.clear()
if s == Wait.R:
loop.add_reader(fileno, wakeup, Ready.R)
- await ev.wait()
+ await wait_for(ev.wait(), timeout)
loop.remove_reader(fileno)
elif s == Wait.W:
loop.add_writer(fileno, wakeup, Ready.W)
- await ev.wait()
+ await wait_for(ev.wait(), timeout)
loop.remove_writer(fileno)
elif s == Wait.RW:
loop.add_reader(fileno, wakeup, Ready.R)
loop.add_writer(fileno, wakeup, Ready.W)
- await ev.wait()
+ await wait_for(ev.wait(), timeout)
loop.remove_reader(fileno)
loop.remove_writer(fileno)
else:
raise e.InternalError("bad poll status: %s")
fileno, s = gen.send(ready)
+ except TimeoutError:
+ raise e.DatabaseError("timeout expired")
+
except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
@pytest.mark.slow
-@pytest.mark.xfail
@pytest.mark.skipif(sys.platform == "win32", reason="connect() hangs on Win32")
def test_connect_timeout():
s = socket.socket(socket.AF_INET)
@pytest.mark.slow
-@pytest.mark.xfail
async def test_connect_timeout():
s = socket.socket(socket.AF_INET)
s.bind(("", 0))
import psycopg3
from psycopg3 import ProgrammingError
-from psycopg3.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
+from psycopg3.conninfo import (
+ _conninfo_connect_timeout,
+ make_conninfo,
+ conninfo_to_dict,
+ ConnectionInfo,
+)
snowman = "\u2603"
assert dsnin == dsnout
+@pytest.mark.parametrize(
+ "dsn, kwargs, exp",
+ [
+ (
+ "",
+ {"host": "localhost", "connect_timeout": 1},
+ ({"host": "localhost", "connect_timeout": "1"}, 1),
+ ),
+ (
+ "dbname=postgres",
+ {},
+ ({"dbname": "postgres"}, None),
+ ),
+ (
+ "dbname=postgres connect_timeout=2",
+ {},
+ ({"dbname": "postgres", "connect_timeout": "2"}, 2),
+ ),
+ (
+ "postgresql:///postgres?connect_timeout=2",
+ {"connect_timeout": 10},
+ ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+ ),
+ ],
+)
+def test__conninfo_connect_timeout(dsn, kwargs, exp):
+ conninfo, connect_timeout = _conninfo_connect_timeout(dsn, **kwargs)
+ assert conninfo_to_dict(conninfo) == exp[0]
+ assert connect_timeout == exp[1]
+
+
class TestConnectionInfo:
@pytest.mark.parametrize(
"attr",
assert res.status == ExecStatus.TUPLES_OK
+@pytest.mark.parametrize("timeout", timeouts)
@pytest.mark.asyncio
-async def test_wait_conn_async(dsn):
+async def test_wait_conn_async(dsn, timeout):
gen = generators.connect(dsn)
- conn = await waiting.wait_conn_async(gen)
+ conn = await waiting.wait_conn_async(gen, **timeout)
assert conn.status == ConnStatus.OK