From bc7b72baf67e460ad05ea1ca8863daddb0c48ce8 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Tue, 14 Apr 2020 18:37:13 +1200 Subject: [PATCH] Fixed connect parameters Raise an error with no params, accept keywords only. --- psycopg3/connection.py | 16 +++++++----- tests/test_async_connection.py | 46 ++++++++++++++++++++++++++++++++++ tests/test_connection.py | 42 +++++++++++++++++++++++++++++++ tests/test_psycopg3_dbapi20.py | 42 +++++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 6 deletions(-) diff --git a/psycopg3/connection.py b/psycopg3/connection.py index bcf606450..09791ac3a 100644 --- a/psycopg3/connection.py +++ b/psycopg3/connection.py @@ -124,11 +124,11 @@ class Connection(BaseConnection): @classmethod def connect( - cls, conninfo: str, connection_factory: Any = None, **kwargs: Any + cls, conninfo: Optional[str] = None, **kwargs: Any, ) -> "Connection": - if connection_factory is not None: - raise NotImplementedError() - conninfo = make_conninfo(conninfo, **kwargs) + if conninfo is None and not kwargs: + raise TypeError("missing conninfo and not parameters specified") + conninfo = make_conninfo(conninfo or "", **kwargs) gen = generators.connect(conninfo) pgconn = cls.wait(gen) return cls(pgconn) @@ -194,8 +194,12 @@ class AsyncConnection(BaseConnection): self.cursor_factory = cursor.AsyncCursor @classmethod - async def connect(cls, conninfo: str = "", **kwargs: Any) -> "AsyncConnection": - conninfo = make_conninfo(conninfo, **kwargs) + async def connect( + cls, conninfo: Optional[str] = None, **kwargs: Any + ) -> "AsyncConnection": + if conninfo is None and not kwargs: + raise TypeError("missing conninfo and not parameters specified") + conninfo = make_conninfo(conninfo or "", **kwargs) gen = generators.connect(conninfo) pgconn = await cls.wait(gen) return cls(pgconn) diff --git a/tests/test_async_connection.py b/tests/test_async_connection.py index 4f7dcdad8..81fc05715 100644 --- a/tests/test_async_connection.py +++ b/tests/test_async_connection.py @@ -2,6 +2,7 @@ import pytest import psycopg3 from psycopg3 import AsyncConnection +from psycopg3.conninfo import conninfo_to_dict def test_connect(dsn, loop): @@ -77,3 +78,48 @@ def test_set_encoding(aconn, loop): def test_set_encoding_bad(aconn, loop): with pytest.raises(psycopg3.DatabaseError): loop.run_until_complete(aconn.set_client_encoding("WAT")) + + +@pytest.mark.parametrize( + "testdsn, kwargs, want", + [ + ("", {}, ""), + ("host=foo user=bar", {}, "host=foo user=bar"), + ("host=foo", {"user": "baz"}, "host=foo user=baz"), + ( + "host=foo port=5432", + {"host": "qux", "user": "joe"}, + "host=qux user=joe port=5432", + ), + ("host=foo", {"user": None}, "host=foo"), + ], +) +def test_connect_args(monkeypatch, pgconn, loop, testdsn, kwargs, want): + the_conninfo = None + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + monkeypatch.setattr(psycopg3.generators, "connect", fake_connect) + loop.run_until_complete( + psycopg3.AsyncConnection.connect(testdsn, **kwargs) + ) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + + +@pytest.mark.parametrize( + "args, kwargs", [((), {}), (("", ""), {}), ((), {"nosuchparam": 42})], +) +def test_connect_badargs(monkeypatch, pgconn, loop, args, kwargs): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg3.generators, "connect", fake_connect) + with pytest.raises((TypeError, psycopg3.ProgrammingError)): + loop.run_until_complete( + psycopg3.AsyncConnection.connect(*args, **kwargs) + ) diff --git a/tests/test_connection.py b/tests/test_connection.py index 562dcfab7..774a979bd 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -2,6 +2,7 @@ import pytest import psycopg3 from psycopg3 import Connection +from psycopg3.conninfo import conninfo_to_dict def test_connect(dsn): @@ -79,3 +80,44 @@ def test_set_encoding_unsupported(conn): def test_set_encoding_bad(conn): with pytest.raises(psycopg3.DatabaseError): conn.set_client_encoding("WAT") + + +@pytest.mark.parametrize( + "testdsn, kwargs, want", + [ + ("", {}, ""), + ("host=foo user=bar", {}, "host=foo user=bar"), + ("host=foo", {"user": "baz"}, "host=foo user=baz"), + ( + "host=foo port=5432", + {"host": "qux", "user": "joe"}, + "host=qux user=joe port=5432", + ), + ("host=foo", {"user": None}, "host=foo"), + ], +) +def test_connect_args(monkeypatch, pgconn, testdsn, kwargs, want): + the_conninfo = None + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + monkeypatch.setattr(psycopg3.generators, "connect", fake_connect) + psycopg3.Connection.connect(testdsn, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + + +@pytest.mark.parametrize( + "args, kwargs", [((), {}), (("", ""), {}), ((), {"nosuchparam": 42})], +) +def test_connect_badargs(monkeypatch, pgconn, args, kwargs): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg3.generators, "connect", fake_connect) + with pytest.raises((TypeError, psycopg3.ProgrammingError)): + psycopg3.Connection.connect(*args, **kwargs) diff --git a/tests/test_psycopg3_dbapi20.py b/tests/test_psycopg3_dbapi20.py index fffd5379d..c2e6727c8 100644 --- a/tests/test_psycopg3_dbapi20.py +++ b/tests/test_psycopg3_dbapi20.py @@ -2,6 +2,7 @@ import pytest import datetime as dt import psycopg3 +from psycopg3.conninfo import conninfo_to_dict from . import dbapi20 @@ -101,3 +102,44 @@ def test_time_from_ticks(ticks, want): s = psycopg3.TimeFromTicks(ticks) want = dt.datetime.strptime(want, "%H:%M:%S.%f").time() assert s.replace(hour=0) == want + + +@pytest.mark.parametrize( + "testdsn, kwargs, want", + [ + ("", {}, ""), + ("host=foo user=bar", {}, "host=foo user=bar"), + ("host=foo", {"user": "baz"}, "host=foo user=baz"), + ( + "host=foo port=5432", + {"host": "qux", "user": "joe"}, + "host=qux user=joe port=5432", + ), + ("host=foo", {"user": None}, "host=foo"), + ], +) +def test_connect_args(monkeypatch, pgconn, testdsn, kwargs, want): + the_conninfo = None + + def fake_connect(conninfo): + nonlocal the_conninfo + the_conninfo = conninfo + return pgconn + yield + + monkeypatch.setattr(psycopg3.generators, "connect", fake_connect) + psycopg3.connect(testdsn, **kwargs) + assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want) + + +@pytest.mark.parametrize( + "args, kwargs", [((), {}), (("", ""), {}), ((), {"nosuchparam": 42})], +) +def test_connect_badargs(monkeypatch, pgconn, args, kwargs): + def fake_connect(conninfo): + return pgconn + yield + + monkeypatch.setattr(psycopg3.generators, "connect", fake_connect) + with pytest.raises((TypeError, psycopg3.ProgrammingError)): + psycopg3.connect(*args, **kwargs) -- 2.47.2