]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Parsing conninfo and returning defaults moved to its own class
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 Mar 2020 12:54:03 +0000 (01:54 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 Mar 2020 12:57:14 +0000 (01:57 +1300)
psycopg3/pq.py
psycopg3/pq_ctypes.py
tests/test_pq_conninfo.py [new file with mode: 0644]
tests/test_pq_pgconn.py [moved from tests/test_pq.py with 83% similarity]

index f0a5dc74753ef56094da246ff0facaf379ee829c..87daa745b7f2634a96bbdd3ebc2fc4d4e57d2f0b 100644 (file)
@@ -20,6 +20,7 @@ from . import pq_ctypes as pq_module
 
 PGconn = pq_module.PGconn
 PQerror = pq_module.PQerror
+Conninfo = pq_module.Conninfo
 
 __all__ = (
     "ConnStatus",
@@ -27,5 +28,6 @@ __all__ = (
     "TransactionStatus",
     "Ping",
     "PGconn",
+    "Conninfo",
     "PQerror",
 )
index fba097c81c752159127e297b00f0e45c955e03bc..5469a3383850975cb576436ed63b41df0a60286e 100644 (file)
@@ -39,7 +39,7 @@ class PGconn:
         if isinstance(conninfo, str):
             conninfo = conninfo.encode("utf8")
         if not isinstance(conninfo, bytes):
-            raise TypeError("bytes expected, got %r instead" % conninfo)
+            raise TypeError(f"bytes expected, got {conninfo!r} instead")
 
         pgconn_ptr = impl.PQconnectdb(conninfo)
         return cls(pgconn_ptr)
@@ -49,7 +49,7 @@ class PGconn:
         if isinstance(conninfo, str):
             conninfo = conninfo.encode("utf8")
         if not isinstance(conninfo, bytes):
-            raise TypeError("bytes expected, got %r instead" % conninfo)
+            raise TypeError(f"bytes expected, got {conninfo!r} instead")
 
         pgconn_ptr = impl.PQconnectStart(conninfo)
         return cls(pgconn_ptr)
@@ -63,48 +63,16 @@ class PGconn:
         if p is not None:
             impl.PQfinish(p)
 
-    @classmethod
-    def get_defaults(cls):
-        opts = impl.PQconndefaults()
-        if not opts:
-            raise MemoryError("couldn't allocate connection defaults")
-        try:
-            return _conninfoopts_from_array(opts)
-        finally:
-            impl.PQconninfoFree(opts)
-
     @property
     def info(self):
         opts = impl.PQconninfo(self.pgconn_ptr)
         if not opts:
             raise MemoryError("couldn't allocate connection info")
         try:
-            return _conninfoopts_from_array(opts)
+            return Conninfo._options_from_array(opts)
         finally:
             impl.PQconninfoFree(opts)
 
-    @classmethod
-    def parse_conninfo(cls, conninfo):
-        if isinstance(conninfo, str):
-            conninfo = conninfo.encode("utf8")
-        if not isinstance(conninfo, bytes):
-            raise TypeError("bytes expected, got %r instead" % conninfo)
-
-        errmsg = c_char_p()
-        rv = impl.PQconninfoParse(conninfo, pointer(errmsg))
-        if not rv:
-            if not errmsg:
-                raise MemoryError("couldn't allocate on conninfo parse")
-            else:
-                exc = PQerror(errmsg.value.decode("utf8", "replace"))
-                impl.PQfreemem(errmsg)
-                raise exc
-
-        try:
-            return _conninfoopts_from_array(rv)
-        finally:
-            impl.PQconninfoFree(rv)
-
     def reset(self):
         impl.PQreset(self.pgconn_ptr)
 
@@ -122,7 +90,7 @@ class PGconn:
         if isinstance(conninfo, str):
             conninfo = conninfo.encode("utf8")
         if not isinstance(conninfo, bytes):
-            raise TypeError("bytes expected, got %r instead" % conninfo)
+            raise TypeError(f"bytes expected, got {conninfo!r} instead")
 
         rv = impl.PQping(conninfo)
         return Ping(rv)
@@ -212,7 +180,7 @@ class PGconn:
             # TODO: encode in client encoding
             return s.encode("utf8")
         else:
-            raise TypeError("expected bytes or str, got %r instead" % s)
+            raise TypeError(f"expected bytes or str, got {s!r} instead")
 
     def _decode(self, b):
         if b is None:
@@ -226,20 +194,54 @@ ConninfoOption = namedtuple(
 )
 
 
-def _conninfoopts_from_array(opts):
-    def gets(opt, kw):
-        rv = getattr(opt, kw)
-        if rv is not None:
-            rv = rv.decode("utf8", "replace")
-        return rv
+class Conninfo:
+    @classmethod
+    def get_defaults(cls):
+        opts = impl.PQconndefaults()
+        if not opts:
+            raise MemoryError("couldn't allocate connection defaults")
+        try:
+            return cls._options_from_array(opts)
+        finally:
+            impl.PQconninfoFree(opts)
 
-    rv = []
-    skws = "keyword envvar compiled val label dispatcher".split()
-    for opt in opts:
-        if not opt.keyword:
-            break
-        d = {kw: gets(opt, kw) for kw in skws}
-        d["dispsize"] = opt.dispsize
-        rv.append(ConninfoOption(**d))
+    @classmethod
+    def parse(cls, conninfo):
+        if isinstance(conninfo, str):
+            conninfo = conninfo.encode("utf8")
+        if not isinstance(conninfo, bytes):
+            raise TypeError(f"bytes expected, got {conninfo!r} instead")
 
-    return rv
+        errmsg = c_char_p()
+        rv = impl.PQconninfoParse(conninfo, pointer(errmsg))
+        if not rv:
+            if not errmsg:
+                raise MemoryError("couldn't allocate on conninfo parse")
+            else:
+                exc = PQerror(errmsg.value.decode("utf8", "replace"))
+                impl.PQfreemem(errmsg)
+                raise exc
+
+        try:
+            return cls._options_from_array(rv)
+        finally:
+            impl.PQconninfoFree(rv)
+
+    @classmethod
+    def _options_from_array(cls, opts):
+        def gets(opt, kw):
+            rv = getattr(opt, kw)
+            if rv is not None:
+                rv = rv.decode("utf8", "replace")
+            return rv
+
+        rv = []
+        skws = "keyword envvar compiled val label dispatcher".split()
+        for opt in opts:
+            if not opt.keyword:
+                break
+            d = {kw: gets(opt, kw) for kw in skws}
+            d["dispsize"] = opt.dispsize
+            rv.append(ConninfoOption(**d))
+
+        return rv
diff --git a/tests/test_pq_conninfo.py b/tests/test_pq_conninfo.py
new file mode 100644 (file)
index 0000000..b29dc48
--- /dev/null
@@ -0,0 +1,32 @@
+import pytest
+
+
+def test_defaults(pq, tempenv):
+    tempenv["PGPORT"] = "15432"
+    defs = pq.Conninfo.get_defaults()
+    assert len(defs) > 20
+    port = [d for d in defs if d.keyword == "port"][0]
+    assert port.envvar == "PGPORT"
+    assert port.compiled == "5432"
+    assert port.val == "15432"
+    assert port.label == "Database-Port"
+    assert port.dispatcher == ""
+    assert port.dispsize == 6
+
+
+def test_conninfo_parse(pq):
+    info = pq.Conninfo.parse(
+        "postgresql://host1:123,host2:456/somedb"
+        "?target_session_attrs=any&application_name=myapp"
+    )
+    info = {i.keyword: i.val for i in info if i.val is not None}
+    assert info["host"] == "host1,host2"
+    assert info["port"] == "123,456"
+    assert info["dbname"] == "somedb"
+    assert info["application_name"] == "myapp"
+
+
+def test_conninfo_parse_bad(pq):
+    with pytest.raises(pq.PQerror) as e:
+        pq.Conninfo.parse("bad_conninfo=")
+        assert "bad_conninfo" in str(e.value)
similarity index 83%
rename from tests/test_pq.py
rename to tests/test_pq_pgconn.py
index e77f1d5ae3ce8fbea8cde5e6166dd8f3dce2144c..53a1f5176d11222345ffbcfd6d67f1fd4c34299f 100644 (file)
@@ -60,20 +60,7 @@ def test_connect_async_bad(pq, dsn):
     assert conn.status == ConnStatus.CONNECTION_BAD
 
 
-def test_defaults(pq, tempenv):
-    tempenv["PGPORT"] = "15432"
-    defs = pq.PGconn.get_defaults()
-    assert len(defs) > 20
-    port = [d for d in defs if d.keyword == "port"][0]
-    assert port.envvar == "PGPORT"
-    assert port.compiled == "5432"
-    assert port.val == "15432"
-    assert port.label == "Database-Port"
-    assert port.dispatcher == ""
-    assert port.dispsize == 6
-
-
-def test_info(dsn, pgconn):
+def test_info(pq, dsn, pgconn):
     info = pgconn.info
     assert len(info) > 20
     dbname = [d for d in info if d.keyword == "dbname"][0]
@@ -82,29 +69,11 @@ def test_info(dsn, pgconn):
     assert dbname.dispatcher == ""
     assert dbname.dispsize == 20
 
-    parsed = pgconn.parse_conninfo(dsn)
+    parsed = pq.Conninfo.parse(dsn)
     name = [o.val for o in parsed if o.keyword == "dbname"][0]
     assert dbname.val == name
 
 
-def test_conninfo_parse(pq):
-    info = pq.PGconn.parse_conninfo(
-        "postgresql://host1:123,host2:456/somedb"
-        "?target_session_attrs=any&application_name=myapp"
-    )
-    info = {i.keyword: i.val for i in info if i.val is not None}
-    assert info["host"] == "host1,host2"
-    assert info["port"] == "123,456"
-    assert info["dbname"] == "somedb"
-    assert info["application_name"] == "myapp"
-
-
-def test_conninfo_parse_bad(pq):
-    with pytest.raises(pq.PQerror) as e:
-        pq.PGconn.parse_conninfo("bad_conninfo=")
-        assert "bad_conninfo" in str(e.value)
-
-
 def test_reset(pgconn):
     assert pgconn.status == ConnStatus.CONNECTION_OK
     # TODO: break it
@@ -198,14 +167,14 @@ def test_needs_password(pgconn):
     assert pgconn.needs_password is False
 
 
-def test_used_password(pgconn, tempenv, dsn):
+def test_used_password(pq, pgconn, tempenv, dsn):
     assert isinstance(pgconn.used_password, bool)
 
     # Assume that if a password was passed then it was needed.
     # Note that the server may still need a password passed via pgpass
     # so it may be that has_password is false but still a password was
     # requested by the server and passed by libpq.
-    info = pgconn.parse_conninfo(dsn)
+    info = pq.Conninfo.parse(dsn)
     has_password = (
         "PGPASSWORD" in tempenv
         or [i for i in info if i.keyword == "password"][0].val is not None