From: Daniele Varrazzo Date: Sun, 25 Jul 2021 19:26:12 +0000 (+0200) Subject: Fix bytes quoting with non standard conforming strings X-Git-Tag: 3.0.dev2~33 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=491df7aa0cc0b255e7dc3041dd0397ea85f6d549;p=thirdparty%2Fpsycopg.git Fix bytes quoting with non standard conforming strings --- diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py index de02b42ab..d1702f334 100644 --- a/psycopg/psycopg/types/string.py +++ b/psycopg/psycopg/types/string.py @@ -85,6 +85,7 @@ class BytesDumper(Dumper): format = Format.TEXT _oid = postgres.types["bytea"].oid + _qprefix = b"" def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) @@ -97,6 +98,19 @@ class BytesDumper(Dumper): # probably dump return value should be extended to Buffer return self._esc.escape_bytea(obj) + def quote(self, obj: bytes) -> bytes: + if not self._qprefix: + if self.connection: + scs = self.connection.pgconn.parameter_status( + b"standard_conforming_strings" + ) + self._qprefix = b"'" if scs == b"on" else b" E'" + else: + self._qprefix = b" E'" + + escaped = self.dump(obj) + return self._qprefix + bytes(escaped) + b"'" + class BytesBinaryDumper(Dumper): diff --git a/psycopg_c/psycopg_c/types/string.pyx b/psycopg_c/psycopg_c/types/string.pyx index 3e7a95847..84aa34458 100644 --- a/psycopg_c/psycopg_c/types/string.pyx +++ b/psycopg_c/psycopg_c/types/string.pyx @@ -148,8 +148,12 @@ cdef class BytesDumper(CDumper): format = PQ_TEXT + # 0: not set, 1: just single "'" quote, 3: " E'" qoute + cdef int _qplen + def __cinit__(self): self.oid = oids.BYTEA_OID + self._qplen = 0 cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1: @@ -177,6 +181,56 @@ cdef class BytesDumper(CDumper): libpq.PQfreemem(out) return len_out + def quote(self, obj) -> bytearray: + cdef size_t len_esc + cdef size_t len_out + cdef unsigned char *out + cdef char *ptr + cdef Py_ssize_t length + cdef bytearray rv + + _buffer_as_string_and_size(obj, &ptr, &length) + + if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: + out = libpq.PQescapeByteaConn( + self._pgconn._pgconn_ptr, ptr, length, &len_esc) + else: + out = libpq.PQescapeBytea(ptr, length, &len_esc) + + if out is NULL: + raise MemoryError( + f"couldn't allocate for escape_bytea of {length} bytes" + ) + + if not self._qplen: + self._qplen = 3 + if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: + scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, + b"standard_conforming_strings") + if scs and scs[0] == b'o' and scs[1] == b"n": # == "on" + self._qplen = 1 + + # len_esc includes the final 0 + # So the final string is len_esc - 1 plus quotes, which might be 2 or 4 bytes + len_out = len_esc + self._qplen + rv = PyByteArray_FromStringAndSize("", 0) + cdef char *buf = CDumper.ensure_size(rv, 0, len_out) + memcpy(buf + self._qplen, out, len_esc) + libpq.PQfreemem(out) + + if self._qplen == 3: + # Quote as " E'content'" + buf[0] = b' ' + buf[1] = b'E' + buf[2] = b'\'' + else: + # Quote as "'content'" + buf[0] = b'\'' + + buf[len_out - 1] = b'\'' + + return rv + @cython.final cdef class BytesBinaryDumper(CDumper): diff --git a/tests/types/test_string.py b/tests/types/test_string.py index 04669f0fc..b3cdbfd3d 100644 --- a/tests/types/test_string.py +++ b/tests/types/test_string.py @@ -212,23 +212,31 @@ def test_dump_1byte(conn, fmt_in, pytype): cur = conn.cursor() for i in range(0, 256): obj = pytype(bytes([i])) - cur.execute(f"select %{fmt_in} = %s::bytea", (obj, fr"\x{i:02x}")) + cur.execute(f"select %{fmt_in} = set_byte('x', 0, %s)", (obj, i)) assert cur.fetchone()[0] is True, i -def test_quote_1byte(conn): +@pytest.mark.parametrize("scs", ["on", "off"]) +def test_quote_1byte(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + cur = conn.cursor() - query = sql.SQL("select {ch} = %s::bytea") + query = sql.SQL("select {ch} = set_byte('x', 0, %s)") for i in range(0, 256): - cur.execute(query.format(ch=sql.Literal(bytes([i]))), (fr"\x{i:02x}",)) + cur.execute(query.format(ch=sql.Literal(bytes([i]))), (i,)) assert cur.fetchone()[0] is True, i + # No "nonstandard use of \\ in a string literal" warning + assert not messages + @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) def test_load_1byte(conn, fmt_out): cur = conn.cursor(binary=fmt_out) for i in range(0, 256): - cur.execute("select %s::bytea", (fr"\x{i:02x}",)) + cur.execute("select set_byte('x', 0, %s)", (i,)) assert cur.fetchone()[0] == bytes([i]) assert cur.pgresult.fformat(0) == fmt_out