]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fixed connect parameters
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 14 Apr 2020 06:37:13 +0000 (18:37 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 14 Apr 2020 08:26:13 +0000 (20:26 +1200)
Raise an error with no params, accept keywords only.

psycopg3/connection.py
tests/test_async_connection.py
tests/test_connection.py
tests/test_psycopg3_dbapi20.py

index bcf606450a6fb6ba0257348331c91c63e880608f..09791ac3a112682271d5aad5af64dfe75ce9a3c0 100644 (file)
@@ -124,11 +124,11 @@ class Connection(BaseConnection):
 
     @classmethod
     def connect(
-        cls, conninfo: str, connection_factory: Any = None, **kwargs: Any
+        cls, conninfo: Optional[str] = None, **kwargs: Any,
     ) -> "Connection":
-        if connection_factory is not None:
-            raise NotImplementedError()
-        conninfo = make_conninfo(conninfo, **kwargs)
+        if conninfo is None and not kwargs:
+            raise TypeError("missing conninfo and not parameters specified")
+        conninfo = make_conninfo(conninfo or "", **kwargs)
         gen = generators.connect(conninfo)
         pgconn = cls.wait(gen)
         return cls(pgconn)
@@ -194,8 +194,12 @@ class AsyncConnection(BaseConnection):
         self.cursor_factory = cursor.AsyncCursor
 
     @classmethod
-    async def connect(cls, conninfo: str = "", **kwargs: Any) -> "AsyncConnection":
-        conninfo = make_conninfo(conninfo, **kwargs)
+    async def connect(
+        cls, conninfo: Optional[str] = None, **kwargs: Any
+    ) -> "AsyncConnection":
+        if conninfo is None and not kwargs:
+            raise TypeError("missing conninfo and not parameters specified")
+        conninfo = make_conninfo(conninfo or "", **kwargs)
         gen = generators.connect(conninfo)
         pgconn = await cls.wait(gen)
         return cls(pgconn)
index 4f7dcdad8527a2fa143ca0053a50bce4ca299075..81fc057150a43d6f1280428878a9843a79b52dfe 100644 (file)
@@ -2,6 +2,7 @@ import pytest
 
 import psycopg3
 from psycopg3 import AsyncConnection
+from psycopg3.conninfo import conninfo_to_dict
 
 
 def test_connect(dsn, loop):
@@ -77,3 +78,48 @@ def test_set_encoding(aconn, loop):
 def test_set_encoding_bad(aconn, loop):
     with pytest.raises(psycopg3.DatabaseError):
         loop.run_until_complete(aconn.set_client_encoding("WAT"))
+
+
+@pytest.mark.parametrize(
+    "testdsn, kwargs, want",
+    [
+        ("", {}, ""),
+        ("host=foo user=bar", {}, "host=foo user=bar"),
+        ("host=foo", {"user": "baz"}, "host=foo user=baz"),
+        (
+            "host=foo port=5432",
+            {"host": "qux", "user": "joe"},
+            "host=qux user=joe port=5432",
+        ),
+        ("host=foo", {"user": None}, "host=foo"),
+    ],
+)
+def test_connect_args(monkeypatch, pgconn, loop, testdsn, kwargs, want):
+    the_conninfo = None
+
+    def fake_connect(conninfo):
+        nonlocal the_conninfo
+        the_conninfo = conninfo
+        return pgconn
+        yield
+
+    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    loop.run_until_complete(
+        psycopg3.AsyncConnection.connect(testdsn, **kwargs)
+    )
+    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+
+
+@pytest.mark.parametrize(
+    "args, kwargs", [((), {}), (("", ""), {}), ((), {"nosuchparam": 42})],
+)
+def test_connect_badargs(monkeypatch, pgconn, loop, args, kwargs):
+    def fake_connect(conninfo):
+        return pgconn
+        yield
+
+    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    with pytest.raises((TypeError, psycopg3.ProgrammingError)):
+        loop.run_until_complete(
+            psycopg3.AsyncConnection.connect(*args, **kwargs)
+        )
index 562dcfab77b2852a5a19e7e242afa7671e74a3bf..774a979bd243cc9b8c718571cb94632acf57c645 100644 (file)
@@ -2,6 +2,7 @@ import pytest
 
 import psycopg3
 from psycopg3 import Connection
+from psycopg3.conninfo import conninfo_to_dict
 
 
 def test_connect(dsn):
@@ -79,3 +80,44 @@ def test_set_encoding_unsupported(conn):
 def test_set_encoding_bad(conn):
     with pytest.raises(psycopg3.DatabaseError):
         conn.set_client_encoding("WAT")
+
+
+@pytest.mark.parametrize(
+    "testdsn, kwargs, want",
+    [
+        ("", {}, ""),
+        ("host=foo user=bar", {}, "host=foo user=bar"),
+        ("host=foo", {"user": "baz"}, "host=foo user=baz"),
+        (
+            "host=foo port=5432",
+            {"host": "qux", "user": "joe"},
+            "host=qux user=joe port=5432",
+        ),
+        ("host=foo", {"user": None}, "host=foo"),
+    ],
+)
+def test_connect_args(monkeypatch, pgconn, testdsn, kwargs, want):
+    the_conninfo = None
+
+    def fake_connect(conninfo):
+        nonlocal the_conninfo
+        the_conninfo = conninfo
+        return pgconn
+        yield
+
+    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    psycopg3.Connection.connect(testdsn, **kwargs)
+    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+
+
+@pytest.mark.parametrize(
+    "args, kwargs", [((), {}), (("", ""), {}), ((), {"nosuchparam": 42})],
+)
+def test_connect_badargs(monkeypatch, pgconn, args, kwargs):
+    def fake_connect(conninfo):
+        return pgconn
+        yield
+
+    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    with pytest.raises((TypeError, psycopg3.ProgrammingError)):
+        psycopg3.Connection.connect(*args, **kwargs)
index fffd5379dde3dd68d31705ab802d5ae6c684bf0f..c2e6727c8b8d1cd6e03fa7fa0703be987fb19dcb 100644 (file)
@@ -2,6 +2,7 @@ import pytest
 import datetime as dt
 
 import psycopg3
+from psycopg3.conninfo import conninfo_to_dict
 
 from . import dbapi20
 
@@ -101,3 +102,44 @@ def test_time_from_ticks(ticks, want):
     s = psycopg3.TimeFromTicks(ticks)
     want = dt.datetime.strptime(want, "%H:%M:%S.%f").time()
     assert s.replace(hour=0) == want
+
+
+@pytest.mark.parametrize(
+    "testdsn, kwargs, want",
+    [
+        ("", {}, ""),
+        ("host=foo user=bar", {}, "host=foo user=bar"),
+        ("host=foo", {"user": "baz"}, "host=foo user=baz"),
+        (
+            "host=foo port=5432",
+            {"host": "qux", "user": "joe"},
+            "host=qux user=joe port=5432",
+        ),
+        ("host=foo", {"user": None}, "host=foo"),
+    ],
+)
+def test_connect_args(monkeypatch, pgconn, testdsn, kwargs, want):
+    the_conninfo = None
+
+    def fake_connect(conninfo):
+        nonlocal the_conninfo
+        the_conninfo = conninfo
+        return pgconn
+        yield
+
+    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    psycopg3.connect(testdsn, **kwargs)
+    assert conninfo_to_dict(the_conninfo) == conninfo_to_dict(want)
+
+
+@pytest.mark.parametrize(
+    "args, kwargs", [((), {}), (("", ""), {}), ((), {"nosuchparam": 42})],
+)
+def test_connect_badargs(monkeypatch, pgconn, args, kwargs):
+    def fake_connect(conninfo):
+        return pgconn
+        yield
+
+    monkeypatch.setattr(psycopg3.generators, "connect", fake_connect)
+    with pytest.raises((TypeError, psycopg3.ProgrammingError)):
+        psycopg3.connect(*args, **kwargs)