]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add dumping by oid in the C implementation too
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 22:20:51 +0000 (00:20 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 22:23:47 +0000 (00:23 +0200)
psycopg/psycopg/_transform.py
psycopg/psycopg/abc.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/copy.pyx
psycopg_c/psycopg_c/_psycopg/transform.pyx
tests/test_query.py

index b15511e4b5a60e321262b3da2a1aa6feaa2ad2a7..cd87687b38f48931e1b8eda4a468e6297e32b7cd 100644 (file)
@@ -42,12 +42,11 @@ class Transformer(AdaptContext):
     _adapters: "AdaptersMap"
     _pgresult: Optional["PGresult"] = None
 
-    types: Tuple[int, ...]
-    formats: List[pq.Format]
+    types: Optional[Tuple[int, ...]]
+    formats: Optional[List[pq.Format]]
 
     def __init__(self, context: Optional[AdaptContext] = None):
-        self.types = ()
-        self.formats = []
+        self.types = self.formats = None
 
         # WARNING: don't store context, or you'll create a loop with the Cursor
         if context:
@@ -137,7 +136,8 @@ class Transformer(AdaptContext):
     def dump_sequence(
         self, params: Sequence[Any], formats: Sequence[PyFormat]
     ) -> Sequence[Optional[Buffer]]:
-        out: List[Optional[Buffer]] = [None] * len(params)
+        nparams = len(params)
+        out: List[Optional[Buffer]] = [None] * nparams
 
         change_state = False
 
@@ -152,22 +152,28 @@ class Transformer(AdaptContext):
         # an executemany the first records has a null, the second has a value.
         if not dumpers:
             change_state = True
-            dumpers = [None] * len(params)
-            types = [INVALID_OID] * len(params)
-            pqformats = [pq.Format.TEXT] * len(params)
+            dumpers = [None] * nparams
+            types = [INVALID_OID] * nparams
+            pqformats = [pq.Format.TEXT] * nparams
 
-        for i in range(len(params)):
+        for i in range(nparams):
             param = params[i]
             if param is not None:
                 dumper = dumpers[i]
                 if not dumper:
-                    dumper = dumpers[i] = self.get_dumper(param, formats[i])
                     change_state = True
+                    dumper = dumpers[i] = self.get_dumper(param, formats[i])
+
                     if not types:
-                        types = list(self.types)
+                        types = (
+                            list(self.types)
+                            if self.types
+                            else [INVALID_OID] * nparams
+                        )
                     types[i] = dumper.oid
+
                     if not pqformats:
-                        pqformats = list(self.formats)
+                        pqformats = self.formats or [pq.Format.TEXT] * nparams
                     pqformats[i] = dumper.format
 
                 out[i] = dumper.dump(param)
index d6c164998ecf2a07e9a699fea0f8edd5f040e40d..137acd4bde229745f887c237143c1f475b0bcc88 100644 (file)
@@ -170,8 +170,8 @@ class Loader(Protocol):
 
 class Transformer(Protocol):
 
-    types: Tuple[int, ...]
-    formats: Sequence[pq.Format]
+    types: Optional[Tuple[int, ...]]
+    formats: Optional[List[pq.Format]]
 
     def __init__(self, context: Optional[AdaptContext] = None):
         ...
@@ -215,9 +215,6 @@ class Transformer(Protocol):
     def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
         ...
 
-    def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Dumper:
-        ...
-
     def load_rows(
         self, row0: int, row1: int, make_row: "RowMaker[Row]"
     ) -> List["Row"]:
index ebcb6addeda61bf917fb6604f8c280c9f73ce0b1..9a9d316d6fbd271b955107a2da40133c488d0752 100644 (file)
@@ -17,8 +17,8 @@ from psycopg.pq.abc import PGconn, PGresult
 from psycopg.connection import BaseConnection
 
 class Transformer(abc.AdaptContext):
-    types: Tuple[int, ...]
-    formats: Sequence[pq.Format]
+    types: Optional[Tuple[int, ...]]
+    formats: Optional[List[pq.Format]]
     def __init__(self, context: Optional[abc.AdaptContext] = None): ...
     @property
     def connection(self) -> Optional[BaseConnection[Any]]: ...
@@ -43,7 +43,6 @@ class Transformer(abc.AdaptContext):
         self, params: Sequence[Any], formats: Sequence[PyFormat]
     ) -> Sequence[Optional[abc.Buffer]]: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
-    def get_dumper_by_oid(self, oid: int, format: pq.Format) -> abc.Dumper: ...
     def load_rows(
         self, row0: int, row1: int, make_row: RowMaker[Row]
     ) -> List[Row]: ...
index a52afc4cb3cab16fdb8439235a0dec7cf180586a..dfd6aefc59210ba2bcbce17ba41a664cb483a2d8 100644 (file)
@@ -49,34 +49,44 @@ def format_row_binary(
     cdef PyObject *fmt = <PyObject *>PG_BINARY
     cdef PyObject *row_dumper
 
-    for i in range(rowlen):
-        item = row[i]
-        if item is not None:
-            row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
-            if (<RowDumper>row_dumper).cdumper is not None:
-                # A cdumper can resize if necessary and copy in place
-                size = (<RowDumper>row_dumper).cdumper.cdump(
-                    item, out, pos + sizeof(besize))
-                # Also add the size of the item, before the item
-                besize = endian.htobe32(<int32_t>size)
-                target = PyByteArray_AS_STRING(out)  # might have been moved by cdump
-                memcpy(target + pos, <void *>&besize, sizeof(besize))
-            else:
-                # A Python dumper, gotta call it and extract its juices
-                b = PyObject_CallFunctionObjArgs(
-                    (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL)
-                _buffer_as_string_and_size(b, &buf, &size)
-                target = CDumper.ensure_size(out, pos, size + sizeof(besize))
-                besize = endian.htobe32(<int32_t>size)
-                memcpy(target, <void *>&besize, sizeof(besize))
-                memcpy(target + sizeof(besize), buf, size)
+    if not tx._row_dumpers:
+        tx._row_dumpers = PyList_New(rowlen)
 
-            pos += size + sizeof(besize)
+    dumpers = tx._row_dumpers
 
-        else:
+    for i in range(rowlen):
+        item = row[i]
+        if item is None:
             target = CDumper.ensure_size(out, pos, sizeof(_binary_null))
             memcpy(target, <void *>&_binary_null, sizeof(_binary_null))
             pos += sizeof(_binary_null)
+            continue
+
+        row_dumper = PyList_GET_ITEM(dumpers, i)
+        if not row_dumper:
+            row_dumper = tx.get_row_dumper(<PyObject *>item, fmt)
+            Py_INCREF(<object>row_dumper)
+            PyList_SET_ITEM(dumpers, i, <object>row_dumper)
+
+        if (<RowDumper>row_dumper).cdumper is not None:
+            # A cdumper can resize if necessary and copy in place
+            size = (<RowDumper>row_dumper).cdumper.cdump(
+                item, out, pos + sizeof(besize))
+            # Also add the size of the item, before the item
+            besize = endian.htobe32(<int32_t>size)
+            target = PyByteArray_AS_STRING(out)  # might have been moved by cdump
+            memcpy(target + pos, <void *>&besize, sizeof(besize))
+        else:
+            # A Python dumper, gotta call it and extract its juices
+            b = PyObject_CallFunctionObjArgs(
+                (<RowDumper>row_dumper).dumpfunc, <PyObject *>item, NULL)
+            _buffer_as_string_and_size(b, &buf, &size)
+            target = CDumper.ensure_size(out, pos, size + sizeof(besize))
+            besize = endian.htobe32(<int32_t>size)
+            memcpy(target, <void *>&besize, sizeof(besize))
+            memcpy(target + sizeof(besize), buf, size)
+
+        pos += size + sizeof(besize)
 
     # Resize to the final size
     PyByteArray_Resize(out, pos)
index ab1fa453c186062e742c370b594995e7eaceb117..976239f896b62d71ccb4edf5db5336c8004223c5 100644 (file)
@@ -72,6 +72,8 @@ cdef class Transformer:
 
     cdef readonly object connection
     cdef readonly object adapters
+    cdef readonly object types
+    cdef readonly object formats
 
     # mapping class -> Dumper instance (auto, text, binary)
     cdef dict _auto_dumpers
@@ -100,6 +102,8 @@ cdef class Transformer:
             self.adapters = postgres.adapters
             self.connection = None
 
+        self.types = self.formats = None
+
     @property
     def pgresult(self) -> Optional[PGresult]:
         return self._pgresult
@@ -122,26 +126,48 @@ cdef class Transformer:
         self._nfields = libpq.PQnfields(res)
         self._ntuples = libpq.PQntuples(res)
 
-        cdef int i
-        cdef object tmp
-        cdef list types
+        if not set_loaders:
+            return
+
+        if not self._nfields:
+            self._row_loaders = []
+            return
+
         if format is None:
             format = libpq.PQfformat(res, 0)
 
-        if set_loaders:
-            types = PyList_New(self._nfields)
-            for i in range(self._nfields):
-                tmp = libpq.PQftype(res, i)
-                Py_INCREF(tmp)
-                PyList_SET_ITEM(types, i, tmp)
+        cdef list loaders = PyList_New(self._nfields)
+        cdef PyObject *row_loader
+        cdef object oid
 
-            self._c_loader_types(self._nfields, types, format)
+        cdef int i
+        for i in range(self._nfields):
+            oid = libpq.PQftype(res, i)
+            row_loader = self._c_get_loader(<PyObject *>oid, <PyObject *>format)
+            Py_INCREF(<object>row_loader)
+            PyList_SET_ITEM(loaders, i, <object>row_loader)
+
+        self._row_loaders = loaders
+
+    def set_dumper_types(self, types: Sequence[int], format: Format) -> None:
+        cdef int ntypes = len(types)
+        dumpers = PyList_New(ntypes)
+        cdef int i
+        for i in range(ntypes):
+            oid = types[i]
+            dumper_ptr = self.get_dumper_by_oid(
+                <PyObject *>oid, <PyObject *>format)
+            Py_INCREF(<object>dumper_ptr)
+            PyList_SET_ITEM(dumpers, i, <object>dumper_ptr)
 
-    def set_loader_types(self,
-            types: Sequence[int], format: Format) -> None:
+        self._row_dumpers = dumpers
+        self.types = tuple(types)
+        self.formats = [format] * ntypes
+
+    def set_loader_types(self, types: Sequence[int], format: Format) -> None:
         self._c_loader_types(len(types), types, format)
 
-    cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, int format):
+    cdef void _c_loader_types(self, Py_ssize_t ntypes, list types, object format):
         cdef list loaders = PyList_New(ntypes)
 
         # these are used more as Python object than C
@@ -228,68 +254,82 @@ cdef class Transformer:
         PyDict_SetItem(<object>cache, key1, row_dumper)
         return <PyObject *>row_dumper
 
-    def get_dumper_by_oid(self, oid, format) -> "Dumper":
+    cdef PyObject *get_dumper_by_oid(self, PyObject *oid, PyObject *fmt) except NULL:
+        """
+        Return a borrowed reference to the RowDumper for the given oid/fmt.
+        """
         cdef PyObject *ptr
         cdef PyObject *cache
         cdef RowDumper row_dumper
 
         # Establish where would the dumper be cached
-        if format == PQ_TEXT:
+        cdef int cfmt = <object>fmt
+        if cfmt == 0:
             if self._oid_text_dumpers is None:
                 self._oid_text_dumpers = {}
             cache = <PyObject *>self._oid_text_dumpers
-        elif format == PQ_BINARY:
+        elif cfmt == 1:
             if self._oid_binary_dumpers is None:
                 self._oid_binary_dumpers = {}
             cache = <PyObject *>self._oid_binary_dumpers
         else:
             raise ValueError(
-                f"format should be a psycopg.pq.Format, not {format}")
+                f"format should be a psycopg.pq.Format, not {<object>fmt}")
 
         # Reuse an existing Dumper class for objects of the same type
-        ptr = PyDict_GetItem(<object>cache, oid)
+        ptr = PyDict_GetItem(<object>cache, <object>oid)
         if ptr == NULL:
             dcls = PyObject_CallFunctionObjArgs(
-                self.adapters.get_dumper_by_oid,
-                <PyObject *>oid, <PyObject *>format, NULL)
+                self.adapters.get_dumper_by_oid, oid, fmt, NULL)
             dumper = PyObject_CallFunctionObjArgs(
                 dcls, <PyObject *>NoneType, <PyObject *>self, NULL)
 
             row_dumper = _as_row_dumper(dumper)
-            PyDict_SetItem(<object>cache, oid, row_dumper)
+            PyDict_SetItem(<object>cache, <object>oid, row_dumper)
             ptr = <PyObject *>row_dumper
 
-        return (<RowDumper>ptr).pydumper
+        return ptr
 
     cpdef dump_sequence(self, object params, object formats):
         # Verify that they are not none and that PyList_GET_ITEM won't blow up
         cdef Py_ssize_t nparams = len(params)
-        cdef list ps = PyList_New(nparams)
-        cdef tuple ts = PyTuple_New(nparams)
-        cdef list fs = PyList_New(nparams)
-        cdef object dumped, oid
-        cdef Py_ssize_t size
-        cdef PyObject *dumper_ptr  # borrowed pointer to row dumper
+        cdef list out = PyList_New(nparams)
 
-        if self._row_dumpers is None:
-            self._row_dumpers = PyList_New(nparams)
+        cdef int change_state = 0
 
         dumpers = self._row_dumpers
+        cdef object types = None
+        cdef object pqformats = None
+
+        if not dumpers:
+            change_state = 1
+            dumpers = PyList_New(nparams)
+            types = [oids.INVALID_OID] * nparams
+            pqformats = [PQ_TEXT] * nparams
 
         cdef int i
+        cdef PyObject *dumper_ptr  # borrowed pointer to row dumper
+        cdef object dumped
+        cdef Py_ssize_t size
         for i in range(nparams):
             param = params[i]
             if param is not None:
                 dumper_ptr = PyList_GET_ITEM(dumpers, i)
                 if dumper_ptr == NULL:
-                    format = formats[i]
+                    change_state = 1
                     dumper_ptr = self.get_row_dumper(
-                        <PyObject *>param, <PyObject *>format)
+                        <PyObject *>param, <PyObject *>formats[i])
                     Py_INCREF(<object>dumper_ptr)
                     PyList_SET_ITEM(dumpers, i, <object>dumper_ptr)
 
-                oid = (<RowDumper>dumper_ptr).oid
-                dfmt = (<RowDumper>dumper_ptr).format
+                    if types is None:
+                        types = list(self.types)
+                    types[i] = (<RowDumper>dumper_ptr).oid
+
+                    if pqformats is None:
+                        pqformats = list(self.formats)
+                    pqformats[i] = (<RowDumper>dumper_ptr).format
+
                 if (<RowDumper>dumper_ptr).cdumper is not None:
                     dumped = PyByteArray_FromStringAndSize("", 0)
                     size = (<RowDumper>dumper_ptr).cdumper.cdump(
@@ -301,17 +341,16 @@ cdef class Transformer:
                         <PyObject *>param, NULL)
             else:
                 dumped = None
-                oid = oids.INVALID_OID
-                dfmt = PQ_TEXT
 
             Py_INCREF(dumped)
-            PyList_SET_ITEM(ps, i, dumped)
-            Py_INCREF(oid)
-            PyTuple_SET_ITEM(ts, i, oid)
-            Py_INCREF(dfmt)
-            PyList_SET_ITEM(fs, i, dfmt)
+            PyList_SET_ITEM(out, i, dumped)
+
+        if change_state:
+            self._row_dumpers = dumpers
+            self.types = tuple(types)
+            self.formats = pqformats
 
-        return ps, ts, fs
+        return out
 
     def load_rows(self, int row0, int row1, object make_row) -> List[Row]:
         if self._pgresult is None:
index 9765ea9d3170a6aff31fe214b03c1c04296ead67..0314dd8cf64f179b03b458d8cd6a66031cfce4a6 100644 (file)
@@ -68,7 +68,7 @@ def test_split_query_bad(input):
 @pytest.mark.parametrize(
     "query, params, want, wformats, wparams",
     [
-        (b"", None, b"", (), ()),
+        (b"", None, b"", None, None),
         (b"", [], b"", [], []),
         (b"%%", [], b"%", [], []),
         (b"select %t", (1,), b"select $1", [pq.Format.TEXT], [b"1"]),