]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Emit quoted values valid for every standard_conforming_string value
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Jul 2021 00:19:46 +0000 (02:19 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 26 Jul 2021 00:19:46 +0000 (02:19 +0200)
Work around the unpredictable return value from the libpq.

psycopg/psycopg/adapt.py
psycopg/psycopg/types/string.py
psycopg_c/psycopg_c/_psycopg/adapt.pyx
psycopg_c/psycopg_c/pq.pxd
psycopg_c/psycopg_c/pq/escaping.pyx
psycopg_c/psycopg_c/types/string.pyx
tests/test_sql.py
tests/types/test_string.py

index 29c59f3960128dbe56d5fe5a089976e3c5d7cb02..b3d655ca80ba0a762576a1c123bb39041a7c0dab 100644 (file)
@@ -18,6 +18,8 @@ if TYPE_CHECKING:
 AdaptersMap = _adapters_map.AdaptersMap
 Buffer = abc.Buffer
 
+ORD_BS = ord("\\")
+
 
 class Dumper(abc.Dumper, ABC):
     """
@@ -60,10 +62,37 @@ class Dumper(abc.Dumper, ABC):
 
         if self.connection:
             esc = pq.Escaping(self.connection.pgconn)
+            # escaping and quoting
             return esc.escape_literal(value)
-        else:
-            esc = pq.Escaping()
-            return b"'%s'" % esc.escape_string(value)
+
+        # This path is taken when quote is asked without a connection,
+        # usually it means by psycopg.sql.quote() or by
+        # 'Composible.as_string(None)'. Most often than not this is done by
+        # someone generating a SQL file to consume elsewhere.
+
+        # No quoting, only quote escaping, random bs escaping. See further.
+        esc = pq.Escaping()
+        out = esc.escape_string(value)
+
+        # b"\\" in memoryview doesn't work so search for the ascii value
+        if ORD_BS not in out:
+            # If the string has no backslash, the result is correct and we
+            # don't need to bother with standard_conforming_strings.
+            return b"'" + out + b"'"
+
+        # The libpq has a crazy behaviour: PQescapeString uses the last
+        # standard_conforming_strings setting seen on a connection. This
+        # means that backslashes might be escaped or might not.
+        #
+        # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+        # if scs is off, '\\' raises a warning and '\' is an error.
+        #
+        # Check what the libpq does, and if it doesn't escape the backslash
+        # let's do it on our own. Never mind the race condition.
+        rv: bytes = b" E'" + out + b"'"
+        if esc.escape_string(b"\\") == b"\\":
+            rv = rv.replace(b"\\", b"\\\\")
+        return rv
 
     def get_key(
         self, obj: Any, format: PyFormat
index d1702f33481eced0a06380a2e76f64a7d5466771..3833f17e2f62719dde71504e9c0cdeeec1aea6a3 100644 (file)
@@ -99,17 +99,29 @@ class BytesDumper(Dumper):
         return self._esc.escape_bytea(obj)
 
     def quote(self, obj: bytes) -> bytes:
-        if not self._qprefix:
-            if self.connection:
+        escaped = self.dump(obj)
+
+        # We cannot use the base quoting because escape_bytea already returns
+        # the quotes content. if scs is off it will escape the backslashes in
+        # the format, otherwise it won't, but it doesn't tell us what quotes to
+        # use.
+        if self.connection:
+            if not self._qprefix:
                 scs = self.connection.pgconn.parameter_status(
                     b"standard_conforming_strings"
                 )
                 self._qprefix = b"'" if scs == b"on" else b" E'"
-            else:
-                self._qprefix = b" E'"
 
-        escaped = self.dump(obj)
-        return self._qprefix + bytes(escaped) + b"'"
+            return self._qprefix + escaped + b"'"
+
+        # We don't have a connection, so someone is using us to generate a file
+        # to use off-line or something like that. PQescapeBytea, like its
+        # string counterpart, is not predictable whether it will escape
+        # backslashes.
+        rv: bytes = b" E'" + escaped + b"'"
+        if self._esc.escape_bytea(b"\x00") == b"\\000":
+            rv = rv.replace(b"\\", b"\\\\")
+        return rv
 
 
 class BytesBinaryDumper(Dumper):
index ce640f8e29c9041a4ca40d7cd2659bbfb4ab553a..51d75ce690f7c6b47eab4381020013815c4c18b1 100644 (file)
@@ -16,10 +16,12 @@ equivalent C implementations.
 from typing import Any
 
 cimport cython
+
+from libc.string cimport memcpy, memchr
 from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
 from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_AS_STRING
 
-from psycopg_c.pq cimport _buffer_as_string_and_size
+from psycopg_c.pq cimport _buffer_as_string_and_size, Escaping
 
 from psycopg import errors as e
 from psycopg.pq.misc import error_message
@@ -57,44 +59,67 @@ cdef class CDumper:
         """
         raise NotImplementedError()
 
-    def dump(self, obj: Any) -> bytearray:
+    def dump(self, obj):
         """Return the Postgres representation of *obj* as Python array of bytes"""
         cdef rv = PyByteArray_FromStringAndSize("", 0)
         cdef Py_ssize_t length = self.cdump(obj, rv, 0)
         PyByteArray_Resize(rv, length)
         return rv
 
-    def quote(self, obj: Any) -> bytearray:
+    def quote(self, obj):
         cdef char *ptr
         cdef char *ptr_out
-        cdef Py_ssize_t length, len_out
-        cdef int error
-        cdef bytearray rv
+        cdef Py_ssize_t length
 
-        pyout = self.dump(obj)
-        _buffer_as_string_and_size(pyout, &ptr, &length)
-        rv = PyByteArray_FromStringAndSize("", 0)
-        PyByteArray_Resize(rv, length * 2 + 3)  # Must include the quotes
-        ptr_out = PyByteArray_AS_STRING(rv)
+        value = self.dump(obj)
 
         if self._pgconn is not None:
-            if self._pgconn._pgconn_ptr == NULL:
-                raise e.OperationalError("the connection is closed")
-
-            len_out = libpq.PQescapeStringConn(
-                self._pgconn._pgconn_ptr, ptr_out + 1, ptr, length, &error
-            )
-            if error:
-                raise e.OperationalError(
-                    f"escape_string failed: {error_message(self._pgconn)}"
-                )
-        else:
-            len_out = libpq.PQescapeString(ptr_out + 1, ptr, length)
-
-        ptr_out[0] = b'\''
-        ptr_out[len_out + 1] = b'\''
-        PyByteArray_Resize(rv, len_out + 2)
+            esc = Escaping(self._pgconn)
+            # escaping and quoting
+            return esc.escape_literal(value)
 
+        # This path is taken when quote is asked without a connection,
+        # usually it means by psycopg.sql.quote() or by
+        # 'Composible.as_string(None)'. Most often than not this is done by
+        # someone generating a SQL file to consume elsewhere.
+
+        rv = PyByteArray_FromStringAndSize("", 0)
+
+        # No quoting, only quote escaping, random bs escaping. See further.
+        esc = Escaping()
+        out = esc.escape_string(value)
+
+        _buffer_as_string_and_size(out, &ptr, &length)
+
+        if not memchr(ptr, b'\\', length):
+            # If the string has no backslash, the result is correct and we
+            # don't need to bother with standard_conforming_strings.
+            PyByteArray_Resize(rv, length + 2)  # Must include the quotes
+            ptr_out = PyByteArray_AS_STRING(rv)
+            ptr_out[0] = b"'"
+            memcpy(ptr_out + 1, ptr, length)
+            ptr_out[length + 1] = b"'"
+            return rv
+
+        # The libpq has a crazy behaviour: PQescapeString uses the last
+        # standard_conforming_strings setting seen on a connection. This
+        # means that backslashes might be escaped or might not.
+        #
+        # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+        # if scs is off, '\\' raises a warning and '\' is an error.
+        #
+        # Check what the libpq does, and if it doesn't escape the backslash
+        # let's do it on our own. Never mind the race condition.
+        PyByteArray_Resize(rv, length + 4)  # Must include " E'...'" quotes
+        ptr_out = PyByteArray_AS_STRING(rv)
+        ptr_out[0] = b" "
+        ptr_out[1] = b"E"
+        ptr_out[2] = b"'"
+        memcpy(ptr_out + 3, ptr, length)
+        ptr_out[length + 3] = b"'"
+
+        if esc.escape_string(b"\\") == b"\\":
+            rv = bytes(rv).replace(b"\\", b"\\\\")
         return rv
 
     cpdef object get_key(self, object obj, object format):
index cead72b21b2ad7777283836e1cdca908db98b06a..cf30906ea1a4393854dea8329a8c5e6c7625aff4 100644 (file)
@@ -47,6 +47,12 @@ cdef class PGcancel:
 cdef class Escaping:
     cdef PGconn conn
 
+    cpdef escape_literal(self, data)
+    cpdef escape_identifier(self, data)
+    cpdef escape_string(self, data)
+    cpdef escape_bytea(self, data)
+    cpdef unescape_bytea(self, const unsigned char *data)
+
 
 cdef class PQBuffer:
     cdef unsigned char *buf
index 836412ba19c37d5f46567fff50d36b13552a37f9..eaa2d07946d4107708024f34e61c934483ef62e9 100644 (file)
@@ -14,7 +14,7 @@ cdef class Escaping:
     def __init__(self, PGconn conn = None):
         self.conn = conn
 
-    def escape_literal(self, data: "Buffer") -> memoryview:
+    cpdef escape_literal(self, data):
         cdef char *out
         cdef bytes rv
         cdef char *ptr
@@ -37,7 +37,7 @@ cdef class Escaping:
             PQBuffer._from_buffer(<unsigned char *>out, strlen(out))
         )
 
-    def escape_identifier(self, data: "Buffer") -> memoryview:
+    cpdef escape_identifier(self, data):
         cdef char *out
         cdef char *ptr
         cdef Py_ssize_t length
@@ -59,7 +59,7 @@ cdef class Escaping:
             PQBuffer._from_buffer(<unsigned char *>out, strlen(out))
         )
 
-    def escape_string(self, data: "Buffer") -> memoryview:
+    cpdef escape_string(self, data):
         cdef int error
         cdef size_t len_out
         cdef char *ptr
@@ -91,7 +91,7 @@ cdef class Escaping:
         PyByteArray_Resize(rv, len_out)
         return PyMemoryView_FromObject(rv)
 
-    def escape_bytea(self, data: "Buffer") -> memoryview:
+    cpdef escape_bytea(self, data):
         cdef size_t len_out
         cdef unsigned char *out
         cdef char *ptr
@@ -117,7 +117,7 @@ cdef class Escaping:
             PQBuffer._from_buffer(out, len_out - 1)  # out includes final 0
         )
 
-    def unescape_bytea(self, const unsigned char *data) -> memoryview:
+    cpdef unescape_bytea(self, const unsigned char *data):
         # not needed, but let's keep it symmetric with the escaping:
         # if a connection is passed in, it must be valid.
         if self.conn is not None:
index 84aa34458bcf46d96d5f80908845de806c4dcd74..27e062cd7d5881f96d84a6314a11d24b3241c580 100644 (file)
@@ -181,53 +181,58 @@ cdef class BytesDumper(CDumper):
         libpq.PQfreemem(out)
         return len_out
 
-    def quote(self, obj) -> bytearray:
-        cdef size_t len_esc
+    def quote(self, obj):
         cdef size_t len_out
         cdef unsigned char *out
         cdef char *ptr
         cdef Py_ssize_t length
-        cdef bytearray rv
+        cdef const char *scs
 
-        _buffer_as_string_and_size(obj, &ptr, &length)
-
-        if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
-            out = libpq.PQescapeByteaConn(
-                self._pgconn._pgconn_ptr, <unsigned char *>ptr, length, &len_esc)
-        else:
-            out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_esc)
+        escaped = self.dump(obj)
+        _buffer_as_string_and_size(escaped, &ptr, &length)
 
-        if out is NULL:
-            raise MemoryError(
-                f"couldn't allocate for escape_bytea of {length} bytes"
-            )
+        rv = PyByteArray_FromStringAndSize("", 0)
 
-        if not self._qplen:
-            self._qplen = 3
-            if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
+        # We cannot use the base quoting because escape_bytea already returns
+        # the quotes content. if scs is off it will escape the backslashes in
+        # the format, otherwise it won't, but it doesn't tell us what quotes to
+        # use.
+        if self._pgconn is not None:
+            if not self._qplen:
                 scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr,
                     b"standard_conforming_strings")
                 if scs and scs[0] == b'o' and scs[1] == b"n":  # == "on"
                     self._qplen = 1
+                else:
+                    self._qplen = 3
 
-        # len_esc includes the final 0
-        # So the final string is len_esc - 1 plus quotes, which might be 2 or 4 bytes
-        len_out = len_esc + self._qplen
-        rv = PyByteArray_FromStringAndSize("", 0)
-        cdef char *buf = CDumper.ensure_size(rv, 0, len_out)
-        memcpy(buf + self._qplen, out, len_esc)
-        libpq.PQfreemem(out)
-
-        if self._qplen == 3:
-            # Quote as " E'content'"
-            buf[0] = b' '
-            buf[1] = b'E'
-            buf[2] = b'\''
-        else:
-            # Quote as "'content'"
-            buf[0] = b'\''
-
-        buf[len_out - 1] = b'\''
+            PyByteArray_Resize(rv, length + self._qplen + 1)  # Include quotes
+            ptr_out = PyByteArray_AS_STRING(rv)
+            if self._qplen == 1:
+                ptr_out[0] = b"'"
+            else:
+                ptr_out[0] = b" "
+                ptr_out[1] = b"E"
+                ptr_out[2] = b"'"
+            memcpy(ptr_out + self._qplen, ptr, length)
+            ptr_out[length + self._qplen] = b"'"
+            return rv
+
+        # We don't have a connection, so someone is using us to generate a file
+        # to use off-line or something like that. PQescapeBytea, like its
+        # string counterpart, is not predictable whether it will escape
+        # backslashes.
+        PyByteArray_Resize(rv, length + 4)  # Include quotes
+        ptr_out = PyByteArray_AS_STRING(rv)
+        ptr_out[0] = b" "
+        ptr_out[1] = b"E"
+        ptr_out[2] = b"'"
+        memcpy(ptr_out + 3, ptr, length)
+        ptr_out[length + 3] = b"'"
+
+        esc = Escaping()
+        if esc.escape_bytea(b"\x00") == b"\\000":
+            rv = bytes(rv).replace(b"\\", b"\\\\")
 
         return rv
 
index 2232b336ae44a901fd3eef1c0b3a1164388d08c7..ee8411520d712c0081aa6319a88ff6629131b110 100644 (file)
@@ -7,18 +7,70 @@ import datetime as dt
 
 import pytest
 
-from psycopg import sql, ProgrammingError
+from psycopg import pq, sql, ProgrammingError
 from psycopg.adapt import PyFormat as Format
 
 
 @pytest.mark.parametrize(
     "obj, quoted",
-    [("hello", "'hello'"), (42, "42"), (True, "true"), (None, "NULL")],
+    [
+        ("foo\\bar", " E'foo\\\\bar'"),
+        ("hello", "'hello'"),
+        (42, "42"),
+        (True, "true"),
+        (None, "NULL"),
+    ],
 )
 def test_quote(obj, quoted):
     assert sql.quote(obj) == quoted
 
 
+@pytest.mark.parametrize("scs", ["on", "off"])
+def test_quote_roundtrip(conn, scs):
+    messages = []
+    conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+    conn.execute(f"set standard_conforming_strings to {scs}")
+
+    for i in range(1, 256):
+        want = chr(i)
+        quoted = sql.quote(want)
+        got = conn.execute(f"select {quoted}::text").fetchone()[0]
+        assert want == got
+
+        # No "nonstandard use of \\ in a string literal" warning
+        assert not messages, f"error with {want!r}"
+
+
+def test_quote_stable_despite_deranged_libpq(conn):
+    # Verify the libpq behaviour of PQescapeString using the last setting seen.
+    # Check that we are not affected by it.
+    good_str = " E'\\\\'"
+    good_bytes = " E'\\\\000'"
+    conn.execute("set standard_conforming_strings to on")
+    assert pq.Escaping().escape_string(b"\\") == b"\\"
+    assert sql.quote("\\") == good_str
+    assert pq.Escaping().escape_bytea(b"\x00") == b"\\000"
+    assert sql.quote(b"\x00") == good_bytes
+
+    conn.execute("set standard_conforming_strings to off")
+    assert pq.Escaping().escape_string(b"\\") == b"\\\\"
+    assert sql.quote("\\") == good_str
+    assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000"
+    assert sql.quote(b"\x00") == good_bytes
+
+    # Verify that the good values are actually good
+    messages = []
+    conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+    conn.execute("set escape_string_warning to on")
+    for scs in ("on", "off"):
+        conn.execute(f"set standard_conforming_strings to {scs}")
+        cur = conn.execute(f"select {good_str}, {good_bytes}::bytea")
+        assert cur.fetchone() == ("\\", b"\x00")
+
+    # No "nonstandard use of \\ in a string literal" warning
+    assert not messages
+
+
 class TestSqlFormat:
     def test_pos(self, conn):
         s = sql.SQL("select {} from {}").format(
@@ -188,7 +240,7 @@ class TestIdentifier:
 
     def test_init(self):
         assert isinstance(sql.Identifier("foo"), sql.Identifier)
-        assert isinstance(sql.Identifier(u"foo"), sql.Identifier)
+        assert isinstance(sql.Identifier("foo"), sql.Identifier)
         assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier)
         with pytest.raises(TypeError):
             sql.Identifier()
@@ -230,7 +282,7 @@ class TestLiteral:
 
     def test_init(self):
         assert isinstance(sql.Literal("foo"), sql.Literal)
-        assert isinstance(sql.Literal(u"foo"), sql.Literal)
+        assert isinstance(sql.Literal("foo"), sql.Literal)
         assert isinstance(sql.Literal(b"foo"), sql.Literal)
         assert isinstance(sql.Literal(42), sql.Literal)
         assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal)
@@ -267,7 +319,7 @@ class TestSQL:
 
     def test_init(self):
         assert isinstance(sql.SQL("foo"), sql.SQL)
-        assert isinstance(sql.SQL(u"foo"), sql.SQL)
+        assert isinstance(sql.SQL("foo"), sql.SQL)
         with pytest.raises(TypeError):
             sql.SQL(10)
         with pytest.raises(TypeError):
index b3cdbfd3d083c4cd5869a47b07c53be42548298f..4c8c1444949a6af06324fe3463b74cc77a10b733 100644 (file)
@@ -18,19 +18,28 @@ eur = "\u20ac"
 def test_dump_1char(conn, fmt_in):
     cur = conn.cursor()
     for i in range(1, 256):
-        cur.execute(f"select %{fmt_in} = chr(%s::int)", (chr(i), i))
+        cur.execute(f"select %{fmt_in} = chr(%s)", (chr(i), i))
         assert cur.fetchone()[0] is True, chr(i)
 
 
-def test_quote_1char(conn):
+@pytest.mark.parametrize("scs", ["on", "off"])
+def test_quote_1char(conn, scs):
+    messages = []
+    conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+    conn.execute(f"set standard_conforming_strings to {scs}")
+    conn.execute("set escape_string_warning to on")
+
     cur = conn.cursor()
-    query = sql.SQL("select {ch} = chr(%s::int)")
+    query = sql.SQL("select {ch} = chr(%s)")
     for i in range(1, 256):
         if chr(i) == "%":
             continue
         cur.execute(query.format(ch=sql.Literal(chr(i))), (i,))
         assert cur.fetchone()[0] is True, chr(i)
 
+    # No "nonstandard use of \\ in a string literal" warning
+    assert not messages
+
 
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
 def test_dump_zero(conn, fmt_in):
@@ -57,7 +66,7 @@ def test_quote_percent(conn):
     assert cur.fetchone()[0] == "%"
 
     cur.execute(
-        sql.SQL("select {ch} = chr(%s::int)").format(ch=sql.Literal("%")),
+        sql.SQL("select {ch} = chr(%s)").format(ch=sql.Literal("%")),
         (ord("%"),),
     )
     assert cur.fetchone()[0] is True
@@ -68,7 +77,7 @@ def test_quote_percent(conn):
 def test_load_1char(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out)
     for i in range(1, 256):
-        cur.execute(f"select chr(%s::int)::{typename}", (i,))
+        cur.execute(f"select chr(%s)::{typename}", (i,))
         res = cur.fetchone()[0]
         assert res == chr(i)
 
@@ -126,12 +135,10 @@ def test_load_enc(conn, typename, encoding, fmt_out):
     cur = conn.cursor(binary=fmt_out)
 
     conn.client_encoding = encoding
-    (res,) = cur.execute(
-        f"select chr(%s::int)::{typename}", (ord(eur),)
-    ).fetchone()
+    (res,) = cur.execute(f"select chr(%s)::{typename}", (ord(eur),)).fetchone()
     assert res == eur
 
-    stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+    stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
         ord(eur), sql.SQL(fmt_out.name)
     )
     with cur.copy(stmt) as copy:
@@ -149,9 +156,9 @@ def test_load_badenc(conn, typename, fmt_out):
 
     conn.client_encoding = "latin1"
     with pytest.raises(psycopg.DataError):
-        cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
+        cur.execute(f"select chr(%s)::{typename}", (ord(eur),))
 
-    stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+    stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
         ord(eur), sql.SQL(fmt_out.name)
     )
     with cur.copy(stmt) as copy:
@@ -166,10 +173,10 @@ def test_load_ascii(conn, typename, fmt_out):
     cur = conn.cursor(binary=fmt_out)
 
     conn.client_encoding = "ascii"
-    cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
+    cur.execute(f"select chr(%s)::{typename}", (ord(eur),))
     assert cur.fetchone()[0] == eur.encode("utf8")
 
-    stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+    stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
         ord(eur), sql.SQL(fmt_out.name)
     )
     with cur.copy(stmt) as copy:
@@ -221,6 +228,7 @@ def test_quote_1byte(conn, scs):
     messages = []
     conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
     conn.execute(f"set standard_conforming_strings to {scs}")
+    conn.execute("set escape_string_warning to on")
 
     cur = conn.cursor()
     query = sql.SQL("select {ch} = set_byte('x', 0, %s)")