PGresult = pq_module.PGresult
PQerror = pq_module.PQerror
Conninfo = pq_module.Conninfo
+Escaping = pq_module.Escaping
__all__ = (
"ConnStatus",
c_size_t,
POINTER(c_size_t),
]
-PQescapeByteaConn.restype = POINTER(c_ubyte) # same, POINTER(c_ubyte)
+PQescapeByteaConn.restype = POINTER(c_ubyte)
-# TODO: PQescapeBytea PQunescapeBytea
+# PQescapeBytea: deprecated
+
+PQunescapeBytea = pq.PQunescapeBytea
+PQunescapeBytea.argtypes = [
+ POINTER(c_char), # actually POINTER(c_ubyte) but this is easier
+ POINTER(c_size_t),
+]
+PQunescapeBytea.restype = POINTER(c_ubyte)
# 33.4. Asynchronous Command Processing
return "Any"
elif t is c_int or t is c_uint or t is c_size_t:
return "int"
- elif t is c_char_p:
+ elif t is c_char_p or t.__name__ == "LP_c_char":
return "bytes"
elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",):
return f"Sequence[{t.__name__[3:]}]"
elif t.__name__ in (
- "LP_c_char",
+ "LP_c_ubyte",
"LP_c_char_p",
"LP_c_int",
"LP_c_uint",
+ "LP_c_ulong",
):
return f"pointer[{t.__name__[3:]}]"
arg4: int,
arg5: Optional[Array[c_uint]],
) -> PGresult_struct: ...
-def PQescapeByteaConn(
- arg1: Optional[PGconn_struct],
- arg2: bytes,
- arg3: int,
- arg4: pointer[c_ulong],
-) -> pointer[c_ubyte]: ...
+def PQgetvalue(
+ arg1: Optional[PGresult_struct], arg2: int, arg3: int
+) -> pointer[c_char]: ...
# fmt: off
# autogenerated: start
def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ...
-def PQgetvalue(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> pointer[c_char]: ...
def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
def PQcmdStatus(arg1: Optional[PGresult_struct]) -> bytes: ...
def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
+def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ...
def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_uint], arg5: pointer[c_char_p], arg6: pointer[c_int], arg7: pointer[c_int], arg8: int) -> int: ...
def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ...
raise MemoryError("couldn't allocate PGresult")
return PGresult(rv)
- def escape_bytea(self, data: bytes) -> bytes:
- len_out = c_size_t()
- out = impl.PQescapeByteaConn(
- self.pgconn_ptr,
- data,
- len(data),
- pointer(t_cast(c_ulong, len_out)),
- )
- if not out:
- raise MemoryError(
- f"couldn't allocate {len(data)} bytes for escape_bytea"
- )
-
- rv = string_at(out, len_out.value - 1) # out includes final 0
- impl.PQfreemem(out)
- return rv
-
def get_result(self) -> Optional["PGresult"]:
rv = impl.PQgetResult(self.pgconn_ptr)
return PGresult(rv) if rv else None
rv.append(ConninfoOption(**d))
return rv
+
+
+class Escaping:
+ def __init__(self, conn: PGconn):
+ self.conn = conn
+
+ def escape_bytea(self, data: bytes) -> bytes:
+ len_out = c_size_t()
+ out = impl.PQescapeByteaConn(
+ self.conn.pgconn_ptr,
+ data,
+ len(data),
+ pointer(t_cast(c_ulong, len_out)),
+ )
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for escape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value - 1) # out includes final 0
+ impl.PQfreemem(out)
+ return rv
+
+ @classmethod
+ def unescape_bytea(cls, data: bytes) -> bytes:
+ len_out = c_size_t()
+ out = impl.PQunescapeBytea(data, pointer(t_cast(c_ulong, len_out)))
+ if not out:
+ raise MemoryError(
+ f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+ )
+
+ rv = string_at(out, len_out.value)
+ impl.PQfreemem(out)
+ return rv
)
from ..connection import BaseConnection
from ..utils.typing import EncodeFunc, DecodeFunc, Oid
+from ..pq import Escaping
from .oids import type_oid
return self._encode(obj)[0]
-@Adapter.text(bytes)
-class BytesAdapter(Adapter):
- def adapt(self, obj: bytes) -> Tuple[bytes, Oid]:
- return self.conn.pgconn.escape_bytea(obj), type_oid["bytea"]
-
-
-@Adapter.binary(bytes)
-def adapt_bytes(b: bytes) -> Tuple[bytes, Oid]:
- return b, type_oid["bytea"]
-
-
@Typecaster.text(type_oid["text"])
@Typecaster.binary(type_oid["text"])
class StringCaster(Typecaster):
else:
# return bytes for SQL_ASCII db
return data
+
+
+@Adapter.text(bytes)
+class BytesAdapter(Adapter):
+ def __init__(self, cls: type, conn: BaseConnection):
+ super().__init__(cls, conn)
+ self.esc = Escaping(self.conn.pgconn)
+
+ def adapt(self, obj: bytes) -> Tuple[bytes, Oid]:
+ return self.esc.escape_bytea(obj), type_oid["bytea"]
+
+
+@Adapter.binary(bytes)
+def adapt_bytes(b: bytes) -> Tuple[bytes, Oid]:
+ return b, type_oid["bytea"]
+
+
+@Typecaster.text(type_oid["bytea"])
+def cast_bytea(data: bytes) -> bytes:
+ return Escaping.unescape_bytea(data)
+
+
+@Typecaster.binary(type_oid["bytea"])
+def cast_bytea_binary(data: bytes) -> bytes:
+ return data
--- /dev/null
+import pytest
+
+
+@pytest.mark.parametrize(
+ "data", [(b"hello\00world"), (b"\00\00\00\00")],
+)
+def test_escape_bytea(pq, pgconn, data):
+ rv = pq.Escaping(pgconn).escape_bytea(data)
+ exp = br"\x" + b"".join(b"%02x" % c for c in data)
+ assert rv == exp
+
+
+def test_escape_1char(pq, pgconn):
+ esc = pq.Escaping(pgconn)
+ for c in range(256):
+ rv = esc.escape_bytea(bytes([c]))
+ exp = br"\x%02x" % c
+ assert rv == exp
+
+
+@pytest.mark.parametrize(
+ "data", [(b"hello\00world"), (b"\00\00\00\00")],
+)
+def test_unescape_bytea(pq, pgconn, data):
+ enc = br"\x" + b"".join(b"%02x" % c for c in data)
+ rv = pq.Escaping.unescape_bytea(enc)
+ assert rv == data
res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
assert res.status == pq.ExecStatus.FATAL_ERROR
assert b"wat" in res.error_message
-
-
-@pytest.mark.parametrize(
- "data", [(b"hello\00world"), (b"\00\00\00\00")],
-)
-def test_escape_bytea(pgconn, data):
- rv = pgconn.escape_bytea(data)
- exp = br"\x" + b"".join(b"%02x" % c for c in data)
- assert rv == exp
-
-
-def test_escape_1char(pgconn):
- for c in range(256):
- rv = pgconn.escape_bytea(bytes([c]))
- exp = br"\x%02x" % c
- assert rv == exp
--- /dev/null
+import pytest
+
+from psycopg3.adapt import Format
+
+
+#
+# tests with text
+#
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_adapt_1char(conn, format):
+ cur = conn.cursor()
+ query = "select %s = chr(%%s::int)" % (
+ "%s" if format == Format.TEXT else "%b"
+ )
+ for i in range(1, 256):
+ cur.execute(query, (chr(i), i))
+ assert cur.fetchone()[0], chr(i)
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_cast_1char(conn, format):
+ cur = conn.cursor(binary=format == Format.BINARY)
+ for i in range(1, 256):
+ cur.execute("select chr(%s::int)", (i,))
+ assert cur.fetchone()[0] == chr(i)
+
+ assert cur.pgresult.fformat(0) == format
+
+
+#
+# tests with bytea
+#
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_adapt_1byte(conn, format):
+ cur = conn.cursor()
+ query = "select %s = %%s::bytea" % (
+ "%s" if format == Format.TEXT else "%b"
+ )
+ for i in range(0, 256):
+ cur.execute(query, (bytes([i]), fr"\x{i:02x}"))
+ assert cur.fetchone()[0], i
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_cast_1byte(conn, format):
+ cur = conn.cursor(binary=format == Format.BINARY)
+ for i in range(0, 256):
+ cur.execute("select %s::bytea", (fr"\x{i:02x}",))
+ assert cur.fetchone()[0] == bytes([i])
+
+ assert cur.pgresult.fformat(0) == format