From: Daniele Varrazzo Date: Sat, 11 Apr 2020 08:23:28 +0000 (+1200) Subject: Ensure a valid connection with escaping functions X-Git-Tag: 3.0.dev0~576 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=65d977ba88d936a30697429006e5e621bdb75331;p=thirdparty%2Fpsycopg.git Ensure a valid connection with escaping functions --- diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index 7cb91c003..b40e8f31c 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -585,6 +585,7 @@ class Escaping: def escape_bytea(self, data: bytes) -> bytes: len_out = c_size_t() if self.conn is not None: + self.conn._ensure_pgconn() out = impl.PQescapeByteaConn( self.conn.pgconn_ptr, data, @@ -605,6 +606,11 @@ class Escaping: return rv def unescape_bytea(self, data: bytes) -> bytes: + # not needed, but let's keep it symmetric with the escaping: + # if a connection is passed in, it must be valid. + if self.conn is not None: + self.conn._ensure_pgconn() + len_out = c_size_t() out = impl.PQunescapeBytea(data, pointer(t_cast(c_ulong, len_out))) if not out: diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py index 01ec91722..86062175f 100644 --- a/tests/pq/test_escaping.py +++ b/tests/pq/test_escaping.py @@ -1,14 +1,21 @@ import pytest +import psycopg3 + @pytest.mark.parametrize( "data", [(b"hello\00world"), (b"\00\00\00\00")], ) def test_escape_bytea(pq, pgconn, data): - rv = pq.Escaping(pgconn).escape_bytea(data) exp = br"\x" + b"".join(b"%02x" % c for c in data) + esc = pq.Escaping(pgconn) + rv = esc.escape_bytea(data) assert rv == exp + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + esc.escape_bytea(data) + def test_escape_noconn(pq, pgconn): data = bytes(range(256)) @@ -34,5 +41,10 @@ def test_escape_1char(pq, pgconn): ) def test_unescape_bytea(pq, pgconn, data): enc = br"\x" + b"".join(b"%02x" % c for c in data) - rv = pq.Escaping(pgconn).unescape_bytea(enc) + esc = pq.Escaping(pgconn) + rv = esc.unescape_bytea(enc) assert rv == data + + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + esc.unescape_bytea(data)