# 33.3.4. Escaping Strings for Inclusion in SQL Commands
-# TODO: PQescapeIdentifier PQescapeStringConn
+# TODO: PQescapeStringConn
PQescapeLiteral = pq.PQescapeLiteral
PQescapeLiteral.argtypes = [PGconn_ptr, c_char_p, c_size_t]
PQescapeLiteral.restype = POINTER(c_char)
+PQescapeIdentifier = pq.PQescapeIdentifier
+PQescapeIdentifier.argtypes = [PGconn_ptr, c_char_p, c_size_t]
+PQescapeIdentifier.restype = POINTER(c_char)
+
# won't wrap: PQescapeString
PQescapeByteaConn = pq.PQescapeByteaConn
def PQcmdStatus(arg1: Optional[PGresult_struct]) -> Optional[bytes]: ...
def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
def PQescapeLiteral(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
+def PQescapeIdentifier(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int) -> Optional[bytes]: ...
def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
def PQescapeBytea(arg1: bytes, arg2: int, arg3: pointer[c_ulong]) -> pointer[c_ubyte]: ...
def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ...
else:
raise PQerror("escape_literal failed: no connection provided")
+ def escape_identifier(self, data: bytes) -> bytes:
+ if self.conn is not None:
+ self.conn._ensure_pgconn()
+ out = impl.PQescapeIdentifier(
+ self.conn.pgconn_ptr, data, len(data)
+ )
+ if not out:
+ raise PQerror(
+ f"escape_identifier failed: {error_message(self.conn)} bytes"
+ )
+ rv = string_at(out)
+ impl.PQfreemem(out)
+ return rv
+
+ else:
+ raise PQerror("escape_identifier failed: no connection provided")
+
def escape_bytea(self, data: bytes) -> bytes:
len_out = c_size_t()
if self.conn is not None:
Oid PQoidValue(const PGresult *res)
# 33.3.4. Escaping Strings for Inclusion in SQL Commands
- # TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString
+ # TODO: PQescapeStringConn PQescapeString
+ char *PQescapeIdentifier(PGconn *conn, const char *str, size_t length);
char *PQescapeLiteral(PGconn *conn, const char *str, size_t length);
unsigned char *PQescapeByteaConn(PGconn *conn,
const unsigned char *src,
else:
raise PQerror("escape_literal failed: no connection provided")
+ def escape_identifier(self, data: bytes) -> bytes:
+ cdef char *out
+ cdef bytes rv
+
+ if self.conn is not None:
+ if self.conn.pgconn_ptr is NULL:
+ raise PQerror("the connection is closed")
+ out = impl.PQescapeIdentifier(self.conn.pgconn_ptr, data, len(data))
+ if out is NULL:
+ raise PQerror(
+ f"escape_identifier failed: {error_message(self.conn)} bytes"
+ )
+ rv = out
+ impl.PQfreemem(out)
+ return rv
+
+ else:
+ raise PQerror("escape_identifier failed: no connection provided")
+
def escape_bytea(self, data: bytes) -> bytes:
cdef size_t len_out
cdef unsigned char *out
esc.escape_literal(b"hi")
+@pytest.mark.parametrize(
+ "data, want",
+ [
+ (b"", b'""'),
+ (b"hello", b'"hello"'),
+ (b'foo"bar', b'"foo""bar"'),
+ (b"foo\\bar", b'"foo\\bar"'),
+ ],
+)
+def test_escape_identifier(pgconn, data, want):
+ esc = pq.Escaping(pgconn)
+ out = esc.escape_identifier(data)
+ assert out == want
+
+
+def test_escape_identifier_1char(pgconn):
+ esc = pq.Escaping(pgconn)
+ special = {b'"': b'""""', b"\\": b'"\\"'}
+ for c in range(1, 128):
+ data = bytes([c])
+ rv = esc.escape_identifier(data)
+ exp = special.get(data) or b'"%s"' % data
+ assert rv == exp
+
+
+def test_escape_identifier_noconn(pgconn):
+ esc = pq.Escaping()
+ with pytest.raises(psycopg3.OperationalError):
+ esc.escape_identifier(b"hi")
+
+ esc = pq.Escaping(pgconn)
+ pgconn.finish()
+ with pytest.raises(psycopg3.OperationalError):
+ esc.escape_identifier(b"hi")
+
+
@pytest.mark.parametrize(
"data", [(b"hello\00world"), (b"\00\00\00\00")],
)