# 33.3.4. Escaping Strings for Inclusion in SQL Commands
-# TODO: PQescapeStringConn
PQescapeLiteral = pq.PQescapeLiteral
PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
PQescapeLiteral.restype = POINTER(c_char)
PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
PQescapeIdentifier.restype = POINTER(c_char)
+PQescapeStringConn = pq.PQescapeStringConn
+# TODO: raises "wrong type" error
+# PQescapeStringConn.argtypes = [
+# PGconn_ptr, c_char_p, c_char_p, c_size_t, POINTER(c_int)
+# ]
+PQescapeStringConn.restype = c_size_t
+
# won't wrap: PQescapeString
PQescapeByteaConn = pq.PQescapeByteaConn
arg1: Optional[PGresult_struct], arg2: int, arg3: int
) -> pointer[c_char]: ...
def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
+def PQescapeStringConn(
+ arg1: Optional[PGconn_struct],
+ arg2: c_char_p,
+ arg3: bytes,
+ arg4: int,
+ arg5: pointer[c_int],
+) -> int: ...
def PQsendPrepare(
arg1: Optional[PGconn_struct],
arg2: bytes,
else:
raise PQerror("escape_identifier failed: no connection provided")
+ def escape_string(self, data: bytes) -> bytes:
+ if self.conn is not None:
+ self.conn._ensure_pgconn()
+ error = c_int()
+ out = create_string_buffer(len(data) * 2 + 1)
+ impl.PQescapeStringConn(
+ self.conn.pgconn_ptr,
+ pointer(out), # type: ignore
+ data,
+ len(data),
+ pointer(error),
+ )
+
+ if error:
+ raise PQerror(
+ f"escape_string failed: {error_message(self.conn)} bytes"
+ )
+ return out.value
+
+ else:
+ raise PQerror("escape_identifier failed: no connection provided")
+
def escape_bytea(self, data: bytes) -> bytes:
len_out = c_size_t()
if self.conn is not None:
# 33.3.4. Escaping Strings for Inclusion in SQL Commands
# TODO: PQescapeStringConn PQescapeString
- char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length);
- char *PQescapeLiteral(PGconn *conn, const char *str, size_t length);
+ char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length)
+ char *PQescapeLiteral(PGconn *conn, const char *str, size_t length)
+ size_t PQescapeStringConn(PGconn *conn,
+ char *to, const char *from_, size_t length,
+ int *error)
unsigned char *PQescapeByteaConn(PGconn *conn,
const unsigned char *src,
size_t from_length,
out = impl.PQescapeLiteral(self.conn.pgconn_ptr, data, len(data))
if out is NULL:
raise PQerror(
- f"escape_literal failed: {error_message(self.conn)} bytes"
+ f"escape_literal failed: {error_message(self.conn)}"
)
rv = out
impl.PQfreemem(out)
out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data))
if out is NULL:
raise PQerror(
- f"escape_identifier failed: {error_message(self.conn)} bytes"
+ f"escape_identifier failed: {error_message(self.conn)}"
)
rv = out
impl.PQfreemem(out)
else:
raise PQerror("escape_identifier failed: no connection provided")
+ def escape_string(self, data: bytes) -> bytes:
+ cdef int error
+ cdef size_t len_data = len(data)
+ cdef char *out
+ cdef size_t len_out
+ if self.conn is not None:
+ if self.conn.pgconn_ptr is NULL:
+ raise PQerror("the connection is closed")
+
+ out = <char *>PyMem_Malloc(len_data * 2 + 1)
+ len_out = impl.PQescapeStringConn(
+ self.conn.pgconn_ptr, out, data, len_data, &error
+ )
+
+ if error:
+ raise PQerror(
+ f"escape_string failed: {error_message(self.conn)}"
+ )
+ return out[:len_out]
+
+ else:
+ raise PQerror("escape_identifier failed: no connection provided")
+
def escape_bytea(self, data: bytes) -> bytes:
cdef size_t len_out
cdef unsigned char *out
esc.escape_identifier(b"hi")
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b""),
+ (b"hello", b"hello"),
+ (b"foo'bar", b"foo''bar"),
+ (b"foo\\bar", b"foo\\bar"),
+ ],
+)
+def test_escape_string(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_string(data)
+ assert out == want
+
+
+def test_escape_string_1char(pgconn):
+ esc = pq.Escaping(pgconn)
+ special = {b"'": b"''", b"\\": b"\\"}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_string(data)
+ exp = special.get(data) or b"%s" % data
+ assert rv == exp
+
+
+def test_escape_string_noconn(pgconn):
+ esc = pq.Escaping()
+ with pytest.raises(psycopg3.OperationalError):
+ esc.escape_string(b"hi")
+
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg3.OperationalError):
+ esc.escape_string(b"hi")
+
+
@pytest.mark.parametrize(
"data", [(b"hello\00world"), (b"\00\00\00\00")],
)