]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added bytes adaptation
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 09:19:47 +0000 (22:19 +1300)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 2 Apr 2020 09:19:47 +0000 (22:19 +1300)
psycopg3/pq/_pq_ctypes.py
psycopg3/pq/_pq_ctypes.pyi
psycopg3/pq/pq_ctypes.py
psycopg3/types/text.py
tests/pq/test_pgconn.py

index c2debb5d5ead1e4c2a53e0e10c18a08afee4f899..59d498bea224e70bd1a229fca2970b95331f9498 100644 (file)
@@ -7,7 +7,7 @@ libpq access using ctypes
 import ctypes
 import ctypes.util
 from ctypes import Structure, POINTER
-from ctypes import c_char, c_char_p, c_int, c_uint, c_void_p
+from ctypes import c_char, c_char_p, c_int, c_size_t, c_ubyte, c_uint, c_void_p
 from typing import List, Tuple
 
 from psycopg3.errors import NotSupportedError
@@ -358,6 +358,22 @@ PQoidValue.argtypes = [PGresult_ptr]
 PQoidValue.restype = Oid
 
 
+# 33.3.4. Escaping Strings for Inclusion in SQL Commands
+
+# TODO: PQescapeLiteral PQescapeIdentifier PQescapeStringConn PQescapeString
+
+PQescapeByteaConn = pq.PQescapeByteaConn
+PQescapeByteaConn.argtypes = [
+    PGconn_ptr,
+    POINTER(c_char),  # actually POINTER(c_ubyte) but this is easier
+    c_size_t,
+    POINTER(c_size_t),
+]
+PQescapeByteaConn.restype = POINTER(c_ubyte)  # same, POINTER(c_ubyte)
+
+# TODO: PQescapeBytea PQunescapeBytea
+
+
 # 33.4. Asynchronous Command Processing
 
 PQsendQuery = pq.PQsendQuery
@@ -425,7 +441,7 @@ def generate_stub() -> None:
             return "None"
         elif t is c_void_p:
             return "Any"
-        elif t is c_int or t is c_uint:
+        elif t is c_int or t is c_uint or t is c_size_t:
             return "int"
         elif t is c_char_p:
             return "bytes"
index 99b8acca907d938691c8ddf73e58fd75577da21d..bc45cbb27369c97abef4bb8e3cb4d15f524e7409 100644 (file)
@@ -5,7 +5,8 @@ types stub for ctypes functions
 # Copyright (C) 2020 The Psycopg Team
 
 from typing import Any, Optional, Sequence, NewType
-from ctypes import Array, c_char, c_char_p, c_int, c_uint, pointer
+from ctypes import Array, pointer
+from ctypes import c_char, c_char_p, c_int, c_ubyte, c_uint, c_ulong
 
 Oid = c_uint
 
@@ -38,6 +39,12 @@ def PQprepare(
     arg4: int,
     arg5: Optional[Array[c_uint]],
 ) -> PGresult_struct: ...
+def PQescapeByteaConn(
+    arg1: Optional[PGconn_struct],
+    arg2: bytes,
+    arg3: int,
+    arg4: pointer[c_ulong],
+) -> pointer[c_ubyte]: ...
 
 # fmt: off
 # autogenerated: start
index ee053fb00c552254ff588ace49ad2a487eb445c8..e9c1d3311215b56773f59ae2e3ea1b097d6a0273 100644 (file)
@@ -8,9 +8,10 @@ implementation.
 
 # Copyright (C) 2020 The Psycopg Team
 
-from ctypes import string_at
-from ctypes import Array, c_char_p, c_int, pointer
+from ctypes import Array, pointer, string_at
+from ctypes import c_char_p, c_int, c_size_t, c_ulong
 from typing import Any, List, Optional, Sequence
+from typing import cast as t_cast
 
 from .enums import (
     ConnStatus,
@@ -361,6 +362,23 @@ class PGconn:
             raise MemoryError("couldn't allocate PGresult")
         return PGresult(rv)
 
+    def escape_bytea(self, data: bytes) -> bytes:
+        len_out = c_size_t()
+        out = impl.PQescapeByteaConn(
+            self.pgconn_ptr,
+            data,
+            len(data),
+            pointer(t_cast(c_ulong, len_out)),
+        )
+        if not out:
+            raise MemoryError(
+                f"couldn't allocate {len(data)} bytes for escape_bytea"
+            )
+
+        rv = string_at(out, len_out.value - 1)  # out includes final 0
+        impl.PQfreemem(out)
+        return rv
+
     def get_result(self) -> Optional["PGresult"]:
         rv = impl.PQgetResult(self.pgconn_ptr)
         return PGresult(rv) if rv else None
index 1fdbd7ce95498774b7791da05e4bdbe36e122db8..36d426817a3b201a4ce96f4b20ec64ce851f8477 100644 (file)
@@ -5,7 +5,7 @@ Adapters of textual types.
 # Copyright (C) 2020 The Psycopg Team
 
 import codecs
-from typing import Optional, Union
+from typing import Optional, Tuple, Union
 
 from ..adapt import (
     Adapter,
@@ -32,6 +32,17 @@ class StringAdapter(Adapter):
         return self._encode(obj)[0]
 
 
+@Adapter.text(bytes)
+class BytesAdapter(Adapter):
+    def adapt(self, obj: bytes) -> Tuple[bytes, Oid]:
+        return self.conn.pgconn.escape_bytea(obj), type_oid["bytea"]
+
+
+@Adapter.binary(bytes)
+def adapt_bytes(b: bytes) -> Tuple[bytes, Oid]:
+    return b, type_oid["bytea"]
+
+
 @Typecaster.text(type_oid["text"])
 @Typecaster.binary(type_oid["text"])
 class StringCaster(Typecaster):
index 131347935e4c8e1338c096c6d69bdceb9d251105..4d056c2b664603944aa70e1f4b896ae9cc92a874 100644 (file)
@@ -226,3 +226,19 @@ def test_make_empty_result(pq, pgconn):
     res = pgconn.make_empty_result(pq.ExecStatus.FATAL_ERROR)
     assert res.status == pq.ExecStatus.FATAL_ERROR
     assert b"wat" in res.error_message
+
+
+@pytest.mark.parametrize(
+    "data", [(b"hello\00world"), (b"\00\00\00\00")],
+)
+def test_escape_bytea(pgconn, data):
+    rv = pgconn.escape_bytea(data)
+    exp = br"\x" + b"".join(b"%02x" % c for c in data)
+    assert rv == exp
+
+
+def test_escape_1char(pgconn):
+    for c in range(256):
+        rv = pgconn.escape_bytea(bytes([c]))
+        exp = br"\x%02x" % c
+        assert rv == exp