From 491df7aa0cc0b255e7dc3041dd0397ea85f6d549 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 25 Jul 2021 21:26:12 +0200 Subject: [PATCH] Fix bytes quoting with non standard conforming strings --- psycopg/psycopg/types/string.py | 14 ++++++++ psycopg_c/psycopg_c/types/string.pyx | 54 ++++++++++++++++++++++++++++ tests/types/test_string.py | 18 +++++++--- 3 files changed, 81 insertions(+), 5 deletions(-) diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py index de02b42ab..d1702f334 100644 --- a/psycopg/psycopg/types/string.py +++ b/psycopg/psycopg/types/string.py @@ -85,6 +85,7 @@ class BytesDumper(Dumper): format = Format.TEXT _oid = postgres.types["bytea"].oid + _qprefix = b"" def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) @@ -97,6 +98,19 @@ class BytesDumper(Dumper): # probably dump return value should be extended to Buffer return self._esc.escape_bytea(obj) + def quote(self, obj: bytes) -> bytes: + if not self._qprefix: + if self.connection: + scs = self.connection.pgconn.parameter_status( + b"standard_conforming_strings" + ) + self._qprefix = b"'" if scs == b"on" else b" E'" + else: + self._qprefix = b" E'" + + escaped = self.dump(obj) + return self._qprefix + bytes(escaped) + b"'" + class BytesBinaryDumper(Dumper): diff --git a/psycopg_c/psycopg_c/types/string.pyx b/psycopg_c/psycopg_c/types/string.pyx index 3e7a95847..84aa34458 100644 --- a/psycopg_c/psycopg_c/types/string.pyx +++ b/psycopg_c/psycopg_c/types/string.pyx @@ -148,8 +148,12 @@ cdef class BytesDumper(CDumper): format = PQ_TEXT + # 0: not set, 1: just single "'" quote, 3: " E'" qoute + cdef int _qplen + def __cinit__(self): self.oid = oids.BYTEA_OID + self._qplen = 0 cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: @@ -177,6 +181,56 @@ cdef class BytesDumper(CDumper): libpq.PQfreemem(out) return len_out + def quote(self, obj) -> bytearray: + cdef size_t len_esc + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + cdef bytearray rv + + _buffer_as_string_and_size(obj, &ptr, &length) + + if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: + out = libpq.PQescapeByteaConn( + self._pgconn._pgconn_ptr, ptr, length, &len_esc) + else: + out = libpq.PQescapeBytea(ptr, length, &len_esc) + + if out is NULL: + raise MemoryError( + f"couldn't allocate for escape_bytea of {length} bytes" + ) + + if not self._qplen: + self._qplen = 3 + if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: + scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, + b"standard_conforming_strings") + if scs and scs[0] == b'o' and scs[1] == b"n": # == "on" + self._qplen = 1 + + # len_esc includes the final 0 + # So the final string is len_esc - 1 plus quotes, which might be 2 or 4 bytes + len_out = len_esc + self._qplen + rv = PyByteArray_FromStringAndSize("", 0) + cdef char *buf = CDumper.ensure_size(rv, 0, len_out) + memcpy(buf + self._qplen, out, len_esc) + libpq.PQfreemem(out) + + if self._qplen == 3: + # Quote as " E'content'" + buf[0] = b' ' + buf[1] = b'E' + buf[2] = b'\'' + else: + # Quote as "'content'" + buf[0] = b'\'' + + buf[len_out - 1] = b'\'' + + return rv + @cython.final cdef class BytesBinaryDumper(CDumper): diff --git a/tests/types/test_string.py b/tests/types/test_string.py index 04669f0fc..b3cdbfd3d 100644 --- a/tests/types/test_string.py +++ b/tests/types/test_string.py @@ -212,23 +212,31 @@ def test_dump_1byte(conn, fmt_in, pytype): cur = conn.cursor() for i in range(0, 256): obj = pytype(bytes([i])) - cur.execute(f"select %{fmt_in} = %s::bytea", (obj, fr"\x{i:02x}")) + cur.execute(f"select %{fmt_in} = set_byte('x', 0, %s)", (obj, i)) assert cur.fetchone()[0] is True, i -def test_quote_1byte(conn): +@pytest.mark.parametrize("scs", ["on", "off"]) +def test_quote_1byte(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + cur = conn.cursor() - query = sql.SQL("select {ch} = %s::bytea") + query = sql.SQL("select {ch} = set_byte('x', 0, %s)") for i in range(0, 256): - cur.execute(query.format(ch=sql.Literal(bytes([i]))), (fr"\x{i:02x}",)) + cur.execute(query.format(ch=sql.Literal(bytes([i]))), (i,)) assert cur.fetchone()[0] is True, i + # No "nonstandard use of \\ in a string literal" warning + assert not messages + @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) def test_load_1byte(conn, fmt_out): cur = conn.cursor(binary=fmt_out) for i in range(0, 256): - cur.execute("select %s::bytea", (fr"\x{i:02x}",)) + cur.execute("select set_byte('x', 0, %s)", (i,)) assert cur.fetchone()[0] == bytes([i]) assert cur.pgresult.fformat(0) == fmt_out -- 2.47.3