From: Daniele Varrazzo Date: Thu, 17 Dec 2020 04:55:56 +0000 (+0100) Subject: escape_string, escape_identifier use buffers too X-Git-Tag: 3.0.dev0~262 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=9e17aa2004b51ebb0eb45633f67664f188dd9887;p=thirdparty%2Fpsycopg.git escape_string, escape_identifier use buffers too --- diff --git a/psycopg3/psycopg3/pq/pq_ctypes.py b/psycopg3/psycopg3/pq/pq_ctypes.py index 04e7d637b..d7b51f181 100644 --- a/psycopg3/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/psycopg3/pq/pq_ctypes.py @@ -800,41 +800,43 @@ class Escaping: self.conn = conn def escape_literal(self, data: "proto.Buffer") -> memoryview: - if self.conn: - self.conn._ensure_pgconn() - # TODO: might be done without copy (however C does that) - if not isinstance(data, bytes): - data = bytes(data) - out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data)) - if not out: - raise PQerror( - f"escape_literal failed: {error_message(self.conn)} bytes" - ) - rv = string_at(out) - impl.PQfreemem(out) - return memoryview(rv) - - else: + if not self.conn: raise PQerror("escape_literal failed: no connection provided") - def escape_identifier(self, data: bytes) -> bytes: - if self.conn: - self.conn._ensure_pgconn() - out = impl.PQescapeIdentifier( - self.conn.pgconn_ptr, data, len(data) + self.conn._ensure_pgconn() + # TODO: might be done without copy (however C does that) + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data)) + if not out: + raise PQerror( + f"escape_literal failed: {error_message(self.conn)} bytes" ) - if not out: - raise PQerror( - f"escape_identifier failed: {error_message(self.conn)} bytes" - ) - rv = string_at(out) - impl.PQfreemem(out) - return rv + rv = string_at(out) + impl.PQfreemem(out) + return memoryview(rv) - else: + def escape_identifier(self, data: "proto.Buffer") -> memoryview: + if not self.conn: raise PQerror("escape_identifier failed: no connection provided") - def escape_string(self, data: bytes) -> bytes: + self.conn._ensure_pgconn() + + if not isinstance(data, bytes): + data = bytes(data) + out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data)) + if not out: + raise PQerror( + f"escape_identifier failed: {error_message(self.conn)} bytes" + ) + rv = string_at(out) + impl.PQfreemem(out) + return memoryview(rv) + + def escape_string(self, data: "proto.Buffer") -> memoryview: + if not isinstance(data, bytes): + data = bytes(data) + if self.conn: self.conn._ensure_pgconn() error = c_int() @@ -851,7 +853,6 @@ class Escaping: raise PQerror( f"escape_string failed: {error_message(self.conn)} bytes" ) - return out.value else: out = create_string_buffer(len(data) * 2 + 1) @@ -860,7 +861,8 @@ class Escaping: data, len(data), ) - return out.value + + return memoryview(out.value) def escape_bytea(self, data: "proto.Buffer") -> memoryview: len_out = c_size_t() diff --git a/psycopg3/psycopg3/pq/proto.py b/psycopg3/psycopg3/pq/proto.py index 82ed003b8..4405af201 100644 --- a/psycopg3/psycopg3/pq/proto.py +++ b/psycopg3/psycopg3/pq/proto.py @@ -341,10 +341,10 @@ class Escaping(Protocol): def escape_literal(self, data: Buffer) -> memoryview: ... - def escape_identifier(self, data: bytes) -> bytes: + def escape_identifier(self, data: Buffer) -> memoryview: ... - def escape_string(self, data: bytes) -> bytes: + def escape_string(self, data: Buffer) -> memoryview: ... def escape_bytea(self, data: Buffer) -> memoryview: diff --git a/psycopg3_c/psycopg3_c/pq_cython.pyx b/psycopg3_c/psycopg3_c/pq_cython.pyx index 8582e6f31..7f0bd9cbe 100644 --- a/psycopg3_c/psycopg3_c/pq_cython.pyx +++ b/psycopg3_c/psycopg3_c/pq_cython.pyx @@ -10,6 +10,8 @@ from cpython.mem cimport PyMem_Malloc, PyMem_Free from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize from cpython.buffer cimport PyObject_CheckBuffer, PyBUF_SIMPLE from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release +from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize +from cpython.bytearray cimport PyByteArray_AS_STRING import logging from typing import List, Optional, Sequence, Tuple @@ -823,58 +825,55 @@ cdef class Escaping: return memoryview(PQBuffer._from_buffer(out, strlen(out))) - # TODO: return PQBuffer - def escape_identifier(self, data: bytes) -> bytes: + def escape_identifier(self, data: "Buffer") -> memoryview: cdef char *out - cdef bytes rv + cdef char *ptr + cdef Py_ssize_t length - if self.conn is not None: - if self.conn.pgconn_ptr is NULL: - raise PQerror("the connection is closed") - out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data)) - if out is NULL: - raise PQerror( - f"escape_identifier failed: {error_message(self.conn)}" - ) - rv = out - impl.PQfreemem(out) - return rv + _buffer_as_string_and_size(data, &ptr, &length) - else: + if self.conn is None: raise PQerror("escape_identifier failed: no connection provided") + if self.conn.pgconn_ptr is NULL: + raise PQerror("the connection is closed") + + out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, ptr, length) + if out is NULL: + raise PQerror( + f"escape_identifier failed: {error_message(self.conn)}" + ) + + return memoryview(PQBuffer._from_buffer(out, strlen(out))) - def escape_string(self, data: bytes) -> bytes: + def escape_string(self, data: "Buffer") -> memoryview: cdef int error - cdef size_t len_data = len(data) - cdef char *out cdef size_t len_out - cdef bytes rv + cdef char *ptr + cdef Py_ssize_t length + cdef bytearray rv + + _buffer_as_string_and_size(data, &ptr, &length) + + rv = PyByteArray_FromStringAndSize("", 0) + PyByteArray_Resize(rv, length * 2 + 1) if self.conn is not None: if self.conn.pgconn_ptr is NULL: raise PQerror("the connection is closed") - out = PyMem_Malloc(len_data * 2 + 1) len_out = impl.PQescapeStringConn( - self.conn.pgconn_ptr, out, data, len_data, &error + self.conn.pgconn_ptr, PyByteArray_AS_STRING(rv), + ptr, length, &error ) - if error: - PyMem_Free(out) raise PQerror( f"escape_string failed: {error_message(self.conn)}" ) - rv = out[:len_out] - PyMem_Free(out) - return rv - else: - out = PyMem_Malloc(len_data * 2 + 1) - len_out = impl.PQescapeString(out, data, len_data) - rv = out[:len_out] - PyMem_Free(out) - return rv + len_out = impl.PQescapeString(PyByteArray_AS_STRING(rv), ptr, length) + + return memoryview(rv) def escape_bytea(self, data: "Buffer") -> memoryview: cdef size_t len_out