]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added PGconn.parameter_status
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 Mar 2020 11:56:31 +0000 (00:56 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 14 Mar 2020 11:56:31 +0000 (00:56 +1300)
psycopg3/_pq_ctypes.py
psycopg3/pq_ctypes.py
tests/conftest.py
tests/fix_db.py
tests/fix_tempenv.py [new file with mode: 0644]
tests/test_pq.py

index aeb7747951216bdedde922d425638ca7725aadbc..1feb246ac33ef6b5241471e87dc9e5b9dd8daa92 100644 (file)
@@ -137,6 +137,10 @@ PQtransactionStatus = pq.PQtransactionStatus
 PQtransactionStatus.argtypes = [PGconn_ptr]
 PQtransactionStatus.restype = c_int
 
+PQparameterStatus = pq.PQparameterStatus
+PQparameterStatus.argtypes = [PGconn_ptr, c_char_p]
+PQparameterStatus.restype = c_char_p
+
 PQerrorMessage = pq.PQerrorMessage
 PQerrorMessage.argtypes = [PGconn_ptr]
 PQerrorMessage.restype = c_char_p
index 75a3b69676b73e30051117ad1e0116ea3a2af689..88c873f42a0d79390a3fb2878b8cbb1818cc4466 100644 (file)
@@ -1,3 +1,4 @@
+#!/usr/bin/env python3
 """
 libpq Python wrapper using ctypes bindings.
 
@@ -168,6 +169,10 @@ class PGconn:
         rv = impl.PQtransactionStatus(self.pgconn_ptr)
         return TransactionStatus(rv)
 
+    def parameter_status(self, name):
+        rv = impl.PQparameterStatus(self.pgconn_ptr, self._encode(name))
+        return self._decode(rv)
+
     @property
     def error_message(self):
         return self._decode(impl.PQerrorMessage(self.pgconn_ptr))
@@ -177,8 +182,13 @@ class PGconn:
         return impl.PQsocket(self.pgconn_ptr)
 
     def _encode(self, s):
-        # TODO: encode in client encoding
-        return s.encode("utf8")
+        if isinstance(s, bytes):
+            return s
+        elif isinstance(s, str):
+            # TODO: encode in client encoding
+            return s.encode("utf8")
+        else:
+            raise TypeError("expected bytes or str, got %r instead" % s)
 
     def _decode(self, b):
         if b is None:
index 16040b121e883a8c55381f13ded86ddb71018b10..bf263193beb2d1e2450d2e7db78c1a0cfb43e30f 100644 (file)
@@ -1 +1 @@
-pytest_plugins = ("tests.fix_db",)
+pytest_plugins = ("tests.fix_db", "tests.fix_tempenv")
index 6bfbdbf41d2349b6d93ccbd26eff419f7b9c462b..205877a030acb561684d3cc3c4ba3a37541811ef 100644 (file)
@@ -31,4 +31,5 @@ def dsn(request):
 
 @pytest.fixture
 def pgconn(pq, dsn):
+    """Return a PGconn connection open to `--test-dsn`."""
     return pq.PGconn.connect(dsn)
diff --git a/tests/fix_tempenv.py b/tests/fix_tempenv.py
new file mode 100644 (file)
index 0000000..2f1e089
--- /dev/null
@@ -0,0 +1,37 @@
+import os
+import pytest
+
+
+class TempEnv:
+    def __init__(self):
+        self._prev = {}
+
+    def get(self, item, default):
+        return os.environ.get(self, item, default)
+
+    def __getitem__(self, item):
+        return os.environ[item]
+
+    def __setitem__(self, item, value):
+        self._prev.setdefault(item, os.environ.get(item))
+        os.environ[item] = value
+
+    def __delitem__(self, item):
+        self._prev.setdefault(item, os.environ.get(item))
+        del os.environ[item]
+
+    def restore(self):
+        for k, v in self._prev.items():
+            if v is not None:
+                os.environ[k] = v
+            else:
+                if k in os.environ:
+                    del os.environ[k]
+
+
+@pytest.fixture
+def tempenv():
+    """Allow to change the env vars temporarily."""
+    env = TempEnv()
+    yield env
+    env.restore()
index c6a03ed7206932844146623e34cc07965f00e385..5f1c22a8ee1f14fe21ab7c2474cd023821976904 100644 (file)
@@ -1,4 +1,3 @@
-import os
 from select import select
 
 import pytest
@@ -61,17 +60,9 @@ def test_connect_async_bad(pq, dsn):
     assert conn.status == ConnStatus.CONNECTION_BAD
 
 
-def test_defaults(pq):
-    oldport = os.environ.get("PGPORT")
-    try:
-        os.environ["PGPORT"] = "15432"
-        defs = pq.PGconn.get_defaults()
-    finally:
-        if oldport is not None:
-            os.environ["PGPORT"] = oldport
-        else:
-            del os.environ["PGPORT"]
-
+def test_defaults(pq, tempenv):
+    tempenv["PGPORT"] = "15432"
+    defs = pq.PGconn.get_defaults()
     assert len(defs) > 20
     port = [d for d in defs if d.keyword == "port"][0]
     assert port.envvar == "PGPORT"
@@ -181,3 +172,10 @@ def test_transaction_status(pq, pgconn):
     # TODO: test other states
     pgconn.finish()
     assert pgconn.transaction_status == pq.TransactionStatus.PQTRANS_UNKNOWN
+
+
+def test_parameter_status(pq, dsn, tempenv):
+    tempenv["PGAPPNAME"] = "psycopg3 tests"
+    pgconn = pq.PGconn.connect(dsn)
+    assert pgconn.parameter_status('application_name') == "psycopg3 tests"
+    assert pgconn.parameter_status('wat') is None