]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added Escaping.escape_literal
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 8 Aug 2020 18:07:52 +0000 (19:07 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 8 Aug 2020 18:07:52 +0000 (19:07 +0100)
psycopg3/psycopg3/pq/_pq_ctypes.py
psycopg3/psycopg3/pq/_pq_ctypes.pyi
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/pq/proto.py
psycopg3_c/psycopg3_c/libpq.pxd
psycopg3_c/psycopg3_c/pq_cython.pyx
tests/pq/test_escaping.py

index 79e857cf9c9853a21b22780e942b200fd18c91e8..a5eec231d7b86e1f798efd063d227677f22fb84a 100644 (file)
@@ -387,7 +387,12 @@ PQoidValue.restype = Oid
 
 # 33.3.4. Escaping Strings for Inclusion in SQL Commands
 
-# TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString
+# TODO: PQescapeIdentifier PQescapeStringConn
+PQescapeLiteral = pq.PQescapeLiteral
+PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeLiteral.restype = POINTER(c_char)
+
+# won't wrap: PQescapeString
 
 PQescapeByteaConn = pq.PQescapeByteaConn
 PQescapeByteaConn.argtypes = [
index e3d223a5a8ceba725c297b340081f182e071e32d..1a942a1a3b2d7d6e32480d0b597520bc1caffad8 100644 (file)
@@ -159,6 +159,7 @@ def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
 def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
 def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
 def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
 def PQescapeBytea(arg1: bytes, arg2: int, arg3: pointer[c_ulong]) -> pointer[c_ubyte]: ...
 def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ...
index b85eaa7a846f51a45f5835b192a2b5b0493c9283..2616e8a706fcee69da6e716a05b7fa4ef29f6f3c 100644 (file)
@@ -745,6 +745,21 @@ class Escaping:
     def __init__(self, conn: Optional[PGconn] = None):
         self.conn = conn
 
+    def escape_literal(self, data: bytes) -> bytes:
+        if self.conn is not None:
+            self.conn._ensure_pgconn()
+            out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
+            if not out:
+                raise PQerror(
+                    f"escape_literal failed: {error_message(self.conn)} bytes"
+                )
+            rv = string_at(out)
+            impl.PQfreemem(out)
+            return rv
+
+        else:
+            raise PQerror("escape_literal failed: no connection provided")
+
     def escape_bytea(self, data: bytes) -> bytes:
         len_out = c_size_t()
         if self.conn is not None:
index ad8d108c7414b3b8c82267041b475abf6a6e4545..cc7e2d3f685e1de706015cd64dabeff9734f245c 100644 (file)
@@ -342,6 +342,9 @@ class Escaping(Protocol):
     def __init__(self, conn: Optional[PGconn] = None):
         ...
 
+    def escape_literal(self, data: bytes) -> bytes:
+        ...
+
     def escape_bytea(self, data: bytes) -> bytes:
         ...
 
index 7e0f4f4bc5ec60de01fe29d1859acec3a5d94225..eae21a41cf15ebd9aac0db722db3a08215aeb85c 100644 (file)
@@ -184,6 +184,7 @@ cdef extern from "libpq-fe.h":
 
     # 33.3.4. Escaping Strings for Inclusion in SQL Commands
     # TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString
+    char *PQescapeLiteral(PGconn *conn, const char *str, size_t length);
     unsigned char *PQescapeByteaConn(PGconn *conn,
                                      const unsigned char *src,
                                      size_t from_length,
index a1daa8ee8a5d9f6fca4f43f4f9d248ad4974ab4b..c45bf60a33d995cbdf5ada7db9970a64c986bef8 100644 (file)
@@ -795,6 +795,25 @@ cdef class Escaping:
     def __init__(self, conn: Optional[PGconn] = None):
         self.conn = conn
 
+    def escape_literal(self, data: bytes) -> bytes:
+        cdef char *out
+        cdef bytes rv
+
+        if self.conn is not None:
+            if self.conn.pgconn_ptr is NULL:
+                raise PQerror("the connection is closed")
+            out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
+            if out is NULL:
+                raise PQerror(
+                    f"escape_literal failed: {error_message(self.conn)} bytes"
+                )
+            rv = out
+            impl.PQfreemem(out)
+            return rv
+
+        else:
+            raise PQerror("escape_literal failed: no connection provided")
+
     def escape_bytea(self, data: bytes) -> bytes:
         cdef size_t len_out
         cdef unsigned char *out
index 3dd28210860152796ed9b47ce2d8b32f95f92982..6d804724fad3e3580106e310d04445a3d9afaabd 100644 (file)
@@ -4,6 +4,42 @@ import psycopg3
 from psycopg3 import pq
 
 
+@pytest.mark.parametrize(
+    "data, want",
+    [
+        (b"", b"''"),
+        (b"hello", b"'hello'"),
+        (b"foo'bar", b"'foo''bar'"),
+        (b"foo\\bar", b" E'foo\\\\bar'"),
+    ],
+)
+def test_escape_literal(pgconn, data, want):
+    esc = pq.Escaping(pgconn)
+    out = esc.escape_literal(data)
+    assert out == want
+
+
+def test_escape_literal_1char(pgconn):
+    esc = pq.Escaping(pgconn)
+    special = {b"'": b"''''", b"\\": b" E'\\\\'"}
+    for c in range(1, 128):
+        data = bytes([c])
+        rv = esc.escape_literal(data)
+        exp = special.get(data) or b"'%s'" % data
+        assert rv == exp
+
+
+def test_escape_literal_noconn(pgconn):
+    esc = pq.Escaping()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.escape_literal(b"hi")
+
+    esc = pq.Escaping(pgconn)
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.escape_literal(b"hi")
+
+
 @pytest.mark.parametrize(
     "data", [(b"hello\00world"), (b"\00\00\00\00")],
 )