]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added pq.Escaping object and bytea adaptation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 13:59:35 +0000 (02:59 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 13:59:35 +0000 (02:59 +1300)
psycopg3/pq/__init__.py
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/pq_ctypes.py
psycopg3/types/text.py
tests/pq/test_escaping.py [new file with mode: 0644]
tests/pq/test_pgconn.py
tests/types/test_text.py [new file with mode: 0644]

index 70a0fd0d3d39ae77b769a10fd0394a74a0e9f51a..70223a342524dee70220e3e4a6efca1c40b2df24 100644 (file)
@@ -28,6 +28,7 @@ PGconn = pq_module.PGconn
 PGresult = pq_module.PGresult
 PQerror = pq_module.PQerror
 Conninfo = pq_module.Conninfo
+Escaping = pq_module.Escaping
 
 __all__ = (
     "ConnStatus",
index 59d498bea224e70bd1a229fca2970b95331f9498..135075aac3efa31df064a42718446dc47b5efcf5 100644 (file)
@@ -369,9 +369,16 @@ PQescapeByteaConn.argtypes = [
     c_size_t,
     POINTER(c_size_t),
 ]
-PQescapeByteaConn.restype = POINTER(c_ubyte)  # same, POINTER(c_ubyte)
+PQescapeByteaConn.restype = POINTER(c_ubyte)
 
-# TODO: PQescapeBytea PQunescapeBytea
+# PQescapeBytea: deprecated
+
+PQunescapeBytea = pq.PQunescapeBytea
+PQunescapeBytea.argtypes = [
+    POINTER(c_char),  # actually POINTER(c_ubyte) but this is easier
+    POINTER(c_size_t),
+]
+PQunescapeBytea.restype = POINTER(c_ubyte)
 
 
 # 33.4. Asynchronous Command Processing
@@ -443,7 +450,7 @@ def generate_stub() -> None:
             return "Any"
         elif t is c_int or t is c_uint or t is c_size_t:
             return "int"
-        elif t is c_char_p:
+        elif t is c_char_p or t.__name__ == "LP_c_char":
             return "bytes"
 
         elif t.__name__ in ("LP_PGconn_struct", "LP_PGresult_struct",):
@@ -456,10 +463,11 @@ def generate_stub() -> None:
             return f"Sequence[{t.__name__[3:]}]"
 
         elif t.__name__ in (
-            "LP_c_char",
+            "LP_c_ubyte",
             "LP_c_char_p",
             "LP_c_int",
             "LP_c_uint",
+            "LP_c_ulong",
         ):
             return f"pointer[{t.__name__[3:]}]"
 
index bc45cbb27369c97abef4bb8e3cb4d15f524e7409..9408d7f7edd367fe3ba56b4488614c853002bd8f 100644 (file)
@@ -39,12 +39,9 @@ def PQprepare(
     arg4: int,
     arg5: Optional[Array[c_uint]],
 ) -> PGresult_struct: ...
-def PQescapeByteaConn(
-    arg1: Optional[PGconn_struct],
-    arg2: bytes,
-    arg3: int,
-    arg4: pointer[c_ulong],
-) -> pointer[c_ubyte]: ...
+def PQgetvalue(
+    arg1: Optional[PGresult_struct], arg2: int, arg3: int
+) -> pointer[c_char]: ...
 
 # fmt: off
 # autogenerated: start
@@ -98,7 +95,6 @@ def PQftype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQfmod(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQfsize(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQbinaryTuples(arg1: Optional[PGresult_struct]) -> int: ...
-def PQgetvalue(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> pointer[c_char]: ...
 def PQgetisnull(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
 def PQgetlength(arg1: Optional[PGresult_struct], arg2: int, arg3: int) -> int: ...
 def PQnparams(arg1: Optional[PGresult_struct]) -> int: ...
@@ -106,6 +102,8 @@ def PQparamtype(arg1: Optional[PGresult_struct], arg2: int) -> int: ...
 def PQcmdStatus(arg1: Optional[PGresult_struct]) -> bytes: ...
 def PQcmdTuples(arg1: Optional[PGresult_struct]) -> bytes: ...
 def PQoidValue(arg1: Optional[PGresult_struct]) -> int: ...
+def PQescapeByteaConn(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_ulong]) -> pointer[c_ubyte]: ...
+def PQunescapeBytea(arg1: bytes, arg2: pointer[c_ulong]) -> pointer[c_ubyte]: ...
 def PQsendQuery(arg1: Optional[PGconn_struct], arg2: bytes) -> int: ...
 def PQsendQueryParams(arg1: Optional[PGconn_struct], arg2: bytes, arg3: int, arg4: pointer[c_uint], arg5: pointer[c_char_p], arg6: pointer[c_int], arg7: pointer[c_int], arg8: int) -> int: ...
 def PQgetResult(arg1: Optional[PGconn_struct]) -> PGresult_struct: ...
index e9c1d3311215b56773f59ae2e3ea1b097d6a0273..66030151b47fbf81af621d728b2aaaf838f51511 100644 (file)
@@ -362,23 +362,6 @@ class PGconn:
             raise MemoryError("couldn't allocate PGresult")
         return PGresult(rv)
 
-    def escape_bytea(self, data: bytes) -> bytes:
-        len_out = c_size_t()
-        out = impl.PQescapeByteaConn(
-            self.pgconn_ptr,
-            data,
-            len(data),
-            pointer(t_cast(c_ulong, len_out)),
-        )
-        if not out:
-            raise MemoryError(
-                f"couldn't allocate {len(data)} bytes for escape_bytea"
-            )
-
-        rv = string_at(out, len_out.value - 1)  # out includes final 0
-        impl.PQfreemem(out)
-        return rv
-
     def get_result(self) -> Optional["PGresult"]:
         rv = impl.PQgetResult(self.pgconn_ptr)
         return PGresult(rv) if rv else None
@@ -552,3 +535,38 @@ class Conninfo:
             rv.append(ConninfoOption(**d))
 
         return rv
+
+
+class Escaping:
+    def __init__(self, conn: PGconn):
+        self.conn = conn
+
+    def escape_bytea(self, data: bytes) -> bytes:
+        len_out = c_size_t()
+        out = impl.PQescapeByteaConn(
+            self.conn.pgconn_ptr,
+            data,
+            len(data),
+            pointer(t_cast(c_ulong, len_out)),
+        )
+        if not out:
+            raise MemoryError(
+                f"couldn't allocate for escape_bytea of {len(data)} bytes"
+            )
+
+        rv = string_at(out, len_out.value - 1)  # out includes final 0
+        impl.PQfreemem(out)
+        return rv
+
+    @classmethod
+    def unescape_bytea(cls, data: bytes) -> bytes:
+        len_out = c_size_t()
+        out = impl.PQunescapeBytea(data, pointer(t_cast(c_ulong, len_out)))
+        if not out:
+            raise MemoryError(
+                f"couldn't allocate for unescape_bytea of {len(data)} bytes"
+            )
+
+        rv = string_at(out, len_out.value)
+        impl.PQfreemem(out)
+        return rv
index 36d426817a3b201a4ce96f4b20ec64ce851f8477..421142ee6b8f51b7e03cd5126800a9442f36868f 100644 (file)
@@ -13,6 +13,7 @@ from ..adapt import (
 )
 from ..connection import BaseConnection
 from ..utils.typing import EncodeFunc, DecodeFunc, Oid
+from ..pq import Escaping
 from .oids import type_oid
 
 
@@ -32,17 +33,6 @@ class StringAdapter(Adapter):
         return self._encode(obj)[0]
 
 
-@Adapter.text(bytes)
-class BytesAdapter(Adapter):
-    def adapt(self, obj: bytes) -> Tuple[bytes, Oid]:
-        return self.conn.pgconn.escape_bytea(obj), type_oid["bytea"]
-
-
-@Adapter.binary(bytes)
-def adapt_bytes(b: bytes) -> Tuple[bytes, Oid]:
-    return b, type_oid["bytea"]
-
-
 @Typecaster.text(type_oid["text"])
 @Typecaster.binary(type_oid["text"])
 class StringCaster(Typecaster):
@@ -66,3 +56,28 @@ class StringCaster(Typecaster):
         else:
             # return bytes for SQL_ASCII db
             return data
+
+
+@Adapter.text(bytes)
+class BytesAdapter(Adapter):
+    def __init__(self, cls: type, conn: BaseConnection):
+        super().__init__(cls, conn)
+        self.esc = Escaping(self.conn.pgconn)
+
+    def adapt(self, obj: bytes) -> Tuple[bytes, Oid]:
+        return self.esc.escape_bytea(obj), type_oid["bytea"]
+
+
+@Adapter.binary(bytes)
+def adapt_bytes(b: bytes) -> Tuple[bytes, Oid]:
+    return b, type_oid["bytea"]
+
+
+@Typecaster.text(type_oid["bytea"])
+def cast_bytea(data: bytes) -> bytes:
+    return Escaping.unescape_bytea(data)
+
+
+@Typecaster.binary(type_oid["bytea"])
+def cast_bytea_binary(data: bytes) -> bytes:
+    return data
diff --git a/tests/pq/test_escaping.py b/tests/pq/test_escaping.py
new file mode 100644 (file)
index 0000000..67495aa
--- /dev/null
@@ -0,0 +1,27 @@
+import pytest
+
+
+@pytest.mark.parametrize(
+    "data", [(b"hello\00world"), (b"\00\00\00\00")],
+)
+def test_escape_bytea(pq, pgconn, data):
+    rv = pq.Escaping(pgconn).escape_bytea(data)
+    exp = br"\x" + b"".join(b"%02x" % c for c in data)
+    assert rv == exp
+
+
+def test_escape_1char(pq, pgconn):
+    esc = pq.Escaping(pgconn)
+    for c in range(256):
+        rv = esc.escape_bytea(bytes([c]))
+        exp = br"\x%02x" % c
+        assert rv == exp
+
+
+@pytest.mark.parametrize(
+    "data", [(b"hello\00world"), (b"\00\00\00\00")],
+)
+def test_unescape_bytea(pq, pgconn, data):
+    enc = br"\x" + b"".join(b"%02x" % c for c in data)
+    rv = pq.Escaping.unescape_bytea(enc)
+    assert rv == data
index 4d056c2b664603944aa70e1f4b896ae9cc92a874..131347935e4c8e1338c096c6d69bdceb9d251105 100644 (file)
@@ -226,19 +226,3 @@ def test_make_empty_result(pq, pgconn):
     res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
     assert res.status == pq.ExecStatus.FATAL_ERROR
     assert b"wat" in res.error_message
-
-
-@pytest.mark.parametrize(
-    "data", [(b"hello\00world"), (b"\00\00\00\00")],
-)
-def test_escape_bytea(pgconn, data):
-    rv = pgconn.escape_bytea(data)
-    exp = br"\x" + b"".join(b"%02x" % c for c in data)
-    assert rv == exp
-
-
-def test_escape_1char(pgconn):
-    for c in range(256):
-        rv = pgconn.escape_bytea(bytes([c]))
-        exp = br"\x%02x" % c
-        assert rv == exp
diff --git a/tests/types/test_text.py b/tests/types/test_text.py
new file mode 100644 (file)
index 0000000..ada260d
--- /dev/null
@@ -0,0 +1,55 @@
+import pytest
+
+from psycopg3.adapt import Format
+
+
+#
+# tests with text
+#
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_adapt_1char(conn, format):
+    cur = conn.cursor()
+    query = "select %s = chr(%%s::int)" % (
+        "%s" if format == Format.TEXT else "%b"
+    )
+    for i in range(1, 256):
+        cur.execute(query, (chr(i), i))
+        assert cur.fetchone()[0], chr(i)
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_cast_1char(conn, format):
+    cur = conn.cursor(binary=format == Format.BINARY)
+    for i in range(1, 256):
+        cur.execute("select chr(%s::int)", (i,))
+        assert cur.fetchone()[0] == chr(i)
+
+    assert cur.pgresult.fformat(0) == format
+
+
+#
+# tests with bytea
+#
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_adapt_1byte(conn, format):
+    cur = conn.cursor()
+    query = "select %s = %%s::bytea" % (
+        "%s" if format == Format.TEXT else "%b"
+    )
+    for i in range(0, 256):
+        cur.execute(query, (bytes([i]), fr"\x{i:02x}"))
+        assert cur.fetchone()[0], i
+
+
+@pytest.mark.parametrize("format", [Format.TEXT, Format.BINARY])
+def test_cast_1byte(conn, format):
+    cur = conn.cursor(binary=format == Format.BINARY)
+    for i in range(0, 256):
+        cur.execute("select %s::bytea", (fr"\x{i:02x}",))
+        assert cur.fetchone()[0] == bytes([i])
+
+    assert cur.pgresult.fformat(0) == format