import warnings
import threading
from types import TracebackType
-from typing import Any, Callable, cast, Generic, Iterator, List
+from typing import Any, Callable, cast, Dict, Generic, Iterator, List
from typing import NamedTuple, Optional, Type, TypeVar, Union
from typing import overload, TYPE_CHECKING
from weakref import ref, ReferenceType
from ._enums import IsolationLevel
from .cursor import Cursor
from ._cmodule import _psycopg
-from .conninfo import _conninfo_connect_timeout, ConnectionInfo
+from .conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
from .generators import notifies
from ._preparing import PrepareManager
from .transaction import Transaction
"""
Connect to a database server and return a new `Connection` instance.
"""
- conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
+ params = cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
rv = cls._wait_conn(
cls._connect_gen(conninfo, autocommit=autocommit),
- timeout,
+ timeout=params["connect_timeout"],
)
if row_factory:
rv.row_factory = row_factory
if not getattr(self, "_pool", None):
self.close()
+ @classmethod
+ def _get_connection_params(
+ cls, conninfo: str, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Adjust connection parameters before conecting."""
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ # TODO: SRV lookup (RFC 2782)
+
+ return params
+
def close(self) -> None:
"""Close the database connection."""
if self.closed:
import logging
import warnings
from types import TracebackType
-from typing import Any, AsyncIterator, cast
-from typing import Optional, Type, Union
-from typing import overload, TYPE_CHECKING
+from typing import Any, AsyncIterator, Dict, Optional, Type, Union
+from typing import cast, overload, TYPE_CHECKING
from . import errors as e
from . import waiting
from .rows import Row, AsyncRowFactory, tuple_row, TupleRow
from ._enums import IsolationLevel
from .compat import asynccontextmanager
-from .conninfo import _conninfo_connect_timeout
+from .conninfo import make_conninfo, conninfo_to_dict
from .connection import BaseConnection, CursorRow, Notify
from .generators import notifies
from .transaction import AsyncTransaction
row_factory: Optional[AsyncRowFactory[Row]] = None,
**kwargs: Any,
) -> "AsyncConnection[Any]":
- conninfo, timeout = _conninfo_connect_timeout(conninfo, **kwargs)
+ params = await cls._get_connection_params(conninfo, **kwargs)
+ conninfo = make_conninfo(**params)
+
rv = await cls._wait_conn(
cls._connect_gen(conninfo, autocommit=autocommit),
- timeout,
+ timeout=params["connect_timeout"],
)
if row_factory:
rv.row_factory = row_factory
if not getattr(self, "_pool", None):
await self.close()
+ @classmethod
+ async def _get_connection_params(
+ cls, conninfo: str, **kwargs: Any
+ ) -> Dict[str, Any]:
+ """Adjust connection parameters before conecting."""
+ params = conninfo_to_dict(conninfo, **kwargs)
+
+ # Make sure there is an usable connect_timeout
+ if "connect_timeout" in params:
+ params["connect_timeout"] = int(params["connect_timeout"])
+ else:
+ params["connect_timeout"] = None
+
+ # TODO: resolve host names to hostaddr asynchronously
+ # TODO: SRV lookup (RFC 2782)
+
+ return params
+
async def close(self) -> None:
if self.closed:
return
# Copyright (C) 2020-2021 The Psycopg Team
import re
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional
from pathlib import Path
from datetime import tzinfo
return conninfo
-def conninfo_to_dict(conninfo: str) -> Dict[str, str]:
+def conninfo_to_dict(conninfo: str = "", **kwargs: Any) -> Dict[str, Any]:
"""
Convert the *conninfo* string into a dictionary of parameters.
Raise ProgrammingError if the string is not valid.
"""
opts = _parse_conninfo(conninfo)
- return {
+ rv = {
opt.keyword.decode("utf8"): opt.val.decode("utf8")
for opt in opts
if opt.val is not None
}
+ for k, v in kwargs.items():
+ if v is not None:
+ rv[k] = v
+ return rv
def _parse_conninfo(conninfo: str) -> List[pq.ConninfoOption]:
return s
-def _conninfo_connect_timeout(
- conninfo: str, **kwargs: Any
-) -> Tuple[str, Optional[int]]:
- """
- Build 'conninfo' by combining input value with kwargs and extract
- 'connect_timeout' parameter.
- """
- conninfo = make_conninfo(conninfo, **kwargs)
- connect_timeout: Optional[int]
- try:
- connect_timeout = int(conninfo_to_dict(conninfo)["connect_timeout"])
- except KeyError:
- connect_timeout = None
- return conninfo, connect_timeout
-
-
class ConnectionInfo:
"""Allow access to information about the connection."""
from psycopg import Connection, Notify
from psycopg.rows import tuple_row
from psycopg.errors import UndefinedTable
-from psycopg.conninfo import conninfo_to_dict
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
from .utils import gc_collect
from .test_cursor import my_row_factory
conn.deferrable = 0
assert conn.deferrable is False
+
+
+conninfo_params_timeout = [
+ (
+ "",
+ {"host": "localhost", "connect_timeout": None},
+ ({"host": "localhost"}, None),
+ ),
+ (
+ "",
+ {"host": "localhost", "connect_timeout": 1},
+ ({"host": "localhost", "connect_timeout": "1"}, 1),
+ ),
+ (
+ "dbname=postgres",
+ {},
+ ({"dbname": "postgres"}, None),
+ ),
+ (
+ "dbname=postgres connect_timeout=2",
+ {},
+ ({"dbname": "postgres", "connect_timeout": "2"}, 2),
+ ),
+ (
+ "postgresql:///postgres?connect_timeout=2",
+ {"connect_timeout": 10},
+ ({"dbname": "postgres", "connect_timeout": "10"}, 10),
+ ),
+]
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+def test_get_connection_params(dsn, kwargs, exp):
+ params = Connection._get_connection_params(dsn, **kwargs)
+ conninfo = make_conninfo(**params)
+ assert conninfo_to_dict(conninfo) == exp[0]
+ assert params.get("connect_timeout") == exp[1]
from psycopg import AsyncConnection, Notify
from psycopg.rows import tuple_row
from psycopg.errors import UndefinedTable
-from psycopg.conninfo import conninfo_to_dict
+from psycopg.conninfo import conninfo_to_dict, make_conninfo
from .utils import gc_collect
from .test_cursor import my_row_factory
-from .test_connection import tx_params, tx_values_map
+from .test_connection import tx_params, tx_values_map, conninfo_params_timeout
pytestmark = pytest.mark.asyncio
await aconn.set_deferrable(0)
assert aconn.deferrable is False
+
+
+@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
+async def test_get_connection_params(dsn, kwargs, exp):
+ params = await AsyncConnection._get_connection_params(dsn, **kwargs)
+ conninfo = make_conninfo(**params)
+ assert conninfo_to_dict(conninfo) == exp[0]
+ assert params["connect_timeout"] == exp[1]
import psycopg
from psycopg import ProgrammingError
-from psycopg.conninfo import (
- _conninfo_connect_timeout,
- make_conninfo,
- conninfo_to_dict,
- ConnectionInfo,
-)
+from psycopg.conninfo import make_conninfo, conninfo_to_dict, ConnectionInfo
snowman = "\u2603"
assert dsnin == dsnout
-@pytest.mark.parametrize(
- "dsn, kwargs, exp",
- [
- (
- "",
- {"host": "localhost", "connect_timeout": 1},
- ({"host": "localhost", "connect_timeout": "1"}, 1),
- ),
- (
- "dbname=postgres",
- {},
- ({"dbname": "postgres"}, None),
- ),
- (
- "dbname=postgres connect_timeout=2",
- {},
- ({"dbname": "postgres", "connect_timeout": "2"}, 2),
- ),
- (
- "postgresql:///postgres?connect_timeout=2",
- {"connect_timeout": 10},
- ({"dbname": "postgres", "connect_timeout": "10"}, 10),
- ),
- ],
-)
-def test__conninfo_connect_timeout(dsn, kwargs, exp):
- conninfo, connect_timeout = _conninfo_connect_timeout(dsn, **kwargs)
- assert conninfo_to_dict(conninfo) == exp[0]
- assert connect_timeout == exp[1]
-
-
class TestConnectionInfo:
@pytest.mark.parametrize(
"attr",