From 3d175f9c957dd3d49d65b8de7daa4d7117cf68f6 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 26 Jul 2021 02:19:46 +0200 Subject: [PATCH] Emit quoted values valid for every standard_conforming_string value Work around the unpredictable return value from the libpq. --- psycopg/psycopg/adapt.py | 35 +++++++++++- psycopg/psycopg/types/string.py | 24 ++++++-- psycopg_c/psycopg_c/_psycopg/adapt.pyx | 79 +++++++++++++++++--------- psycopg_c/psycopg_c/pq.pxd | 6 ++ psycopg_c/psycopg_c/pq/escaping.pyx | 10 ++-- psycopg_c/psycopg_c/types/string.pyx | 75 ++++++++++++------------ tests/test_sql.py | 62 ++++++++++++++++++-- tests/types/test_string.py | 34 ++++++----- 8 files changed, 231 insertions(+), 94 deletions(-) diff --git a/psycopg/psycopg/adapt.py b/psycopg/psycopg/adapt.py index 29c59f396..b3d655ca8 100644 --- a/psycopg/psycopg/adapt.py +++ b/psycopg/psycopg/adapt.py @@ -18,6 +18,8 @@ if TYPE_CHECKING: AdaptersMap = _adapters_map.AdaptersMap Buffer = abc.Buffer +ORD_BS = ord("\\") + class Dumper(abc.Dumper, ABC): """ @@ -60,10 +62,37 @@ class Dumper(abc.Dumper, ABC): if self.connection: esc = pq.Escaping(self.connection.pgconn) + # escaping and quoting return esc.escape_literal(value) - else: - esc = pq.Escaping() - return b"'%s'" % esc.escape_string(value) + + # This path is taken when quote is asked without a connection, + # usually it means by psycopg.sql.quote() or by + # 'Composible.as_string(None)'. Most often than not this is done by + # someone generating a SQL file to consume elsewhere. + + # No quoting, only quote escaping, random bs escaping. See further. + esc = pq.Escaping() + out = esc.escape_string(value) + + # b"\\" in memoryview doesn't work so search for the ascii value + if ORD_BS not in out: + # If the string has no backslash, the result is correct and we + # don't need to bother with standard_conforming_strings. + return b"'" + out + b"'" + + # The libpq has a crazy behaviour: PQescapeString uses the last + # standard_conforming_strings setting seen on a connection. This + # means that backslashes might be escaped or might not. + # + # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH, + # if scs is off, '\\' raises a warning and '\' is an error. + # + # Check what the libpq does, and if it doesn't escape the backslash + # let's do it on our own. Never mind the race condition. + rv: bytes = b" E'" + out + b"'" + if esc.escape_string(b"\\") == b"\\": + rv = rv.replace(b"\\", b"\\\\") + return rv def get_key( self, obj: Any, format: PyFormat diff --git a/psycopg/psycopg/types/string.py b/psycopg/psycopg/types/string.py index d1702f334..3833f17e2 100644 --- a/psycopg/psycopg/types/string.py +++ b/psycopg/psycopg/types/string.py @@ -99,17 +99,29 @@ class BytesDumper(Dumper): return self._esc.escape_bytea(obj) def quote(self, obj: bytes) -> bytes: - if not self._qprefix: - if self.connection: + escaped = self.dump(obj) + + # We cannot use the base quoting because escape_bytea already returns + # the quotes content. if scs is off it will escape the backslashes in + # the format, otherwise it won't, but it doesn't tell us what quotes to + # use. + if self.connection: + if not self._qprefix: scs = self.connection.pgconn.parameter_status( b"standard_conforming_strings" ) self._qprefix = b"'" if scs == b"on" else b" E'" - else: - self._qprefix = b" E'" - escaped = self.dump(obj) - return self._qprefix + bytes(escaped) + b"'" + return self._qprefix + escaped + b"'" + + # We don't have a connection, so someone is using us to generate a file + # to use off-line or something like that. PQescapeBytea, like its + # string counterpart, is not predictable whether it will escape + # backslashes. + rv: bytes = b" E'" + escaped + b"'" + if self._esc.escape_bytea(b"\x00") == b"\\000": + rv = rv.replace(b"\\", b"\\\\") + return rv class BytesBinaryDumper(Dumper): diff --git a/psycopg_c/psycopg_c/_psycopg/adapt.pyx b/psycopg_c/psycopg_c/_psycopg/adapt.pyx index ce640f8e2..51d75ce69 100644 --- a/psycopg_c/psycopg_c/_psycopg/adapt.pyx +++ b/psycopg_c/psycopg_c/_psycopg/adapt.pyx @@ -16,10 +16,12 @@ equivalent C implementations. from typing import Any cimport cython + +from libc.string cimport memcpy, memchr from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_AS_STRING -from psycopg_c.pq cimport _buffer_as_string_and_size +from psycopg_c.pq cimport _buffer_as_string_and_size, Escaping from psycopg import errors as e from psycopg.pq.misc import error_message @@ -57,44 +59,67 @@ cdef class CDumper: """ raise NotImplementedError() - def dump(self, obj: Any) -> bytearray: + def dump(self, obj): """Return the Postgres representation of *obj* as Python array of bytes""" cdef rv = PyByteArray_FromStringAndSize("", 0) cdef Py_ssize_t length = self.cdump(obj, rv, 0) PyByteArray_Resize(rv, length) return rv - def quote(self, obj: Any) -> bytearray: + def quote(self, obj): cdef char *ptr cdef char *ptr_out - cdef Py_ssize_t length, len_out - cdef int error - cdef bytearray rv + cdef Py_ssize_t length - pyout = self.dump(obj) - _buffer_as_string_and_size(pyout, &ptr, &length) - rv = PyByteArray_FromStringAndSize("", 0) - PyByteArray_Resize(rv, length * 2 + 3) # Must include the quotes - ptr_out = PyByteArray_AS_STRING(rv) + value = self.dump(obj) if self._pgconn is not None: - if self._pgconn._pgconn_ptr == NULL: - raise e.OperationalError("the connection is closed") - - len_out = libpq.PQescapeStringConn( - self._pgconn._pgconn_ptr, ptr_out + 1, ptr, length, &error - ) - if error: - raise e.OperationalError( - f"escape_string failed: {error_message(self._pgconn)}" - ) - else: - len_out = libpq.PQescapeString(ptr_out + 1, ptr, length) - - ptr_out[0] = b'\'' - ptr_out[len_out + 1] = b'\'' - PyByteArray_Resize(rv, len_out + 2) + esc = Escaping(self._pgconn) + # escaping and quoting + return esc.escape_literal(value) + # This path is taken when quote is asked without a connection, + # usually it means by psycopg.sql.quote() or by + # 'Composible.as_string(None)'. Most often than not this is done by + # someone generating a SQL file to consume elsewhere. + + rv = PyByteArray_FromStringAndSize("", 0) + + # No quoting, only quote escaping, random bs escaping. See further. + esc = Escaping() + out = esc.escape_string(value) + + _buffer_as_string_and_size(out, &ptr, &length) + + if not memchr(ptr, b'\\', length): + # If the string has no backslash, the result is correct and we + # don't need to bother with standard_conforming_strings. + PyByteArray_Resize(rv, length + 2) # Must include the quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b"'" + memcpy(ptr_out + 1, ptr, length) + ptr_out[length + 1] = b"'" + return rv + + # The libpq has a crazy behaviour: PQescapeString uses the last + # standard_conforming_strings setting seen on a connection. This + # means that backslashes might be escaped or might not. + # + # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH, + # if scs is off, '\\' raises a warning and '\' is an error. + # + # Check what the libpq does, and if it doesn't escape the backslash + # let's do it on our own. Never mind the race condition. + PyByteArray_Resize(rv, length + 4) # Must include " E'...'" quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + 3, ptr, length) + ptr_out[length + 3] = b"'" + + if esc.escape_string(b"\\") == b"\\": + rv = bytes(rv).replace(b"\\", b"\\\\") return rv cpdef object get_key(self, object obj, object format): diff --git a/psycopg_c/psycopg_c/pq.pxd b/psycopg_c/psycopg_c/pq.pxd index cead72b21..cf30906ea 100644 --- a/psycopg_c/psycopg_c/pq.pxd +++ b/psycopg_c/psycopg_c/pq.pxd @@ -47,6 +47,12 @@ cdef class PGcancel: cdef class Escaping: cdef PGconn conn + cpdef escape_literal(self, data) + cpdef escape_identifier(self, data) + cpdef escape_string(self, data) + cpdef escape_bytea(self, data) + cpdef unescape_bytea(self, const unsigned char *data) + cdef class PQBuffer: cdef unsigned char *buf diff --git a/psycopg_c/psycopg_c/pq/escaping.pyx b/psycopg_c/psycopg_c/pq/escaping.pyx index 836412ba1..eaa2d0794 100644 --- a/psycopg_c/psycopg_c/pq/escaping.pyx +++ b/psycopg_c/psycopg_c/pq/escaping.pyx @@ -14,7 +14,7 @@ cdef class Escaping: def __init__(self, PGconn conn = None): self.conn = conn - def escape_literal(self, data: "Buffer") -> memoryview: + cpdef escape_literal(self, data): cdef char *out cdef bytes rv cdef char *ptr @@ -37,7 +37,7 @@ cdef class Escaping: PQBuffer._from_buffer(out, strlen(out)) ) - def escape_identifier(self, data: "Buffer") -> memoryview: + cpdef escape_identifier(self, data): cdef char *out cdef char *ptr cdef Py_ssize_t length @@ -59,7 +59,7 @@ cdef class Escaping: PQBuffer._from_buffer(out, strlen(out)) ) - def escape_string(self, data: "Buffer") -> memoryview: + cpdef escape_string(self, data): cdef int error cdef size_t len_out cdef char *ptr @@ -91,7 +91,7 @@ cdef class Escaping: PyByteArray_Resize(rv, len_out) return PyMemoryView_FromObject(rv) - def escape_bytea(self, data: "Buffer") -> memoryview: + cpdef escape_bytea(self, data): cdef size_t len_out cdef unsigned char *out cdef char *ptr @@ -117,7 +117,7 @@ cdef class Escaping: PQBuffer._from_buffer(out, len_out - 1) # out includes final 0 ) - def unescape_bytea(self, const unsigned char *data) -> memoryview: + cpdef unescape_bytea(self, const unsigned char *data): # not needed, but let's keep it symmetric with the escaping: # if a connection is passed in, it must be valid. if self.conn is not None: diff --git a/psycopg_c/psycopg_c/types/string.pyx b/psycopg_c/psycopg_c/types/string.pyx index 84aa34458..27e062cd7 100644 --- a/psycopg_c/psycopg_c/types/string.pyx +++ b/psycopg_c/psycopg_c/types/string.pyx @@ -181,53 +181,58 @@ cdef class BytesDumper(CDumper): libpq.PQfreemem(out) return len_out - def quote(self, obj) -> bytearray: - cdef size_t len_esc + def quote(self, obj): cdef size_t len_out cdef unsigned char *out cdef char *ptr cdef Py_ssize_t length - cdef bytearray rv + cdef const char *scs - _buffer_as_string_and_size(obj, &ptr, &length) - - if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: - out = libpq.PQescapeByteaConn( - self._pgconn._pgconn_ptr, ptr, length, &len_esc) - else: - out = libpq.PQescapeBytea(ptr, length, &len_esc) + escaped = self.dump(obj) + _buffer_as_string_and_size(escaped, &ptr, &length) - if out is NULL: - raise MemoryError( - f"couldn't allocate for escape_bytea of {length} bytes" - ) + rv = PyByteArray_FromStringAndSize("", 0) - if not self._qplen: - self._qplen = 3 - if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL: + # We cannot use the base quoting because escape_bytea already returns + # the quotes content. if scs is off it will escape the backslashes in + # the format, otherwise it won't, but it doesn't tell us what quotes to + # use. + if self._pgconn is not None: + if not self._qplen: scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr, b"standard_conforming_strings") if scs and scs[0] == b'o' and scs[1] == b"n": # == "on" self._qplen = 1 + else: + self._qplen = 3 - # len_esc includes the final 0 - # So the final string is len_esc - 1 plus quotes, which might be 2 or 4 bytes - len_out = len_esc + self._qplen - rv = PyByteArray_FromStringAndSize("", 0) - cdef char *buf = CDumper.ensure_size(rv, 0, len_out) - memcpy(buf + self._qplen, out, len_esc) - libpq.PQfreemem(out) - - if self._qplen == 3: - # Quote as " E'content'" - buf[0] = b' ' - buf[1] = b'E' - buf[2] = b'\'' - else: - # Quote as "'content'" - buf[0] = b'\'' - - buf[len_out - 1] = b'\'' + PyByteArray_Resize(rv, length + self._qplen + 1) # Include quotes + ptr_out = PyByteArray_AS_STRING(rv) + if self._qplen == 1: + ptr_out[0] = b"'" + else: + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + self._qplen, ptr, length) + ptr_out[length + self._qplen] = b"'" + return rv + + # We don't have a connection, so someone is using us to generate a file + # to use off-line or something like that. PQescapeBytea, like its + # string counterpart, is not predictable whether it will escape + # backslashes. + PyByteArray_Resize(rv, length + 4) # Include quotes + ptr_out = PyByteArray_AS_STRING(rv) + ptr_out[0] = b" " + ptr_out[1] = b"E" + ptr_out[2] = b"'" + memcpy(ptr_out + 3, ptr, length) + ptr_out[length + 3] = b"'" + + esc = Escaping() + if esc.escape_bytea(b"\x00") == b"\\000": + rv = bytes(rv).replace(b"\\", b"\\\\") return rv diff --git a/tests/test_sql.py b/tests/test_sql.py index 2232b336a..ee8411520 100644 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -7,18 +7,70 @@ import datetime as dt import pytest -from psycopg import sql, ProgrammingError +from psycopg import pq, sql, ProgrammingError from psycopg.adapt import PyFormat as Format @pytest.mark.parametrize( "obj, quoted", - [("hello", "'hello'"), (42, "42"), (True, "true"), (None, "NULL")], + [ + ("foo\\bar", " E'foo\\\\bar'"), + ("hello", "'hello'"), + (42, "42"), + (True, "true"), + (None, "NULL"), + ], ) def test_quote(obj, quoted): assert sql.quote(obj) == quoted +@pytest.mark.parametrize("scs", ["on", "off"]) +def test_quote_roundtrip(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + + for i in range(1, 256): + want = chr(i) + quoted = sql.quote(want) + got = conn.execute(f"select {quoted}::text").fetchone()[0] + assert want == got + + # No "nonstandard use of \\ in a string literal" warning + assert not messages, f"error with {want!r}" + + +def test_quote_stable_despite_deranged_libpq(conn): + # Verify the libpq behaviour of PQescapeString using the last setting seen. + # Check that we are not affected by it. + good_str = " E'\\\\'" + good_bytes = " E'\\\\000'" + conn.execute("set standard_conforming_strings to on") + assert pq.Escaping().escape_string(b"\\") == b"\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\000" + assert sql.quote(b"\x00") == good_bytes + + conn.execute("set standard_conforming_strings to off") + assert pq.Escaping().escape_string(b"\\") == b"\\\\" + assert sql.quote("\\") == good_str + assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000" + assert sql.quote(b"\x00") == good_bytes + + # Verify that the good values are actually good + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute("set escape_string_warning to on") + for scs in ("on", "off"): + conn.execute(f"set standard_conforming_strings to {scs}") + cur = conn.execute(f"select {good_str}, {good_bytes}::bytea") + assert cur.fetchone() == ("\\", b"\x00") + + # No "nonstandard use of \\ in a string literal" warning + assert not messages + + class TestSqlFormat: def test_pos(self, conn): s = sql.SQL("select {} from {}").format( @@ -188,7 +240,7 @@ class TestIdentifier: def test_init(self): assert isinstance(sql.Identifier("foo"), sql.Identifier) - assert isinstance(sql.Identifier(u"foo"), sql.Identifier) + assert isinstance(sql.Identifier("foo"), sql.Identifier) assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier) with pytest.raises(TypeError): sql.Identifier() @@ -230,7 +282,7 @@ class TestLiteral: def test_init(self): assert isinstance(sql.Literal("foo"), sql.Literal) - assert isinstance(sql.Literal(u"foo"), sql.Literal) + assert isinstance(sql.Literal("foo"), sql.Literal) assert isinstance(sql.Literal(b"foo"), sql.Literal) assert isinstance(sql.Literal(42), sql.Literal) assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal) @@ -267,7 +319,7 @@ class TestSQL: def test_init(self): assert isinstance(sql.SQL("foo"), sql.SQL) - assert isinstance(sql.SQL(u"foo"), sql.SQL) + assert isinstance(sql.SQL("foo"), sql.SQL) with pytest.raises(TypeError): sql.SQL(10) with pytest.raises(TypeError): diff --git a/tests/types/test_string.py b/tests/types/test_string.py index b3cdbfd3d..4c8c14449 100644 --- a/tests/types/test_string.py +++ b/tests/types/test_string.py @@ -18,19 +18,28 @@ eur = "\u20ac" def test_dump_1char(conn, fmt_in): cur = conn.cursor() for i in range(1, 256): - cur.execute(f"select %{fmt_in} = chr(%s::int)", (chr(i), i)) + cur.execute(f"select %{fmt_in} = chr(%s)", (chr(i), i)) assert cur.fetchone()[0] is True, chr(i) -def test_quote_1char(conn): +@pytest.mark.parametrize("scs", ["on", "off"]) +def test_quote_1char(conn, scs): + messages = [] + conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) + conn.execute(f"set standard_conforming_strings to {scs}") + conn.execute("set escape_string_warning to on") + cur = conn.cursor() - query = sql.SQL("select {ch} = chr(%s::int)") + query = sql.SQL("select {ch} = chr(%s)") for i in range(1, 256): if chr(i) == "%": continue cur.execute(query.format(ch=sql.Literal(chr(i))), (i,)) assert cur.fetchone()[0] is True, chr(i) + # No "nonstandard use of \\ in a string literal" warning + assert not messages + @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) def test_dump_zero(conn, fmt_in): @@ -57,7 +66,7 @@ def test_quote_percent(conn): assert cur.fetchone()[0] == "%" cur.execute( - sql.SQL("select {ch} = chr(%s::int)").format(ch=sql.Literal("%")), + sql.SQL("select {ch} = chr(%s)").format(ch=sql.Literal("%")), (ord("%"),), ) assert cur.fetchone()[0] is True @@ -68,7 +77,7 @@ def test_quote_percent(conn): def test_load_1char(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out) for i in range(1, 256): - cur.execute(f"select chr(%s::int)::{typename}", (i,)) + cur.execute(f"select chr(%s)::{typename}", (i,)) res = cur.fetchone()[0] assert res == chr(i) @@ -126,12 +135,10 @@ def test_load_enc(conn, typename, encoding, fmt_out): cur = conn.cursor(binary=fmt_out) conn.client_encoding = encoding - (res,) = cur.execute( - f"select chr(%s::int)::{typename}", (ord(eur),) - ).fetchone() + (res,) = cur.execute(f"select chr(%s)::{typename}", (ord(eur),)).fetchone() assert res == eur - stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format( + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( ord(eur), sql.SQL(fmt_out.name) ) with cur.copy(stmt) as copy: @@ -149,9 +156,9 @@ def test_load_badenc(conn, typename, fmt_out): conn.client_encoding = "latin1" with pytest.raises(psycopg.DataError): - cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) + cur.execute(f"select chr(%s)::{typename}", (ord(eur),)) - stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format( + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( ord(eur), sql.SQL(fmt_out.name) ) with cur.copy(stmt) as copy: @@ -166,10 +173,10 @@ def test_load_ascii(conn, typename, fmt_out): cur = conn.cursor(binary=fmt_out) conn.client_encoding = "ascii" - cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),)) + cur.execute(f"select chr(%s)::{typename}", (ord(eur),)) assert cur.fetchone()[0] == eur.encode("utf8") - stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format( + stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format( ord(eur), sql.SQL(fmt_out.name) ) with cur.copy(stmt) as copy: @@ -221,6 +228,7 @@ def test_quote_1byte(conn, scs): messages = [] conn.add_notice_handler(lambda msg: messages.append(msg.message_primary)) conn.execute(f"set standard_conforming_strings to {scs}") + conn.execute("set escape_string_warning to on") cur = conn.cursor() query = sql.SQL("select {ch} = set_byte('x', 0, %s)") -- 2.47.3