From: Daniele Varrazzo Date: Sat, 8 Aug 2020 19:42:00 +0000 (+0100) Subject: Added Escaping.escape_string X-Git-Tag: 3.0.dev0~464 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9d5132d5d3f692bec0460f7d379cc4e870d1f759;p=thirdparty%2Fpsycopg.git Added Escaping.escape_string --- diff --git a/psycopg3/psycopg3/pq/_pq_ctypes.py b/psycopg3/psycopg3/pq/_pq_ctypes.py index 50237e9fe..fa45ef7fe 100644 --- a/psycopg3/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/psycopg3/pq/_pq_ctypes.py @@ -387,7 +387,6 @@ PQoidValue.restype = Oid # 33.3.4. Escaping Strings for Inclusion in SQL Commands -# TODO: PQescapeStringConn PQescapeLiteral = pq.PQescapeLiteral PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t] PQescapeLiteral.restype = POINTER(c_char) @@ -396,6 +395,13 @@ PQescapeIdentifier = pq.PQescapeIdentifier PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t] PQescapeIdentifier.restype = POINTER(c_char) +PQescapeStringConn = pq.PQescapeStringConn +# TODO: raises "wrong type" error +# PQescapeStringConn.argtypes = [ +# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int) +# ] +PQescapeStringConn.restype = c_size_t + # won't wrap: PQescapeString PQescapeByteaConn = pq.PQescapeByteaConn diff --git a/psycopg3/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/psycopg3/pq/_pq_ctypes.pyi index ebfbaca95..90804f538 100644 --- a/psycopg3/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/psycopg3/pq/_pq_ctypes.pyi @@ -60,6 +60,13 @@ def PQgetvalue( arg1: Optional[PGresult_struct], arg2: int, arg3: int ) -> pointer[c_char]: ... def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... +def PQescapeStringConn( + arg1: Optional[PGconn_struct], + arg2: c_char_p, + arg3: bytes, + arg4: int, + arg5: pointer[c_int], +) -> int: ... def PQsendPrepare( arg1: Optional[PGconn_struct], arg2: bytes, diff --git a/psycopg3/psycopg3/pq/pq_ctypes.py b/psycopg3/psycopg3/pq/pq_ctypes.py index 3d3f03e96..09c9ffb62 100644 --- a/psycopg3/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/psycopg3/pq/pq_ctypes.py @@ -777,6 +777,28 @@ class Escaping: else: raise PQerror("escape_identifier failed: no connection provided") + def escape_string(self, data: bytes) -> bytes: + if self.conn is not None: + self.conn._ensure_pgconn() + error = c_int() + out = create_string_buffer(len(data) * 2 + 1) + impl.PQescapeStringConn( + self.conn.pgconn_ptr, + pointer(out), # type: ignore + data, + len(data), + pointer(error), + ) + + if error: + raise PQerror( + f"escape_string failed: {error_message(self.conn)} bytes" + ) + return out.value + + else: + raise PQerror("escape_identifier failed: no connection provided") + def escape_bytea(self, data: bytes) -> bytes: len_out = c_size_t() if self.conn is not None: diff --git a/psycopg3_c/psycopg3_c/libpq.pxd b/psycopg3_c/psycopg3_c/libpq.pxd index 73e9b7dc0..908f4e3f2 100644 --- a/psycopg3_c/psycopg3_c/libpq.pxd +++ b/psycopg3_c/psycopg3_c/libpq.pxd @@ -184,8 +184,11 @@ cdef extern from "libpq-fe.h": # 33.3.4. Escaping Strings for Inclusion in SQL Commands # TODO: PQescapeStringConn PQescapeString - char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length); - char *PQescapeLiteral(PGconn *conn, const char *str, size_t length); + char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length) + char *PQescapeLiteral(PGconn *conn, const char *str, size_t length) + size_t PQescapeStringConn(PGconn *conn, + char *to, const char *from_, size_t length, + int *error) unsigned char *PQescapeByteaConn(PGconn *conn, const unsigned char *src, size_t from_length, diff --git a/psycopg3_c/psycopg3_c/pq_cython.pyx b/psycopg3_c/psycopg3_c/pq_cython.pyx index e9b04b726..ca0b0c576 100644 --- a/psycopg3_c/psycopg3_c/pq_cython.pyx +++ b/psycopg3_c/psycopg3_c/pq_cython.pyx @@ -805,7 +805,7 @@ cdef class Escaping: out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data)) if out is NULL: raise PQerror( - f"escape_literal failed: {error_message(self.conn)} bytes" + f"escape_literal failed: {error_message(self.conn)}" ) rv = out impl.PQfreemem(out) @@ -824,7 +824,7 @@ cdef class Escaping: out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data)) if out is NULL: raise PQerror( - f"escape_identifier failed: {error_message(self.conn)} bytes" + f"escape_identifier failed: {error_message(self.conn)}" ) rv = out impl.PQfreemem(out) @@ -833,6 +833,29 @@ cdef class Escaping: else: raise PQerror("escape_identifier failed: no connection provided") + def escape_string(self, data: bytes) -> bytes: + cdef int error + cdef size_t len_data = len(data) + cdef char *out + cdef size_t len_out + if self.conn is not None: + if self.conn.pgconn_ptr is NULL: + raise PQerror("the connection is closed") + + out = PyMem_Malloc(len_data * 2 + 1) + len_out = impl.PQescapeStringConn( + self.conn.pgconn_ptr, out, data, len_data, &error + ) + + if error: + raise PQerror( + f"escape_string failed: {error_message(self.conn)}" + ) + return out[:len_out] + + else: + raise PQerror("escape_identifier failed: no connection provided") + def escape_bytea(self, data: bytes) -> bytes: cdef size_t len_out cdef unsigned char *out diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py index 583ee5533..30015559e 100644 --- a/tests/pq/test_escaping.py +++ b/tests/pq/test_escaping.py @@ -76,6 +76,42 @@ def test_escape_identifier_noconn(pgconn): esc.escape_identifier(b"hi") +@pytest.mark.parametrize( + "data, want", + [ + (b"", b""), + (b"hello", b"hello"), + (b"foo'bar", b"foo''bar"), + (b"foo\\bar", b"foo\\bar"), + ], +) +def test_escape_string(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_string(data) + assert out == want + + +def test_escape_string_1char(pgconn): + esc = pq.Escaping(pgconn) + special = {b"'": b"''", b"\\": b"\\"} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_string(data) + exp = special.get(data) or b"%s" % data + assert rv == exp + + +def test_escape_string_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg3.OperationalError): + esc.escape_string(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + esc.escape_string(b"hi") + + @pytest.mark.parametrize( "data", [(b"hello\00world"), (b"\00\00\00\00")], )