]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
escape_string, escape_identifier use buffers too
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 17 Dec 2020 04:55:56 +0000 (05:55 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 17 Dec 2020 04:57:09 +0000 (05:57 +0100)
psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/pq/proto.py
psycopg3_c/psycopg3_c/pq_cython.pyx

index 04e7d637b95ae86639f556d08b5f464e6059aa0f..d7b51f18113201e161c3cace373b24639dc719db 100644 (file)
@@ -800,41 +800,43 @@ class Escaping:
         self.conn = conn
 
     def escape_literal(self, data: "proto.Buffer") -> memoryview:
-        if self.conn:
-            self.conn._ensure_pgconn()
-            # TODO: might be done without copy (however C does that)
-            if not isinstance(data, bytes):
-                data = bytes(data)
-            out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
-            if not out:
-                raise PQerror(
-                    f"escape_literal failed: {error_message(self.conn)} bytes"
-                )
-            rv = string_at(out)
-            impl.PQfreemem(out)
-            return memoryview(rv)
-
-        else:
+        if not self.conn:
             raise PQerror("escape_literal failed: no connection provided")
 
-    def escape_identifier(self, data: bytes) -> bytes:
-        if self.conn:
-            self.conn._ensure_pgconn()
-            out = impl.PQescapeIdentifier(
-                self.conn.pgconn_ptr, data, len(data)
+        self.conn._ensure_pgconn()
+        # TODO: might be done without copy (however C does that)
+        if not isinstance(data, bytes):
+            data = bytes(data)
+        out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
+        if not out:
+            raise PQerror(
+                f"escape_literal failed: {error_message(self.conn)} bytes"
             )
-            if not out:
-                raise PQerror(
-                    f"escape_identifier failed: {error_message(self.conn)} bytes"
-                )
-            rv = string_at(out)
-            impl.PQfreemem(out)
-            return rv
+        rv = string_at(out)
+        impl.PQfreemem(out)
+        return memoryview(rv)
 
-        else:
+    def escape_identifier(self, data: "proto.Buffer") -> memoryview:
+        if not self.conn:
             raise PQerror("escape_identifier failed: no connection provided")
 
-    def escape_string(self, data: bytes) -> bytes:
+        self.conn._ensure_pgconn()
+
+        if not isinstance(data, bytes):
+            data = bytes(data)
+        out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data))
+        if not out:
+            raise PQerror(
+                f"escape_identifier failed: {error_message(self.conn)} bytes"
+            )
+        rv = string_at(out)
+        impl.PQfreemem(out)
+        return memoryview(rv)
+
+    def escape_string(self, data: "proto.Buffer") -> memoryview:
+        if not isinstance(data, bytes):
+            data = bytes(data)
+
         if self.conn:
             self.conn._ensure_pgconn()
             error = c_int()
@@ -851,7 +853,6 @@ class Escaping:
                 raise PQerror(
                     f"escape_string failed: {error_message(self.conn)} bytes"
                 )
-            return out.value
 
         else:
             out = create_string_buffer(len(data) * 2 + 1)
@@ -860,7 +861,8 @@ class Escaping:
                 data,
                 len(data),
             )
-            return out.value
+
+        return memoryview(out.value)
 
     def escape_bytea(self, data: "proto.Buffer") -> memoryview:
         len_out = c_size_t()
index 82ed003b8b7f34837996396de5a40d0b113f0a28..4405af201e501e94bc2878a72b1997fed85ebc73 100644 (file)
@@ -341,10 +341,10 @@ class Escaping(Protocol):
     def escape_literal(self, data: Buffer) -> memoryview:
         ...
 
-    def escape_identifier(self, data: bytes) -> bytes:
+    def escape_identifier(self, data: Buffer) -> memoryview:
         ...
 
-    def escape_string(self, data: bytes) -> bytes:
+    def escape_string(self, data: Buffer) -> memoryview:
         ...
 
     def escape_bytea(self, data: Buffer) -> memoryview:
index 8582e6f31438abc4c185d756a9f91b3ef27455b9..7f0bd9cbe16ded96818b583b151219f578a6a023 100644 (file)
@@ -10,6 +10,8 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free
 from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize
 from cpython.buffer cimport PyObject_CheckBuffer, PyBUF_SIMPLE
 from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release
+from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
+from cpython.bytearray cimport PyByteArray_AS_STRING
 
 import logging
 from typing import List, Optional, Sequence, Tuple
@@ -823,58 +825,55 @@ cdef class Escaping:
 
         return memoryview(PQBuffer._from_buffer(<unsigned char *>out, strlen(out)))
 
-    # TODO: return PQBuffer
-    def escape_identifier(self, data: bytes) -> bytes:
+    def escape_identifier(self, data: "Buffer") -> memoryview:
         cdef char *out
-        cdef bytes rv
+        cdef char *ptr
+        cdef Py_ssize_t length
 
-        if self.conn is not None:
-            if self.conn.pgconn_ptr is NULL:
-                raise PQerror("the connection is closed")
-            out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data))
-            if out is NULL:
-                raise PQerror(
-                    f"escape_identifier failed: {error_message(self.conn)}"
-                )
-            rv = out
-            impl.PQfreemem(out)
-            return rv
+        _buffer_as_string_and_size(data, &ptr, &length)
 
-        else:
+        if self.conn is None:
             raise PQerror("escape_identifier failed: no connection provided")
+        if self.conn.pgconn_ptr is NULL:
+            raise PQerror("the connection is closed")
+
+        out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, ptr, length)
+        if out is NULL:
+            raise PQerror(
+                f"escape_identifier failed: {error_message(self.conn)}"
+            )
+
+        return memoryview(PQBuffer._from_buffer(<unsigned char *>out, strlen(out)))
 
-    def escape_string(self, data: bytes) -> bytes:
+    def escape_string(self, data: "Buffer") -> memoryview:
         cdef int error
-        cdef size_t len_data = len(data)
-        cdef char *out
         cdef size_t len_out
-        cdef bytes rv
+        cdef char *ptr
+        cdef Py_ssize_t length
+        cdef bytearray rv
+
+        _buffer_as_string_and_size(data, &ptr, &length)
+
+        rv = PyByteArray_FromStringAndSize("", 0)
+        PyByteArray_Resize(rv, length * 2 + 1)
 
         if self.conn is not None:
             if self.conn.pgconn_ptr is NULL:
                 raise PQerror("the connection is closed")
 
-            out = <char *>PyMem_Malloc(len_data * 2 + 1)
             len_out = impl.PQescapeStringConn(
-                self.conn.pgconn_ptr, out, data, len_data, &error
+                self.conn.pgconn_ptr, PyByteArray_AS_STRING(rv),
+                ptr, length, &error
             )
-
             if error:
-                PyMem_Free(out)
                 raise PQerror(
                     f"escape_string failed: {error_message(self.conn)}"
                 )
 
-            rv = out[:len_out]
-            PyMem_Free(out)
-            return rv
-
         else:
-            out = <char *>PyMem_Malloc(len_data * 2 + 1)
-            len_out = impl.PQescapeString(out, data, len_data)
-            rv = out[:len_out]
-            PyMem_Free(out)
-            return rv
+            len_out = impl.PQescapeString(PyByteArray_AS_STRING(rv), ptr, length)
+
+        return memoryview(rv)
 
     def escape_bytea(self, data: "Buffer") -> memoryview:
         cdef size_t len_out