# 33.3.4. Escaping Strings for Inclusion in SQL Commands
-# TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString
+# TODO: PQescapeIdentifier PQescapeStringConn
+PQescapeLiteral = pq.PQescapeLiteral
+PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeLiteral.restype = POINTER(c_char)
+
+# won't wrap: PQescapeString
PQescapeByteaConn = pq.PQescapeByteaConn
PQescapeByteaConn.argtypes = [
def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
def PQescapeBytea(arg1: bytes, arg2: int, arg3: pointer[c_ulong]) -> pointer[c_ubyte]: ...
def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ...
def __init__(self, conn: Optional[PGconn] = None):
self.conn = conn
+ def escape_literal(self, data: bytes) -> bytes:
+ if self.conn is not None:
+ self.conn._ensure_pgconn()
+ out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
+ if not out:
+ raise PQerror(
+ f"escape_literal failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ else:
+ raise PQerror("escape_literal failed: no connection provided")
+
def escape_bytea(self, data: bytes) -> bytes:
len_out = c_size_t()
if self.conn is not None:
def __init__(self, conn: Optional[PGconn] = None):
...
+ def escape_literal(self, data: bytes) -> bytes:
+ ...
+
def escape_bytea(self, data: bytes) -> bytes:
...
# 33.3.4. Escaping Strings for Inclusion in SQL Commands
# TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString
+ char *PQescapeLiteral(PGconn *conn, const char *str, size_t length);
unsigned char *PQescapeByteaConn(PGconn *conn,
const unsigned char *src,
size_t from_length,
def __init__(self, conn: Optional[PGconn] = None):
self.conn = conn
+ def escape_literal(self, data: bytes) -> bytes:
+ cdef char *out
+ cdef bytes rv
+
+ if self.conn is not None:
+ if self.conn.pgconn_ptr is NULL:
+ raise PQerror("the connection is closed")
+ out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
+ if out is NULL:
+ raise PQerror(
+ f"escape_literal failed: {error_message(self.conn)} bytes"
+ )
+ rv = out
+ impl.PQfreemem(out)
+ return rv
+
+ else:
+ raise PQerror("escape_literal failed: no connection provided")
+
def escape_bytea(self, data: bytes) -> bytes:
cdef size_t len_out
cdef unsigned char *out
from psycopg3 import pq
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b"''"),
+ (b"hello", b"'hello'"),
+ (b"foo'bar", b"'foo''bar'"),
+ (b"foo\\bar", b" E'foo\\\\bar'"),
+ ],
+)
+def test_escape_literal(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_literal(data)
+ assert out == want
+
+
+def test_escape_literal_1char(pgconn):
+ esc = pq.Escaping(pgconn)
+ special = {b"'": b"''''", b"\\": b" E'\\\\'"}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_literal(data)
+ exp = special.get(data) or b"'%s'" % data
+ assert rv == exp
+
+
+def test_escape_literal_noconn(pgconn):
+ esc = pq.Escaping()
+ with pytest.raises(psycopg3.OperationalError):
+ esc.escape_literal(b"hi")
+
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg3.OperationalError):
+ esc.escape_literal(b"hi")
+
+
@pytest.mark.parametrize(
"data", [(b"hello\00world"), (b"\00\00\00\00")],
)