]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Ensure a valid connection with escaping functions
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 08:23:28 +0000 (20:23 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 11 Apr 2020 08:23:28 +0000 (20:23 +1200)
psycopg3/pq/pq_ctypes.py
tests/pq/test_escaping.py

index 7cb91c0035403c0ce57bf322b82a8c08647b57b6..b40e8f31cb63ac77d11b53225c755c8408f62a83 100644 (file)
@@ -585,6 +585,7 @@ class Escaping:
     def escape_bytea(self, data: bytes) -> bytes:
         len_out = c_size_t()
         if self.conn is not None:
+            self.conn._ensure_pgconn()
             out = impl.PQescapeByteaConn(
                 self.conn.pgconn_ptr,
                 data,
@@ -605,6 +606,11 @@ class Escaping:
         return rv
 
     def unescape_bytea(self, data: bytes) -> bytes:
+        # not needed, but let's keep it symmetric with the escaping:
+        # if a connection is passed in, it must be valid.
+        if self.conn is not None:
+            self.conn._ensure_pgconn()
+
         len_out = c_size_t()
         out = impl.PQunescapeBytea(data, pointer(t_cast(c_ulong, len_out)))
         if not out:
index 01ec91722573e7f9d3a0e04480fa4e67972c59b0..86062175f9a4bd37aefc163378880fecceeaffe2 100644 (file)
@@ -1,14 +1,21 @@
 import pytest
 
+import psycopg3
+
 
 @pytest.mark.parametrize(
     "data", [(b"hello\00world"), (b"\00\00\00\00")],
 )
 def test_escape_bytea(pq, pgconn, data):
-    rv = pq.Escaping(pgconn).escape_bytea(data)
     exp = br"\x" + b"".join(b"%02x" % c for c in data)
+    esc = pq.Escaping(pgconn)
+    rv = esc.escape_bytea(data)
     assert rv == exp
 
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.escape_bytea(data)
+
 
 def test_escape_noconn(pq, pgconn):
     data = bytes(range(256))
@@ -34,5 +41,10 @@ def test_escape_1char(pq, pgconn):
 )
 def test_unescape_bytea(pq, pgconn, data):
     enc = br"\x" + b"".join(b"%02x" % c for c in data)
-    rv = pq.Escaping(pgconn).unescape_bytea(enc)
+    esc = pq.Escaping(pgconn)
+    rv = esc.unescape_bytea(enc)
     assert rv == data
+
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.unescape_bytea(data)