]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added Escaping.escape_identifier
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 8 Aug 2020 18:31:43 +0000 (19:31 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 8 Aug 2020 18:31:43 +0000 (19:31 +0100)
psycopg3/psycopg3/pq/_pq_ctypes.py
psycopg3/psycopg3/pq/_pq_ctypes.pyi
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3_c/psycopg3_c/libpq.pxd
psycopg3_c/psycopg3_c/pq_cython.pyx
tests/pq/test_escaping.py

index a5eec231d7b86e1f798efd063d227677f22fb84a..50237e9fe5a84025b8b1f96d48b89e8a3293dfbc 100644 (file)
@@ -387,11 +387,15 @@ PQoidValue.restype = Oid
 
 # 33.3.4. Escaping Strings for Inclusion in SQL Commands
 
-# TODO: PQescapeIdentifier PQescapeStringConn
+# TODO: PQescapeStringConn
 PQescapeLiteral = pq.PQescapeLiteral
 PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
 PQescapeLiteral.restype = POINTER(c_char)
 
+PQescapeIdentifier = pq.PQescapeIdentifier
+PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeIdentifier.restype = POINTER(c_char)
+
 # won't wrap: PQescapeString
 
 PQescapeByteaConn = pq.PQescapeByteaConn
index 1a942a1a3b2d7d6e32480d0b597520bc1caffad8..ebfbaca950f9016ead7682a0bc01426869182470 100644 (file)
@@ -160,6 +160,7 @@ def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
 def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
 def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
 def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
 def PQescapeBytea(arg1: bytes, arg2: int, arg3: pointer[c_ulong]) -> pointer[c_ubyte]: ...
 def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ...
index 2616e8a706fcee69da6e716a05b7fa4ef29f6f3c..3d3f03e969d396f17115d69f75ca50f58e2a90cb 100644 (file)
@@ -760,6 +760,23 @@ class Escaping:
         else:
             raise PQerror("escape_literal failed: no connection provided")
 
+    def escape_identifier(self, data: bytes) -> bytes:
+        if self.conn is not None:
+            self.conn._ensure_pgconn()
+            out = impl.PQescapeIdentifier(
+                self.conn.pgconn_ptr, data, len(data)
+            )
+            if not out:
+                raise PQerror(
+                    f"escape_identifier failed: {error_message(self.conn)} bytes"
+                )
+            rv = string_at(out)
+            impl.PQfreemem(out)
+            return rv
+
+        else:
+            raise PQerror("escape_identifier failed: no connection provided")
+
     def escape_bytea(self, data: bytes) -> bytes:
         len_out = c_size_t()
         if self.conn is not None:
index eae21a41cf15ebd9aac0db722db3a08215aeb85c..73e9b7dc0daa5e17bbd187d17b5b1aa3553776ea 100644 (file)
@@ -183,7 +183,8 @@ cdef extern from "libpq-fe.h":
     Oid PQoidValue(const PGresult *res)
 
     # 33.3.4. Escaping Strings for Inclusion in SQL Commands
-    # TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString
+    # TODO: PQescapeStringConn PQescapeString
+    char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length);
     char *PQescapeLiteral(PGconn *conn, const char *str, size_t length);
     unsigned char *PQescapeByteaConn(PGconn *conn,
                                      const unsigned char *src,
index c45bf60a33d995cbdf5ada7db9970a64c986bef8..e9b04b726de9f1691dee44b88701eb7689924faa 100644 (file)
@@ -814,6 +814,25 @@ cdef class Escaping:
         else:
             raise PQerror("escape_literal failed: no connection provided")
 
+    def escape_identifier(self, data: bytes) -> bytes:
+        cdef char *out
+        cdef bytes rv
+
+        if self.conn is not None:
+            if self.conn.pgconn_ptr is NULL:
+                raise PQerror("the connection is closed")
+            out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data))
+            if out is NULL:
+                raise PQerror(
+                    f"escape_identifier failed: {error_message(self.conn)} bytes"
+                )
+            rv = out
+            impl.PQfreemem(out)
+            return rv
+
+        else:
+            raise PQerror("escape_identifier failed: no connection provided")
+
     def escape_bytea(self, data: bytes) -> bytes:
         cdef size_t len_out
         cdef unsigned char *out
index 6d804724fad3e3580106e310d04445a3d9afaabd..583ee55333e802bbd3b8c1e7ff0d86c8c46bccd4 100644 (file)
@@ -40,6 +40,42 @@ def test_escape_literal_noconn(pgconn):
         esc.escape_literal(b"hi")
 
 
+@pytest.mark.parametrize(
+    "data, want",
+    [
+        (b"", b'""'),
+        (b"hello", b'"hello"'),
+        (b'foo"bar', b'"foo""bar"'),
+        (b"foo\\bar", b'"foo\\bar"'),
+    ],
+)
+def test_escape_identifier(pgconn, data, want):
+    esc = pq.Escaping(pgconn)
+    out = esc.escape_identifier(data)
+    assert out == want
+
+
+def test_escape_identifier_1char(pgconn):
+    esc = pq.Escaping(pgconn)
+    special = {b'"': b'""""', b"\\": b'"\\"'}
+    for c in range(1, 128):
+        data = bytes([c])
+        rv = esc.escape_identifier(data)
+        exp = special.get(data) or b'"%s"' % data
+        assert rv == exp
+
+
+def test_escape_identifier_noconn(pgconn):
+    esc = pq.Escaping()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.escape_identifier(b"hi")
+
+    esc = pq.Escaping(pgconn)
+    pgconn.finish()
+    with pytest.raises(psycopg3.OperationalError):
+        esc.escape_identifier(b"hi")
+
+
 @pytest.mark.parametrize(
     "data", [(b"hello\00world"), (b"\00\00\00\00")],
 )