]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added Escaping.escape_string
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 8 Aug 2020 19:42:00 +0000 (20:42 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 8 Aug 2020 19:42:00 +0000 (20:42 +0100)
psycopg3/psycopg3/pq/_pq_ctypes.py
psycopg3/psycopg3/pq/_pq_ctypes.pyi
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3_c/psycopg3_c/libpq.pxd
psycopg3_c/psycopg3_c/pq_cython.pyx
tests/pq/test_escaping.py

index 50237e9fe5a84025b8b1f96d48b89e8a3293dfbc..fa45ef7febac1faea4c974d3093d516def4d535d 100644 (file)
@@ -387,7 +387,6 @@ PQoidValue.restype = Oid
 
 # 33.3.4. Escaping Strings for Inclusion in SQL Commands
 
-# TODO: PQescapeStringConn
 PQescapeLiteral = pq.PQescapeLiteral
 PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
 PQescapeLiteral.restype = POINTER(c_char)
@@ -396,6 +395,13 @@ PQescapeIdentifier = pq.PQescapeIdentifier
 PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
 PQescapeIdentifier.restype = POINTER(c_char)
 
+PQescapeStringConn = pq.PQescapeStringConn
+# TODO: raises "wrong type" error
+# PQescapeStringConn.argtypes = [
+#     PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
+# ]
+PQescapeStringConn.restype = c_size_t
+
 # won't wrap: PQescapeString
 
 PQescapeByteaConn = pq.PQescapeByteaConn
index ebfbaca950f9016ead7682a0bc01426869182470..90804f538ea8b17a89d8804ef2d73389aa2eca2e 100644 (file)
@@ -60,6 +60,13 @@ def PQgetvalue(
     arg1: Optional[PGresult_struct], arg2: int, arg3: int
 ) -> pointer[c_char]: ...
 def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQescapeStringConn(
+    arg1: Optional[PGconn_struct],
+    arg2: c_char_p,
+    arg3: bytes,
+    arg4: int,
+    arg5: pointer[c_int],
+) -> int: ...
 def PQsendPrepare(
     arg1: Optional[PGconn_struct],
     arg2: bytes,
index 3d3f03e969d396f17115d69f75ca50f58e2a90cb..09c9ffb6234cd417c606dcd2d6294f2a40fbaedc 100644 (file)
@@ -777,6 +777,28 @@ class Escaping:
         else:
             raise PQerror("escape_identifier failed: no connection provided")
 
+    def escape_string(self, data: bytes) -> bytes:
+        if self.conn is not None:
+            self.conn._ensure_pgconn()
+            error = c_int()
+            out = create_string_buffer(len(data) * 2 + 1)
+            impl.PQescapeStringConn(
+                self.conn.pgconn_ptr,
+                pointer(out),  # type: ignore
+                data,
+                len(data),
+                pointer(error),
+            )
+
+            if error:
+                raise PQerror(
+                    f"escape_string failed: {error_message(self.conn)} bytes"
+                )
+            return out.value
+
+        else:
+            raise PQerror("escape_identifier failed: no connection provided")
+
     def escape_bytea(self, data: bytes) -> bytes:
         len_out = c_size_t()
         if self.conn is not None:
index 73e9b7dc0daa5e17bbd187d17b5b1aa3553776ea..908f4e3f258dccf663b217748943cbcc8cf99f1b 100644 (file)
@@ -184,8 +184,11 @@ cdef extern from "libpq-fe.h":
 
     # 33.3.4. Escaping Strings for Inclusion in SQL Commands
     # TODO: PQescapeStringConn PQescapeString
-    char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length);
-    char *PQescapeLiteral(PGconn *conn, const char *str, size_t length);
+    char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length)
+    char *PQescapeLiteral(PGconn *conn, const char *str, size_t length)
+    size_t PQescapeStringConn(PGconn *conn,
+                              char *to, const char *from_, size_t length,
+                              int *error)
     unsigned char *PQescapeByteaConn(PGconn *conn,
                                      const unsigned char *src,
                                      size_t from_length,
index e9b04b726de9f1691dee44b88701eb7689924faa..ca0b0c576b4705dd3cdc2027ba08385d0fa0973a 100644 (file)
@@ -805,7 +805,7 @@ cdef class Escaping:
             out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
             if out is NULL:
                 raise PQerror(
-                    f"escape_literal failed: {error_message(self.conn)} bytes"
+                    f"escape_literal failed: {error_message(self.conn)}"
                 )
             rv = out
             impl.PQfreemem(out)
@@ -824,7 +824,7 @@ cdef class Escaping:
             out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data))
             if out is NULL:
                 raise PQerror(
-                    f"escape_identifier failed: {error_message(self.conn)} bytes"
+                    f"escape_identifier failed: {error_message(self.conn)}"
                 )
             rv = out
             impl.PQfreemem(out)
@@ -833,6 +833,29 @@ cdef class Escaping:
         else:
             raise PQerror("escape_identifier failed: no connection provided")
 
+    def escape_string(self, data: bytes) -> bytes:
+        cdef int error
+        cdef size_t len_data = len(data)
+        cdef char *out
+        cdef size_t len_out
+        if self.conn is not None:
+            if self.conn.pgconn_ptr is NULL:
+                raise PQerror("the connection is closed")
+
+            out = <char *>PyMem_Malloc(len_data * 2 + 1)
+            len_out = impl.PQescapeStringConn(
+                self.conn.pgconn_ptr, out, data, len_data, &error
+            )
+
+            if error:
+                raise PQerror(
+                    f"escape_string failed: {error_message(self.conn)}"
+                )
+            return out[:len_out]
+
+        else:
+            raise PQerror("escape_identifier failed: no connection provided")
+
     def escape_bytea(self, data: bytes) -> bytes:
         cdef size_t len_out
         cdef unsigned char *out
index 583ee55333e802bbd3b8c1e7ff0d86c8c46bccd4..30015559e7043d2e9022857895a560c9c6a43d5f 100644 (file)
@@ -76,6 +76,42 @@ def test_escape_identifier_noconn(pgconn):
         esc.escape_identifier(b"hi")
 
 
+@pytest.mark.parametrize(
+    "data, want",
+    [
+        (b"", b""),
+        (b"hello", b"hello"),
+        (b"foo'bar", b"foo''bar"),
+        (b"foo\\bar", b"foo\\bar"),
+    ],
+)
+def test_escape_string(pgconn, data, want):
+    esc = pq.Escaping(pgconn)
+    out = esc.escape_string(data)
+    assert out == want
+
+
+def test_escape_string_1char(pgconn):
+    esc = pq.Escaping(pgconn)
+    special = {b"'": b"''", b"\\": b"\\"}
+    for c in range(1, 128):
+        data = bytes([c])
+        rv = esc.escape_string(data)
+        exp = special.get(data) or b"%s" % data
+        assert rv == exp
+
+
+def test_escape_string_noconn(pgconn):
+    esc = pq.Escaping()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.escape_string(b"hi")
+
+    esc = pq.Escaping(pgconn)
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.escape_string(b"hi")
+
+
 @pytest.mark.parametrize(
     "data", [(b"hello\00world"), (b"\00\00\00\00")],
 )