From: Daniele Varrazzo Date: Thu, 2 Apr 2020 13:59:35 +0000 (+1300) Subject: Added pq.Escaping object and bytea adaptation X-Git-Tag: 3.0.dev0~625 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=66e769985a59baa62f8bf7502682c050c76217b4;p=thirdparty%2Fpsycopg.git Added pq.Escaping object and bytea adaptation --- diff --git a/psycopg3/pq/__init__.py b/psycopg3/pq/__init__.py index 70a0fd0d3..70223a342 100644 --- a/psycopg3/pq/__init__.py +++ b/psycopg3/pq/__init__.py @@ -28,6 +28,7 @@ PGconn = pq_module.PGconn PGresult = pq_module.PGresult PQerror = pq_module.PQerror Conninfo = pq_module.Conninfo +Escaping = pq_module.Escaping __all__ = ( "ConnStatus", diff --git a/psycopg3/pq/_pq_ctypes.py b/psycopg3/pq/_pq_ctypes.py index 59d498bea..135075aac 100644 --- a/psycopg3/pq/_pq_ctypes.py +++ b/psycopg3/pq/_pq_ctypes.py @@ -369,9 +369,16 @@ PQescapeByteaConn.argtypes = [ c_size_t, POINTER(c_size_t), ] -PQescapeByteaConn.restype = POINTER(c_ubyte) # same, POINTER(c_ubyte) +PQescapeByteaConn.restype = POINTER(c_ubyte) -# TODO: PQescapeBytea PQunescapeBytea +# PQescapeBytea: deprecated + +PQunescapeBytea = pq.PQunescapeBytea +PQunescapeBytea.argtypes = [ + POINTER(c_char), # actually POINTER(c_ubyte) but this is easier + POINTER(c_size_t), +] +PQunescapeBytea.restype = POINTER(c_ubyte) # 33.4. Asynchronous Command Processing @@ -443,7 +450,7 @@ def generate_stub() -> None: return "Any" elif t is c_int or t is c_uint or t is c_size_t: return "int" - elif t is c_char_p: + elif t is c_char_p or t.__name__ == "LP_c_char": return "bytes" elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",): @@ -456,10 +463,11 @@ def generate_stub() -> None: return f"Sequence[{t.__name__[3:]}]" elif t.__name__ in ( - "LP_c_char", + "LP_c_ubyte", "LP_c_char_p", "LP_c_int", "LP_c_uint", + "LP_c_ulong", ): return f"pointer[{t.__name__[3:]}]" diff --git a/psycopg3/pq/_pq_ctypes.pyi b/psycopg3/pq/_pq_ctypes.pyi index bc45cbb27..9408d7f7e 100644 --- a/psycopg3/pq/_pq_ctypes.pyi +++ b/psycopg3/pq/_pq_ctypes.pyi @@ -39,12 +39,9 @@ def PQprepare( arg4: int, arg5: Optional[Array[c_uint]], ) -> PGresult_struct: ... -def PQescapeByteaConn( - arg1: Optional[PGconn_struct], - arg2: bytes, - arg3: int, - arg4: pointer[c_ulong], -) -> pointer[c_ubyte]: ... +def PQgetvalue( + arg1: Optional[PGresult_struct], arg2: int, arg3: int +) -> pointer[c_char]: ... # fmt: off # autogenerated: start @@ -98,7 +95,6 @@ def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ... -def PQgetvalue(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> pointer[c_char]: ... def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ... def PQnparams(arg1: Optional[PGresult_struct]) -> int: ... @@ -106,6 +102,8 @@ def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ... def PQcmdStatus(arg1: Optional[PGresult_struct]) -> bytes: ... def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ... def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ... +def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ... +def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ... def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ... def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_uint], arg5: pointer[c_char_p], arg6: pointer[c_int], arg7: pointer[c_int], arg8: int) -> int: ... def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ... diff --git a/psycopg3/pq/pq_ctypes.py b/psycopg3/pq/pq_ctypes.py index e9c1d3311..66030151b 100644 --- a/psycopg3/pq/pq_ctypes.py +++ b/psycopg3/pq/pq_ctypes.py @@ -362,23 +362,6 @@ class PGconn: raise MemoryError("couldn't allocate PGresult") return PGresult(rv) - def escape_bytea(self, data: bytes) -> bytes: - len_out = c_size_t() - out = impl.PQescapeByteaConn( - self.pgconn_ptr, - data, - len(data), - pointer(t_cast(c_ulong, len_out)), - ) - if not out: - raise MemoryError( - f"couldn't allocate {len(data)} bytes for escape_bytea" - ) - - rv = string_at(out, len_out.value - 1) # out includes final 0 - impl.PQfreemem(out) - return rv - def get_result(self) -> Optional["PGresult"]: rv = impl.PQgetResult(self.pgconn_ptr) return PGresult(rv) if rv else None @@ -552,3 +535,38 @@ class Conninfo: rv.append(ConninfoOption(**d)) return rv + + +class Escaping: + def __init__(self, conn: PGconn): + self.conn = conn + + def escape_bytea(self, data: bytes) -> bytes: + len_out = c_size_t() + out = impl.PQescapeByteaConn( + self.conn.pgconn_ptr, + data, + len(data), + pointer(t_cast(c_ulong, len_out)), + ) + if not out: + raise MemoryError( + f"couldn't allocate for escape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value - 1) # out includes final 0 + impl.PQfreemem(out) + return rv + + @classmethod + def unescape_bytea(cls, data: bytes) -> bytes: + len_out = c_size_t() + out = impl.PQunescapeBytea(data, pointer(t_cast(c_ulong, len_out))) + if not out: + raise MemoryError( + f"couldn't allocate for unescape_bytea of {len(data)} bytes" + ) + + rv = string_at(out, len_out.value) + impl.PQfreemem(out) + return rv diff --git a/psycopg3/types/text.py b/psycopg3/types/text.py index 36d426817..421142ee6 100644 --- a/psycopg3/types/text.py +++ b/psycopg3/types/text.py @@ -13,6 +13,7 @@ from ..adapt import ( ) from ..connection import BaseConnection from ..utils.typing import EncodeFunc, DecodeFunc, Oid +from ..pq import Escaping from .oids import type_oid @@ -32,17 +33,6 @@ class StringAdapter(Adapter): return self._encode(obj)[0] -@Adapter.text(bytes) -class BytesAdapter(Adapter): - def adapt(self, obj: bytes) -> Tuple[bytes, Oid]: - return self.conn.pgconn.escape_bytea(obj), type_oid["bytea"] - - -@Adapter.binary(bytes) -def adapt_bytes(b: bytes) -> Tuple[bytes, Oid]: - return b, type_oid["bytea"] - - @Typecaster.text(type_oid["text"]) @Typecaster.binary(type_oid["text"]) class StringCaster(Typecaster): @@ -66,3 +56,28 @@ class StringCaster(Typecaster): else: # return bytes for SQL_ASCII db return data + + +@Adapter.text(bytes) +class BytesAdapter(Adapter): + def __init__(self, cls: type, conn: BaseConnection): + super().__init__(cls, conn) + self.esc = Escaping(self.conn.pgconn) + + def adapt(self, obj: bytes) -> Tuple[bytes, Oid]: + return self.esc.escape_bytea(obj), type_oid["bytea"] + + +@Adapter.binary(bytes) +def adapt_bytes(b: bytes) -> Tuple[bytes, Oid]: + return b, type_oid["bytea"] + + +@Typecaster.text(type_oid["bytea"]) +def cast_bytea(data: bytes) -> bytes: + return Escaping.unescape_bytea(data) + + +@Typecaster.binary(type_oid["bytea"]) +def cast_bytea_binary(data: bytes) -> bytes: + return data diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py new file mode 100644 index 000000000..67495aa2b --- /dev/null +++ b/tests/pq/test_escaping.py @@ -0,0 +1,27 @@ +import pytest + + +@pytest.mark.parametrize( + "data", [(b"hello\00world"), (b"\00\00\00\00")], +) +def test_escape_bytea(pq, pgconn, data): + rv = pq.Escaping(pgconn).escape_bytea(data) + exp = br"\x" + b"".join(b"%02x" % c for c in data) + assert rv == exp + + +def test_escape_1char(pq, pgconn): + esc = pq.Escaping(pgconn) + for c in range(256): + rv = esc.escape_bytea(bytes([c])) + exp = br"\x%02x" % c + assert rv == exp + + +@pytest.mark.parametrize( + "data", [(b"hello\00world"), (b"\00\00\00\00")], +) +def test_unescape_bytea(pq, pgconn, data): + enc = br"\x" + b"".join(b"%02x" % c for c in data) + rv = pq.Escaping.unescape_bytea(enc) + assert rv == data diff --git a/tests/pq/test_pgconn.py b/tests/pq/test_pgconn.py index 4d056c2b6..131347935 100644 --- a/tests/pq/test_pgconn.py +++ b/tests/pq/test_pgconn.py @@ -226,19 +226,3 @@ def test_make_empty_result(pq, pgconn): res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR) assert res.status == pq.ExecStatus.FATAL_ERROR assert b"wat" in res.error_message - - -@pytest.mark.parametrize( - "data", [(b"hello\00world"), (b"\00\00\00\00")], -) -def test_escape_bytea(pgconn, data): - rv = pgconn.escape_bytea(data) - exp = br"\x" + b"".join(b"%02x" % c for c in data) - assert rv == exp - - -def test_escape_1char(pgconn): - for c in range(256): - rv = pgconn.escape_bytea(bytes([c])) - exp = br"\x%02x" % c - assert rv == exp diff --git a/tests/types/test_text.py b/tests/types/test_text.py new file mode 100644 index 000000000..ada260d8f --- /dev/null +++ b/tests/types/test_text.py @@ -0,0 +1,55 @@ +import pytest + +from psycopg3.adapt import Format + + +# +# tests with text +# + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_adapt_1char(conn, format): + cur = conn.cursor() + query = "select %s = chr(%%s::int)" % ( + "%s" if format == Format.TEXT else "%b" + ) + for i in range(1, 256): + cur.execute(query, (chr(i), i)) + assert cur.fetchone()[0], chr(i) + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_cast_1char(conn, format): + cur = conn.cursor(binary=format == Format.BINARY) + for i in range(1, 256): + cur.execute("select chr(%s::int)", (i,)) + assert cur.fetchone()[0] == chr(i) + + assert cur.pgresult.fformat(0) == format + + +# +# tests with bytea +# + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_adapt_1byte(conn, format): + cur = conn.cursor() + query = "select %s = %%s::bytea" % ( + "%s" if format == Format.TEXT else "%b" + ) + for i in range(0, 256): + cur.execute(query, (bytes([i]), fr"\x{i:02x}")) + assert cur.fetchone()[0], i + + +@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY]) +def test_cast_1byte(conn, format): + cur = conn.cursor(binary=format == Format.BINARY) + for i in range(0, 256): + cur.execute("select %s::bytea", (fr"\x{i:02x}",)) + assert cur.fetchone()[0] == bytes([i]) + + assert cur.pgresult.fformat(0) == format