From: Daniele Varrazzo Date: Sat, 8 Aug 2020 18:31:43 +0000 (+0100) Subject: Added Escaping.escape_identifier X-Git-Tag: 3.0.dev0~465 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=cb2f903bf915d5dedba086551fc3ef7a7219b1ca;p=thirdparty%2Fpsycopg.git Added Escaping.escape_identifier --- diff --git a/psycopg3/psycopg3/pq/_pq_ctypes.py b/psycopg3/psycopg3/pq/_pq_ctypes.py index a5eec231d..50237e9fe 100644 --- a/psycopg3/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/psycopg3/pq/_pq_ctypes.py @@ -387,11 +387,15 @@ PQoidValue.restype = Oid # 33.3.4. Escaping Strings for Inclusion in SQL Commands -# TODO: PQescapeIdentifier PQescapeStringConn +# TODO: PQescapeStringConn PQescapeLiteral = pq.PQescapeLiteral PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t] PQescapeLiteral.restype = POINTER(c_char) +PQescapeIdentifier = pq.PQescapeIdentifier +PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t] +PQescapeIdentifier.restype = POINTER(c_char) + # won't wrap: PQescapeString PQescapeByteaConn = pq.PQescapeByteaConn diff --git a/psycopg3/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/psycopg3/pq/_pq_ctypes.pyi index 1a942a1a3..ebfbaca95 100644 --- a/psycopg3/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/psycopg3/pq/_pq_ctypes.pyi @@ -160,6 +160,7 @@ def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ... def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ... def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... +def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ... def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ... def PQescapeBytea(arg1: bytes, arg2: int, arg3: pointer[c_ulong]) -> pointer[c_ubyte]: ... def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ... diff --git a/psycopg3/psycopg3/pq/pq_ctypes.py b/psycopg3/psycopg3/pq/pq_ctypes.py index 2616e8a70..3d3f03e96 100644 --- a/psycopg3/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/psycopg3/pq/pq_ctypes.py @@ -760,6 +760,23 @@ class Escaping: else: raise PQerror("escape_literal failed: no connection provided") + def escape_identifier(self, data: bytes) -> bytes: + if self.conn is not None: + self.conn._ensure_pgconn() + out = impl.PQescapeIdentifier( + self.conn.pgconn_ptr, data, len(data) + ) + if not out: + raise PQerror( + f"escape_identifier failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return rv + + else: + raise PQerror("escape_identifier failed: no connection provided") + def escape_bytea(self, data: bytes) -> bytes: len_out = c_size_t() if self.conn is not None: diff --git a/psycopg3_c/psycopg3_c/libpq.pxd b/psycopg3_c/psycopg3_c/libpq.pxd index eae21a41c..73e9b7dc0 100644 --- a/psycopg3_c/psycopg3_c/libpq.pxd +++ b/psycopg3_c/psycopg3_c/libpq.pxd @@ -183,7 +183,8 @@ cdef extern from "libpq-fe.h": Oid PQoidValue(const PGresult *res) # 33.3.4. Escaping Strings for Inclusion in SQL Commands - # TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString + # TODO: PQescapeStringConn PQescapeString + char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length); char *PQescapeLiteral(PGconn *conn, const char *str, size_t length); unsigned char *PQescapeByteaConn(PGconn *conn, const unsigned char *src, diff --git a/psycopg3_c/psycopg3_c/pq_cython.pyx b/psycopg3_c/psycopg3_c/pq_cython.pyx index c45bf60a3..e9b04b726 100644 --- a/psycopg3_c/psycopg3_c/pq_cython.pyx +++ b/psycopg3_c/psycopg3_c/pq_cython.pyx @@ -814,6 +814,25 @@ cdef class Escaping: else: raise PQerror("escape_literal failed: no connection provided") + def escape_identifier(self, data: bytes) -> bytes: + cdef char *out + cdef bytes rv + + if self.conn is not None: + if self.conn.pgconn_ptr is NULL: + raise PQerror("the connection is closed") + out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data)) + if out is NULL: + raise PQerror( + f"escape_identifier failed: {error_message(self.conn)} bytes" + ) + rv = out + impl.PQfreemem(out) + return rv + + else: + raise PQerror("escape_identifier failed: no connection provided") + def escape_bytea(self, data: bytes) -> bytes: cdef size_t len_out cdef unsigned char *out diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py index 6d804724f..583ee5533 100644 --- a/tests/pq/test_escaping.py +++ b/tests/pq/test_escaping.py @@ -40,6 +40,42 @@ def test_escape_literal_noconn(pgconn): esc.escape_literal(b"hi") +@pytest.mark.parametrize( + "data, want", + [ + (b"", b'""'), + (b"hello", b'"hello"'), + (b'foo"bar', b'"foo""bar"'), + (b"foo\\bar", b'"foo\\bar"'), + ], +) +def test_escape_identifier(pgconn, data, want): + esc = pq.Escaping(pgconn) + out = esc.escape_identifier(data) + assert out == want + + +def test_escape_identifier_1char(pgconn): + esc = pq.Escaping(pgconn) + special = {b'"': b'""""', b"\\": b'"\\"'} + for c in range(1, 128): + data = bytes([c]) + rv = esc.escape_identifier(data) + exp = special.get(data) or b'"%s"' % data + assert rv == exp + + +def test_escape_identifier_noconn(pgconn): + esc = pq.Escaping() + with pytest.raises(psycopg3.OperationalError): + esc.escape_identifier(b"hi") + + esc = pq.Escaping(pgconn) + pgconn.finish() + with pytest.raises(psycopg3.OperationalError): + esc.escape_identifier(b"hi") + + @pytest.mark.parametrize( "data", [(b"hello\00world"), (b"\00\00\00\00")], )