Work around the unpredictable return value from the libpq.
AdaptersMap = _adapters_map.AdaptersMap
Buffer = abc.Buffer
+ORD_BS = ord("\\")
+
class Dumper(abc.Dumper, ABC):
"""
if self.connection:
esc = pq.Escaping(self.connection.pgconn)
+ # escaping and quoting
return esc.escape_literal(value)
- else:
- esc = pq.Escaping()
- return b"'%s'" % esc.escape_string(value)
+
+ # This path is taken when quote is asked without a connection,
+ # usually it means by psycopg.sql.quote() or by
+ # 'Composible.as_string(None)'. Most often than not this is done by
+ # someone generating a SQL file to consume elsewhere.
+
+ # No quoting, only quote escaping, random bs escaping. See further.
+ esc = pq.Escaping()
+ out = esc.escape_string(value)
+
+ # b"\\" in memoryview doesn't work so search for the ascii value
+ if ORD_BS not in out:
+ # If the string has no backslash, the result is correct and we
+ # don't need to bother with standard_conforming_strings.
+ return b"'" + out + b"'"
+
+ # The libpq has a crazy behaviour: PQescapeString uses the last
+ # standard_conforming_strings setting seen on a connection. This
+ # means that backslashes might be escaped or might not.
+ #
+ # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+ # if scs is off, '\\' raises a warning and '\' is an error.
+ #
+ # Check what the libpq does, and if it doesn't escape the backslash
+ # let's do it on our own. Never mind the race condition.
+ rv: bytes = b" E'" + out + b"'"
+ if esc.escape_string(b"\\") == b"\\":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
def get_key(
self, obj: Any, format: PyFormat
return self._esc.escape_bytea(obj)
def quote(self, obj: bytes) -> bytes:
- if not self._qprefix:
- if self.connection:
+ escaped = self.dump(obj)
+
+ # We cannot use the base quoting because escape_bytea already returns
+ # the quotes content. if scs is off it will escape the backslashes in
+ # the format, otherwise it won't, but it doesn't tell us what quotes to
+ # use.
+ if self.connection:
+ if not self._qprefix:
scs = self.connection.pgconn.parameter_status(
b"standard_conforming_strings"
)
self._qprefix = b"'" if scs == b"on" else b" E'"
- else:
- self._qprefix = b" E'"
- escaped = self.dump(obj)
- return self._qprefix + bytes(escaped) + b"'"
+ return self._qprefix + escaped + b"'"
+
+ # We don't have a connection, so someone is using us to generate a file
+ # to use off-line or something like that. PQescapeBytea, like its
+ # string counterpart, is not predictable whether it will escape
+ # backslashes.
+ rv: bytes = b" E'" + escaped + b"'"
+ if self._esc.escape_bytea(b"\x00") == b"\\000":
+ rv = rv.replace(b"\\", b"\\\\")
+ return rv
class BytesBinaryDumper(Dumper):
from typing import Any
cimport cython
+
+from libc.string cimport memcpy, memchr
from cpython.bytearray cimport PyByteArray_FromStringAndSize, PyByteArray_Resize
from cpython.bytearray cimport PyByteArray_GET_SIZE, PyByteArray_AS_STRING
-from psycopg_c.pq cimport _buffer_as_string_and_size
+from psycopg_c.pq cimport _buffer_as_string_and_size, Escaping
from psycopg import errors as e
from psycopg.pq.misc import error_message
"""
raise NotImplementedError()
- def dump(self, obj: Any) -> bytearray:
+ def dump(self, obj):
"""Return the Postgres representation of *obj* as Python array of bytes"""
cdef rv = PyByteArray_FromStringAndSize("", 0)
cdef Py_ssize_t length = self.cdump(obj, rv, 0)
PyByteArray_Resize(rv, length)
return rv
- def quote(self, obj: Any) -> bytearray:
+ def quote(self, obj):
cdef char *ptr
cdef char *ptr_out
- cdef Py_ssize_t length, len_out
- cdef int error
- cdef bytearray rv
+ cdef Py_ssize_t length
- pyout = self.dump(obj)
- _buffer_as_string_and_size(pyout, &ptr, &length)
- rv = PyByteArray_FromStringAndSize("", 0)
- PyByteArray_Resize(rv, length * 2 + 3) # Must include the quotes
- ptr_out = PyByteArray_AS_STRING(rv)
+ value = self.dump(obj)
if self._pgconn is not None:
- if self._pgconn._pgconn_ptr == NULL:
- raise e.OperationalError("the connection is closed")
-
- len_out = libpq.PQescapeStringConn(
- self._pgconn._pgconn_ptr, ptr_out + 1, ptr, length, &error
- )
- if error:
- raise e.OperationalError(
- f"escape_string failed: {error_message(self._pgconn)}"
- )
- else:
- len_out = libpq.PQescapeString(ptr_out + 1, ptr, length)
-
- ptr_out[0] = b'\''
- ptr_out[len_out + 1] = b'\''
- PyByteArray_Resize(rv, len_out + 2)
+ esc = Escaping(self._pgconn)
+ # escaping and quoting
+ return esc.escape_literal(value)
+ # This path is taken when quote is asked without a connection,
+ # usually it means by psycopg.sql.quote() or by
+ # 'Composible.as_string(None)'. Most often than not this is done by
+ # someone generating a SQL file to consume elsewhere.
+
+ rv = PyByteArray_FromStringAndSize("", 0)
+
+ # No quoting, only quote escaping, random bs escaping. See further.
+ esc = Escaping()
+ out = esc.escape_string(value)
+
+ _buffer_as_string_and_size(out, &ptr, &length)
+
+ if not memchr(ptr, b'\\', length):
+ # If the string has no backslash, the result is correct and we
+ # don't need to bother with standard_conforming_strings.
+ PyByteArray_Resize(rv, length + 2) # Must include the quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b"'"
+ memcpy(ptr_out + 1, ptr, length)
+ ptr_out[length + 1] = b"'"
+ return rv
+
+ # The libpq has a crazy behaviour: PQescapeString uses the last
+ # standard_conforming_strings setting seen on a connection. This
+ # means that backslashes might be escaped or might not.
+ #
+ # A syntax E'\\' works everywhere, whereas E'\' is an error. OTOH,
+ # if scs is off, '\\' raises a warning and '\' is an error.
+ #
+ # Check what the libpq does, and if it doesn't escape the backslash
+ # let's do it on our own. Never mind the race condition.
+ PyByteArray_Resize(rv, length + 4) # Must include " E'...'" quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + 3, ptr, length)
+ ptr_out[length + 3] = b"'"
+
+ if esc.escape_string(b"\\") == b"\\":
+ rv = bytes(rv).replace(b"\\", b"\\\\")
return rv
cpdef object get_key(self, object obj, object format):
cdef class Escaping:
cdef PGconn conn
+ cpdef escape_literal(self, data)
+ cpdef escape_identifier(self, data)
+ cpdef escape_string(self, data)
+ cpdef escape_bytea(self, data)
+ cpdef unescape_bytea(self, const unsigned char *data)
+
cdef class PQBuffer:
cdef unsigned char *buf
def __init__(self, PGconn conn = None):
self.conn = conn
- def escape_literal(self, data: "Buffer") -> memoryview:
+ cpdef escape_literal(self, data):
cdef char *out
cdef bytes rv
cdef char *ptr
PQBuffer._from_buffer(<unsigned char *>out, strlen(out))
)
- def escape_identifier(self, data: "Buffer") -> memoryview:
+ cpdef escape_identifier(self, data):
cdef char *out
cdef char *ptr
cdef Py_ssize_t length
PQBuffer._from_buffer(<unsigned char *>out, strlen(out))
)
- def escape_string(self, data: "Buffer") -> memoryview:
+ cpdef escape_string(self, data):
cdef int error
cdef size_t len_out
cdef char *ptr
PyByteArray_Resize(rv, len_out)
return PyMemoryView_FromObject(rv)
- def escape_bytea(self, data: "Buffer") -> memoryview:
+ cpdef escape_bytea(self, data):
cdef size_t len_out
cdef unsigned char *out
cdef char *ptr
PQBuffer._from_buffer(out, len_out - 1) # out includes final 0
)
- def unescape_bytea(self, const unsigned char *data) -> memoryview:
+ cpdef unescape_bytea(self, const unsigned char *data):
# not needed, but let's keep it symmetric with the escaping:
# if a connection is passed in, it must be valid.
if self.conn is not None:
libpq.PQfreemem(out)
return len_out
- def quote(self, obj) -> bytearray:
- cdef size_t len_esc
+ def quote(self, obj):
cdef size_t len_out
cdef unsigned char *out
cdef char *ptr
cdef Py_ssize_t length
- cdef bytearray rv
+ cdef const char *scs
- _buffer_as_string_and_size(obj, &ptr, &length)
-
- if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
- out = libpq.PQescapeByteaConn(
- self._pgconn._pgconn_ptr, <unsigned char *>ptr, length, &len_esc)
- else:
- out = libpq.PQescapeBytea(<unsigned char *>ptr, length, &len_esc)
+ escaped = self.dump(obj)
+ _buffer_as_string_and_size(escaped, &ptr, &length)
- if out is NULL:
- raise MemoryError(
- f"couldn't allocate for escape_bytea of {length} bytes"
- )
+ rv = PyByteArray_FromStringAndSize("", 0)
- if not self._qplen:
- self._qplen = 3
- if self._pgconn is not None and self._pgconn._pgconn_ptr != NULL:
+ # We cannot use the base quoting because escape_bytea already returns
+ # the quotes content. if scs is off it will escape the backslashes in
+ # the format, otherwise it won't, but it doesn't tell us what quotes to
+ # use.
+ if self._pgconn is not None:
+ if not self._qplen:
scs = libpq.PQparameterStatus(self._pgconn._pgconn_ptr,
b"standard_conforming_strings")
if scs and scs[0] == b'o' and scs[1] == b"n": # == "on"
self._qplen = 1
+ else:
+ self._qplen = 3
- # len_esc includes the final 0
- # So the final string is len_esc - 1 plus quotes, which might be 2 or 4 bytes
- len_out = len_esc + self._qplen
- rv = PyByteArray_FromStringAndSize("", 0)
- cdef char *buf = CDumper.ensure_size(rv, 0, len_out)
- memcpy(buf + self._qplen, out, len_esc)
- libpq.PQfreemem(out)
-
- if self._qplen == 3:
- # Quote as " E'content'"
- buf[0] = b' '
- buf[1] = b'E'
- buf[2] = b'\''
- else:
- # Quote as "'content'"
- buf[0] = b'\''
-
- buf[len_out - 1] = b'\''
+ PyByteArray_Resize(rv, length + self._qplen + 1) # Include quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ if self._qplen == 1:
+ ptr_out[0] = b"'"
+ else:
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + self._qplen, ptr, length)
+ ptr_out[length + self._qplen] = b"'"
+ return rv
+
+ # We don't have a connection, so someone is using us to generate a file
+ # to use off-line or something like that. PQescapeBytea, like its
+ # string counterpart, is not predictable whether it will escape
+ # backslashes.
+ PyByteArray_Resize(rv, length + 4) # Include quotes
+ ptr_out = PyByteArray_AS_STRING(rv)
+ ptr_out[0] = b" "
+ ptr_out[1] = b"E"
+ ptr_out[2] = b"'"
+ memcpy(ptr_out + 3, ptr, length)
+ ptr_out[length + 3] = b"'"
+
+ esc = Escaping()
+ if esc.escape_bytea(b"\x00") == b"\\000":
+ rv = bytes(rv).replace(b"\\", b"\\\\")
return rv
import pytest
-from psycopg import sql, ProgrammingError
+from psycopg import pq, sql, ProgrammingError
from psycopg.adapt import PyFormat as Format
@pytest.mark.parametrize(
"obj, quoted",
- [("hello", "'hello'"), (42, "42"), (True, "true"), (None, "NULL")],
+ [
+ ("foo\\bar", " E'foo\\\\bar'"),
+ ("hello", "'hello'"),
+ (42, "42"),
+ (True, "true"),
+ (None, "NULL"),
+ ],
)
def test_quote(obj, quoted):
assert sql.quote(obj) == quoted
+@pytest.mark.parametrize("scs", ["on", "off"])
+def test_quote_roundtrip(conn, scs):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+
+ for i in range(1, 256):
+ want = chr(i)
+ quoted = sql.quote(want)
+ got = conn.execute(f"select {quoted}::text").fetchone()[0]
+ assert want == got
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages, f"error with {want!r}"
+
+
+def test_quote_stable_despite_deranged_libpq(conn):
+ # Verify the libpq behaviour of PQescapeString using the last setting seen.
+ # Check that we are not affected by it.
+ good_str = " E'\\\\'"
+ good_bytes = " E'\\\\000'"
+ conn.execute("set standard_conforming_strings to on")
+ assert pq.Escaping().escape_string(b"\\") == b"\\"
+ assert sql.quote("\\") == good_str
+ assert pq.Escaping().escape_bytea(b"\x00") == b"\\000"
+ assert sql.quote(b"\x00") == good_bytes
+
+ conn.execute("set standard_conforming_strings to off")
+ assert pq.Escaping().escape_string(b"\\") == b"\\\\"
+ assert sql.quote("\\") == good_str
+ assert pq.Escaping().escape_bytea(b"\x00") == b"\\\\000"
+ assert sql.quote(b"\x00") == good_bytes
+
+ # Verify that the good values are actually good
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute("set escape_string_warning to on")
+ for scs in ("on", "off"):
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ cur = conn.execute(f"select {good_str}, {good_bytes}::bytea")
+ assert cur.fetchone() == ("\\", b"\x00")
+
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
+
class TestSqlFormat:
def test_pos(self, conn):
s = sql.SQL("select {} from {}").format(
def test_init(self):
assert isinstance(sql.Identifier("foo"), sql.Identifier)
- assert isinstance(sql.Identifier(u"foo"), sql.Identifier)
+ assert isinstance(sql.Identifier("foo"), sql.Identifier)
assert isinstance(sql.Identifier("foo", "bar", "baz"), sql.Identifier)
with pytest.raises(TypeError):
sql.Identifier()
def test_init(self):
assert isinstance(sql.Literal("foo"), sql.Literal)
- assert isinstance(sql.Literal(u"foo"), sql.Literal)
+ assert isinstance(sql.Literal("foo"), sql.Literal)
assert isinstance(sql.Literal(b"foo"), sql.Literal)
assert isinstance(sql.Literal(42), sql.Literal)
assert isinstance(sql.Literal(dt.date(2016, 12, 31)), sql.Literal)
def test_init(self):
assert isinstance(sql.SQL("foo"), sql.SQL)
- assert isinstance(sql.SQL(u"foo"), sql.SQL)
+ assert isinstance(sql.SQL("foo"), sql.SQL)
with pytest.raises(TypeError):
sql.SQL(10)
with pytest.raises(TypeError):
def test_dump_1char(conn, fmt_in):
cur = conn.cursor()
for i in range(1, 256):
- cur.execute(f"select %{fmt_in} = chr(%s::int)", (chr(i), i))
+ cur.execute(f"select %{fmt_in} = chr(%s)", (chr(i), i))
assert cur.fetchone()[0] is True, chr(i)
-def test_quote_1char(conn):
+@pytest.mark.parametrize("scs", ["on", "off"])
+def test_quote_1char(conn, scs):
+ messages = []
+ conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
+ conn.execute(f"set standard_conforming_strings to {scs}")
+ conn.execute("set escape_string_warning to on")
+
cur = conn.cursor()
- query = sql.SQL("select {ch} = chr(%s::int)")
+ query = sql.SQL("select {ch} = chr(%s)")
for i in range(1, 256):
if chr(i) == "%":
continue
cur.execute(query.format(ch=sql.Literal(chr(i))), (i,))
assert cur.fetchone()[0] is True, chr(i)
+ # No "nonstandard use of \\ in a string literal" warning
+ assert not messages
+
@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
def test_dump_zero(conn, fmt_in):
assert cur.fetchone()[0] == "%"
cur.execute(
- sql.SQL("select {ch} = chr(%s::int)").format(ch=sql.Literal("%")),
+ sql.SQL("select {ch} = chr(%s)").format(ch=sql.Literal("%")),
(ord("%"),),
)
assert cur.fetchone()[0] is True
def test_load_1char(conn, typename, fmt_out):
cur = conn.cursor(binary=fmt_out)
for i in range(1, 256):
- cur.execute(f"select chr(%s::int)::{typename}", (i,))
+ cur.execute(f"select chr(%s)::{typename}", (i,))
res = cur.fetchone()[0]
assert res == chr(i)
cur = conn.cursor(binary=fmt_out)
conn.client_encoding = encoding
- (res,) = cur.execute(
- f"select chr(%s::int)::{typename}", (ord(eur),)
- ).fetchone()
+ (res,) = cur.execute(f"select chr(%s)::{typename}", (ord(eur),)).fetchone()
assert res == eur
- stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
ord(eur), sql.SQL(fmt_out.name)
)
with cur.copy(stmt) as copy:
conn.client_encoding = "latin1"
with pytest.raises(psycopg.DataError):
- cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
+ cur.execute(f"select chr(%s)::{typename}", (ord(eur),))
- stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
ord(eur), sql.SQL(fmt_out.name)
)
with cur.copy(stmt) as copy:
cur = conn.cursor(binary=fmt_out)
conn.client_encoding = "ascii"
- cur.execute(f"select chr(%s::int)::{typename}", (ord(eur),))
+ cur.execute(f"select chr(%s)::{typename}", (ord(eur),))
assert cur.fetchone()[0] == eur.encode("utf8")
- stmt = sql.SQL("copy (select chr({}::int)) to stdout (format {})").format(
+ stmt = sql.SQL("copy (select chr({})) to stdout (format {})").format(
ord(eur), sql.SQL(fmt_out.name)
)
with cur.copy(stmt) as copy:
messages = []
conn.add_notice_handler(lambda msg: messages.append(msg.message_primary))
conn.execute(f"set standard_conforming_strings to {scs}")
+ conn.execute("set escape_string_warning to on")
cur = conn.cursor()
query = sql.SQL("select {ch} = set_byte('x', 0, %s)")