]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added bytearray and memoryview dumpers
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 17 Dec 2020 01:29:09 +0000 (02:29 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 17 Dec 2020 04:57:09 +0000 (05:57 +0100)
Also manage these objects when they rich the libp param passing.

psycopg3/psycopg3/pq/pq_ctypes.py
psycopg3/psycopg3/types/text.py
psycopg3_c/psycopg3_c/pq_cython.pyx
tests/types/test_text.py

index 9cb8bcad7bdd9144f4c7cb77f8bdbb01ff32e2a0..546d9c210895e56c896265ab0f95ed4f21caae7c 100644 (file)
@@ -333,7 +333,16 @@ class PGconn:
         alenghts: Optional[Array[c_int]]
         if param_values:
             nparams = len(param_values)
-            aparams = (c_char_p * nparams)(*param_values)
+            aparams = (c_char_p * nparams)(
+                *(
+                    # convert bytearray/memoryview to bytes
+                    # TODO: avoid copy, at least in the C implementation.
+                    b
+                    if b is None or isinstance(b, bytes)
+                    else bytes(b)  # type: ignore[arg-type]
+                    for b in param_values
+                )
+            )
             alenghts = (c_int * nparams)(
                 *(len(p) if p else 0 for p in param_values)
             )
index df0993b61e5781f1a5b70706d8ca39d66520f82d..bd04dc14b0532e93bcd15aea6f176de89687f2af 100644 (file)
@@ -86,9 +86,7 @@ class UnknownLoader(Loader):
         return data.decode(self.encoding)
 
 
-@Dumper.text(bytes)
-class BytesDumper(Dumper):
-
+class _BinaryDumper(Dumper):
     oid = builtins["bytea"].oid
 
     def __init__(self, src: type, context: AdaptContext = None):
@@ -97,17 +95,36 @@ class BytesDumper(Dumper):
             self.connection.pgconn if self.connection else None
         )
 
+
+@Dumper.text(bytes)
+class BytesDumper(_BinaryDumper):
     def dump(self, obj: bytes) -> bytes:
         return self.esc.escape_bytea(obj)
 
 
+@Dumper.text(bytearray)
+class BytearrayDumper(_BinaryDumper):
+    def dump(self, obj: bytearray) -> bytes:
+        return self.esc.escape_bytea(bytes(obj))
+
+
+@Dumper.text(memoryview)
+class MemoryviewDumper(_BinaryDumper):
+    def dump(self, obj: memoryview) -> bytes:
+        return self.esc.escape_bytea(bytes(obj))
+
+
 @Dumper.binary(bytes)
+@Dumper.binary(bytearray)
+@Dumper.binary(memoryview)
 class BytesBinaryDumper(Dumper):
 
     oid = builtins["bytea"].oid
 
-    def dump(self, b: bytes) -> bytes:
-        return b
+    def dump(
+        self, obj: Union[bytes, bytearray, memoryview]
+    ) -> Union[bytes, bytearray, memoryview]:
+        return obj
 
 
 @Loader.text(builtins["bytea"].oid)
index 3de96b0a624482d6046778188426e1b394805aab..f4c31972028c3d14584d84718c933fe0d6f35207 100644 (file)
@@ -6,7 +6,9 @@ libpq Python wrapper using cython bindings.
 
 from posix.unistd cimport getpid
 from cpython.mem cimport PyMem_Malloc, PyMem_Free
-from cpython.bytes cimport PyBytes_AsString
+from cpython.bytes cimport PyBytes_AsString, PyBytes_AsStringAndSize
+from cpython.buffer cimport PyObject_CheckBuffer, PyBUF_SIMPLE
+from cpython.buffer cimport PyObject_GetBuffer, PyBuffer_Release
 
 import logging
 from typing import List, Optional, Sequence, Tuple
@@ -510,9 +512,9 @@ cdef PGconn _connect_start(const char *conninfo):
 
 
 cdef (int, Oid *, char * const*, int *, int *) _query_params_args(
-    param_values: Optional[Sequence[Optional[bytes]]],
-    param_types: Optional[Sequence[int]],
-    param_formats: Optional[Sequence[Format]],
+    list param_values: Optional[Sequence[Optional[bytes]]],
+    list param_types: Optional[Sequence[int]],
+    list param_formats: Optional[Sequence[Format]],
 ) except *:
     cdef int i
 
@@ -530,16 +532,29 @@ cdef (int, Oid *, char * const*, int *, int *) _query_params_args(
 
     cdef char **aparams = NULL
     cdef int *alenghts = NULL
+    cdef char *ptr
+    cdef Py_ssize_t length
+    cdef Py_buffer buf
+
     if nparams:
         aparams = <char **>PyMem_Malloc(nparams * sizeof(char *))
         alenghts = <int *>PyMem_Malloc(nparams * sizeof(int))
         for i in range(nparams):
-            if param_values[i] is not None:
-                aparams[i] = param_values[i]
-                alenghts[i] = len(param_values[i])
-            else:
+            obj = param_values[i]
+            if obj is None:
                 aparams[i] = NULL
                 alenghts[i] = 0
+            elif isinstance(obj, bytes):
+                PyBytes_AsStringAndSize(obj, &ptr, &length)
+                aparams[i] = ptr
+                alenghts[i] = length
+            elif PyObject_CheckBuffer(obj):
+                PyObject_GetBuffer(obj, &buf, PyBUF_SIMPLE)
+                aparams[i] = <char *>buf.buf
+                alenghts[i] = buf.len
+                PyBuffer_Release(&buf)
+            else:
+                raise TypeError(f"bytes or buffer expected, got {type(obj)}")
 
     cdef Oid *atypes = NULL
     if param_types is not None:
index 268b407a83b964d955c01868c9a1e84b7d92d3ab..fd35fb889d53b760fb1c514a54f894c9f88c3caf 100644 (file)
@@ -171,11 +171,13 @@ def test_text_array_ascii(conn, fmt_in, fmt_out):
 
 
 @pytest.mark.parametrize("fmt_in", [Format.TEXT, Format.BINARY])
-def test_dump_1byte(conn, fmt_in):
+@pytest.mark.parametrize("pytype", [bytes, bytearray, memoryview])
+def test_dump_1byte(conn, fmt_in, pytype):
     cur = conn.cursor()
     ph = "%s" if fmt_in == Format.TEXT else "%b"
     for i in range(0, 256):
-        cur.execute(f"select {ph} = %s::bytea", (bytes([i]), fr"\x{i:02x}"))
+        obj = pytype(bytes([i]))
+        cur.execute(f"select {ph} = %s::bytea", (obj, fr"\x{i:02x}"))
         assert cur.fetchone()[0] is True, i