]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fix bytes quoting with non standard conforming strings
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 25 Jul 2021 19:26:12 +0000 (21:26 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 25 Jul 2021 19:26:12 +0000 (21:26 +0200)
psycopg/psycopg/types/string.py
psycopg_c/psycopg_c/types/string.pyx
tests/types/test_string.py

index de02b42ab3f8b972399b68b8ac71fb50b7d36bc6..d1702f33481eced0a06380a2e76f64a7d5466771 100644 (file)
@@ -85,6 +85,7 @@ class BytesDumper(Dumper):
 
     format = Format.TEXT
     _oid = postgres.types["bytea"].oid
+    _qprefix = b""
 
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
@@ -97,6 +98,19 @@ class BytesDumper(Dumper):
         # probably dump return value should be extended to Buffer
         return self._esc.escape_bytea(obj)
 
+    def quote(self, obj: bytes) -> bytes:
+        if not self._qprefix:
+            if self.connection:
+                scs = self.connection.pgconn.parameter_status(
+                    b"standard_conforming_strings"
+                )
+                self._qprefix = b"'" if scs == b"on" else b" E'"
+            else:
+                self._qprefix = b" E'"
+
+        escaped = self.dump(obj)
+        return self._qprefix + bytes(escaped) + b"'"
+
 
 class BytesBinaryDumper(Dumper):
 
index 3e7a9584761a70086c623c941ea689ebe41925d6..84aa34458bcf46d96d5f80908845de806c4dcd74 100644 (file)
@@ -148,8 +148,12 @@ cdef class BytesDumper(CDumper):
 
     format = PQ_TEXT
 
+    # 0: not set, 1: just  single "'" quote, 3: " E'" qoute
+    cdef int _qplen
+
     def __cinit__(self):
         self.oid = oids.BYTEA_OID
+        self._qplen = 0
 
     cdef Py_ssize_t cdump(self, obj, bytearray rv, Py_ssize_t offset) except -1:
 
@@ -177,6 +181,56 @@ cdef class BytesDumper(CDumper):
         libpq.PQfreemem(out)
         return len_out
 
+    def quote(self, obj) -> bytearray:
+        cdef size_t len_esc
+        cdef size_t len_out
+        cdef unsigned char *out
+        cdef char *ptr
+        cdef Py_ssize_t length
+        cdef bytearray rv
+
+        _buffer_as_string_and_size(obj, &ptr, &length)
+
+        if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
+            out = libpq.PQescapeByteaConn(
+                self._pgconn._pgconn_ptr, <unsigned char *>ptr, length, &len_esc)
+        else:
+            out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_esc)
+
+        if out is NULL:
+            raise MemoryError(
+                f"couldn't allocate for escape_bytea of {length} bytes"
+            )
+
+        if not self._qplen:
+            self._qplen = 3
+            if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
+                scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr,
+                    b"standard_conforming_strings")
+                if scs and scs[0] == b'o' and scs[1] == b"n":  # == "on"
+                    self._qplen = 1
+
+        # len_esc includes the final 0
+        # So the final string is len_esc - 1 plus quotes, which might be 2 or 4 bytes
+        len_out = len_esc + self._qplen
+        rv = PyByteArray_FromStringAndSize("", 0)
+        cdef char *buf = CDumper.ensure_size(rv, 0, len_out)
+        memcpy(buf + self._qplen, out, len_esc)
+        libpq.PQfreemem(out)
+
+        if self._qplen == 3:
+            # Quote as " E'content'"
+            buf[0] = b' '
+            buf[1] = b'E'
+            buf[2] = b'\''
+        else:
+            # Quote as "'content'"
+            buf[0] = b'\''
+
+        buf[len_out - 1] = b'\''
+
+        return rv
+
 
 @cython.final
 cdef class BytesBinaryDumper(CDumper):
index 04669f0fc94839a9a88b5055b72cc126d840232d..b3cdbfd3d083c4cd5869a47b07c53be42548298f 100644 (file)
@@ -212,23 +212,31 @@ def test_dump_1byte(conn, fmt_in, pytype):
     cur = conn.cursor()
     for i in range(0, 256):
         obj = pytype(bytes([i]))
-        cur.execute(f"select %{fmt_in} = %s::bytea", (obj, fr"\x{i:02x}"))
+        cur.execute(f"select %{fmt_in} = set_byte('x', 0, %s)", (obj, i))
         assert cur.fetchone()[0] is True, i
 
 
-def test_quote_1byte(conn):
+@pytest.mark.parametrize("scs", ["on", "off"])
+def test_quote_1byte(conn, scs):
+    messages = []
+    conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+    conn.execute(f"set standard_conforming_strings to {scs}")
+
     cur = conn.cursor()
-    query = sql.SQL("select {ch} = %s::bytea")
+    query = sql.SQL("select {ch} = set_byte('x', 0, %s)")
     for i in range(0, 256):
-        cur.execute(query.format(ch=sql.Literal(bytes([i]))), (fr"\x{i:02x}",))
+        cur.execute(query.format(ch=sql.Literal(bytes([i]))), (i,))
         assert cur.fetchone()[0] is True, i
 
+    # No "nonstandard use of \\ in a string literal" warning
+    assert not messages
+
 
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 def test_load_1byte(conn, fmt_out):
     cur = conn.cursor(binary=fmt_out)
     for i in range(0, 256):
-        cur.execute("select %s::bytea", (fr"\x{i:02x}",))
+        cur.execute("select set_byte('x', 0, %s)", (i,))
         assert cur.fetchone()[0] == bytes([i])
 
     assert cur.pgresult.fformat(0) == fmt_out