]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use the composite type info to choose dumpers, instead of the Python type
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 01:27:59 +0000 (03:27 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 27 Aug 2021 22:23:47 +0000 (00:23 +0200)
This allows to dump composites in binary format without the need to use
object wrappers to specify e.g. the type of integer.

psycopg/psycopg/_adapters_map.py
psycopg/psycopg/_transform.py
psycopg/psycopg/abc.py
psycopg/psycopg/types/composite.py
psycopg_c/psycopg_c/_psycopg.pyi
psycopg_c/psycopg_c/_psycopg/transform.pyx
tests/types/test_composite.py

index d0ddb3cafc68dd8dbf2ca5aba3b3f7c3c6edbf2e..4e301bec8b40b7662bedce14a9c2990efe78f11b 100644 (file)
@@ -208,6 +208,33 @@ class AdaptersMap:
             f" to format {PyFormat(format).name}"
         )
 
+    def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Type["Dumper"]:
+        """
+        Return the dumper class for the given oid and format.
+
+        Raise ProgrammingError if a class is not available.
+        """
+        try:
+            dmap = self._dumpers_by_oid[format]
+        except KeyError:
+            raise ValueError(f"bad dumper format: {format}")
+
+        try:
+            return dmap[oid]
+        except KeyError:
+            info = self.types.get(oid)
+            if info:
+                msg = (
+                    f"cannot find a dumper for type {info.name} (oid {oid})"
+                    f" format {pq.Format(format).name}"
+                )
+            else:
+                msg = (
+                    f"cannot find a dumper for unknown type with oid {oid}"
+                    f" format {pq.Format(format).name}"
+                )
+            raise e.ProgrammingError(msg)
+
     def get_loader(
         self, oid: int, format: pq.Format
     ) -> Optional[Type["Loader"]]:
index e51ca083bc1fa2b77bb88a66910e879afbe14b39..c4dbbaf52343745fce2c21b2b78327bb50018934 100644 (file)
@@ -5,7 +5,7 @@ Helper object to transform values between Python and PostgreSQL
 # Copyright (C) 2020-2021 The Psycopg Team
 
 from typing import Any, Dict, List, Optional, Sequence, Tuple
-from typing import DefaultDict, TYPE_CHECKING
+from typing import DefaultDict, Type, TYPE_CHECKING
 from collections import defaultdict
 
 from . import pq
@@ -21,7 +21,9 @@ if TYPE_CHECKING:
     from .pq.abc import PGresult
     from .connection import BaseConnection
 
+NoneType: Type[None] = type(None)
 DumperCache = Dict[DumperKey, "Dumper"]
+OidDumperCache = Dict[int, "Dumper"]
 LoaderCache = Dict[int, "Loader"]
 
 
@@ -50,13 +52,17 @@ class Transformer(AdaptContext):
             self._adapters = postgres.adapters
             self._conn = None
 
-        # mapping class, fmt -> Dumper instance
-        self._dumpers_cache: DefaultDict[PyFormat, DumperCache] = defaultdict(
-            dict
-        )
+        # mapping fmt, class -> Dumper instance
+        self._dumpers: DefaultDict[PyFormat, DumperCache]
+        self._dumpers = defaultdict(dict)
+
+        # mapping fmt, oid -> Dumper instance
+        # Not often used, so create it only if needed.
+        self._oid_dumpers: Optional[Tuple[OidDumperCache, OidDumperCache]]
+        self._oid_dumpers = None
 
-        # mapping oid, fmt -> Loader instance
-        self._loaders_cache: Tuple[LoaderCache, LoaderCache] = ({}, {})
+        # mapping fmt, oid -> Loader instance
+        self._loaders: Tuple[LoaderCache, LoaderCache] = ({}, {})
 
         self._row_dumpers: List[Optional["Dumper"]] = []
 
@@ -143,7 +149,7 @@ class Transformer(AdaptContext):
         key = type(obj)
 
         # Reuse an existing Dumper class for objects of the same type
-        cache = self._dumpers_cache[format]
+        cache = self._dumpers[format]
         try:
             dumper = cache[key]
         except KeyError:
@@ -164,6 +170,25 @@ class Transformer(AdaptContext):
             dumper = cache[key1] = dumper.upgrade(obj, format)
             return dumper
 
+    def get_dumper_by_oid(self, oid: int, format: pq.Format) -> "Dumper":
+        """
+        Return a Dumper to dump an object to the type with given oid.
+        """
+        if not self._oid_dumpers:
+            self._oid_dumpers = ({}, {})
+
+        # Reuse an existing Dumper class for objects of the same type
+        cache = self._oid_dumpers[format]
+        try:
+            return cache[oid]
+        except KeyError:
+            # If it's the first time we see this type, look for a dumper
+            # configured for it.
+            dcls = self.adapters.get_dumper_by_oid(oid, format)
+            cache[oid] = dumper = dcls(NoneType, self)
+
+        return dumper
+
     def load_rows(
         self, row0: int, row1: int, make_row: RowMaker[Row]
     ) -> List[Row]:
@@ -219,7 +244,7 @@ class Transformer(AdaptContext):
 
     def get_loader(self, oid: int, format: pq.Format) -> "Loader":
         try:
-            return self._loaders_cache[format][oid]
+            return self._loaders[format][oid]
         except KeyError:
             pass
 
@@ -228,5 +253,5 @@ class Transformer(AdaptContext):
             loader_cls = self._adapters.get_loader(INVALID_OID, format)
             if not loader_cls:
                 raise e.InterfaceError("unknown oid loader not found")
-        loader = self._loaders_cache[format][oid] = loader_cls(oid, self)
+        loader = self._loaders[format][oid] = loader_cls(oid, self)
         return loader
index cc7d72572234c4c4f6e026799e925d10299bfb49..584b2ae037d3dd918c227ce0e1ff5c76e6e0170f 100644 (file)
@@ -206,6 +206,9 @@ class Transformer(Protocol):
     def get_dumper(self, obj: Any, format: PyFormat) -> Dumper:
         ...
 
+    def get_dumper_by_oid(self, oid: int, format: pq.Format) -> Dumper:
+        ...
+
     def load_rows(
         self, row0: int, row1: int, make_row: "RowMaker[Row]"
     ) -> List["Row"]:
index 5e77d1c4f8488ee9007c3f2b362eca1a2649dce6..950142d76471498c9324b091d9b38bb56e59cf1c 100644 (file)
@@ -86,16 +86,17 @@ class TupleBinaryDumper(RecursiveDumper):
             )
 
         out = bytearray(pack_len(len(obj)))
-        get_dumper = self._tx.get_dumper
+        get_dumper = self._tx.get_dumper_by_oid
         for i in range(len(obj)):
             item = obj[i]
+            oid = self.info.field_types[i]
             if item is not None:
-                dumper = get_dumper(item, PyFormat.BINARY)
+                dumper = get_dumper(oid, self.format)
                 b = dumper.dump(item)
-                out += _pack_oidlen(dumper.oid, len(b))
+                out += _pack_oidlen(oid, len(b))
                 out += b
             else:
-                out += _pack_oidlen(self.info.field_types[i], -1)
+                out += _pack_oidlen(oid, -1)
 
         return out
 
index 04fe7b990c0f4d5284ef471642b7e431fb16117e..5c9fcef6853cc26ad1acbe7e09149efdc45807f7 100644 (file)
@@ -38,6 +38,7 @@ class Transformer(abc.AdaptContext):
         self, params: Sequence[Any], formats: Sequence[PyFormat]
     ) -> Tuple[List[Any], Tuple[int, ...], Sequence[pq.Format]]: ...
     def get_dumper(self, obj: Any, format: PyFormat) -> abc.Dumper: ...
+    def get_dumper_by_oid(self, oid: int, format: pq.Format) -> abc.Dumper: ...
     def load_rows(
         self, row0: int, row1: int, make_row: RowMaker[Row]
     ) -> List[Row]: ...
index 3029bbc6ff8ad89d605fc2266eb673e03afabdd9..a68886fc96cb5cd3fb5c7ebe3d267b904bc4771e 100644 (file)
@@ -25,6 +25,8 @@ from psycopg import errors as e
 from psycopg.pq import Format as PqFormat
 from psycopg.rows import Row, RowMaker
 
+NoneType = type(None)
+
 # internal structure: you are not supposed to know this. But it's worth some
 # 10% of the innermost loop, so I'm willing to ask for forgiveness later...
 
@@ -80,6 +82,10 @@ cdef class Transformer:
     cdef dict _text_loaders
     cdef dict _binary_loaders
 
+    # mapping oid -> Dumper instance (text, binary)
+    cdef dict _oid_text_dumpers
+    cdef dict _oid_binary_dumpers
+
     cdef pq.PGresult _pgresult
     cdef int _nfields, _ntuples
     cdef list _row_dumpers
@@ -227,6 +233,39 @@ cdef class Transformer:
         PyDict_SetItem(<object>cache, key1, row_dumper)
         return <PyObject *>row_dumper
 
+    def get_dumper_by_oid(self, oid, format) -> "Dumper":
+        cdef PyObject *ptr
+        cdef PyObject *cache
+        cdef RowDumper row_dumper
+
+        # Establish where would the dumper be cached
+        if format == PQ_TEXT:
+            if self._oid_text_dumpers is None:
+                self._oid_text_dumpers = {}
+            cache = <PyObject *>self._oid_text_dumpers
+        elif format == PQ_BINARY:
+            if self._oid_binary_dumpers is None:
+                self._oid_binary_dumpers = {}
+            cache = <PyObject *>self._oid_binary_dumpers
+        else:
+            raise ValueError(
+                f"format should be a psycopg.pq.Format, not {format}")
+
+        # Reuse an existing Dumper class for objects of the same type
+        ptr = PyDict_GetItem(<object>cache, oid)
+        if ptr == NULL:
+            dcls = PyObject_CallFunctionObjArgs(
+                self.adapters.get_dumper_by_oid,
+                <PyObject *>oid, <PyObject *>format, NULL)
+            dumper = PyObject_CallFunctionObjArgs(
+                dcls, <PyObject *>NoneType, <PyObject *>self, NULL)
+
+            row_dumper = _as_row_dumper(dumper)
+            PyDict_SetItem(<object>cache, oid, row_dumper)
+            ptr = <PyObject *>row_dumper
+
+        return (<RowDumper>ptr).pydumper
+
     cpdef dump_sequence(self, object params, object formats):
         # Verify that they are not none and that PyList_GET_ITEM won't blow up
         cdef Py_ssize_t nparams = len(params)
@@ -432,7 +471,6 @@ cdef class Transformer:
             <PyObject *>oid, <PyObject *>format)
         return (<RowLoader>row_loader).pyloader
 
-
     cdef PyObject *_c_get_loader(self, PyObject *oid, PyObject *fmt) except NULL:
         """
         Return a borrowed reference to the RowLoader instance for given oid/fmt
index 172b45dd841bbf9a3ae644ef5fea36ded9994b5d..8b0e97d5c0de7c1b296cf04040a8b34ee2491db7 100644 (file)
@@ -6,7 +6,6 @@ from psycopg.adapt import PyFormat as Format
 from psycopg.postgres import types as builtins
 from psycopg.types.composite import CompositeInfo, register_composite
 from psycopg.types.composite import TupleDumper, TupleBinaryDumper
-from psycopg.types.numeric import Int8, Float8
 
 tests_str = [
     ("", ()),
@@ -175,11 +174,7 @@ def test_dump_composite_all_chars(conn, fmt_in, testcomp):
     register_composite(testcomp, cur)
     factory = testcomp.python_type
     for i in range(1, 256):
-        if fmt_in == Format.BINARY:
-            obj = factory(chr(i), Int8(1), Float8(1.0))
-        else:
-            obj = factory(chr(i), 1, 1.0)
-
+        obj = factory(chr(i), 1, 1.0)
         (res,) = cur.execute(
             f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in}", (i, obj)
         ).fetchone()
@@ -192,15 +187,14 @@ def test_dump_composite_null(conn, fmt_in, testcomp):
     register_composite(testcomp, cur)
     factory = testcomp.python_type
 
-    if fmt_in == Format.BINARY:
-        obj = factory("foo", Int8(1), None)
-    else:
-        obj = factory("foo", 1, None)
-
-    (res,) = cur.execute(
-        f"select row('foo', 1, NULL)::testcomp = %{fmt_in}", (obj,)
+    obj = factory("foo", 1, None)
+    rec = cur.execute(
+        f"""
+        select row('foo', 1, NULL)::testcomp = %(obj){fmt_in}, %(obj){fmt_in}::text
+        """,
+        {"obj": obj},
     ).fetchone()
-    assert res is True
+    assert rec[0] is True, rec[1]
 
 
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
@@ -296,7 +290,7 @@ def test_type_dumper_registered_binary(conn, testcomp):
     assert issubclass(d, TupleBinaryDumper)
     assert d is not TupleBinaryDumper
 
-    tc = info.python_type("foo", Int8(42), Float8(3.14))
+    tc = info.python_type("foo", 42, 3.14)
     cur = conn.execute("select pg_typeof(%b)", [tc])
     assert cur.fetchone()[0] == "testcomp"