]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add binary composite dumper
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Aug 2021 21:39:33 +0000 (23:39 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 26 Aug 2021 21:39:33 +0000 (23:39 +0200)
psycopg/psycopg/types/composite.py
tests/types/test_composite.py

index ac223c7573b0de95e17c883dc53847a18dff5c98..0eedd24321f7f7b33aef6c67e895044ce7b62aa8 100644 (file)
@@ -11,14 +11,16 @@ from typing import Any, Callable, cast, Iterator, List, Optional
 from typing import Sequence, Tuple, Type
 
 from .. import pq
+from .. import errors as e
 from .. import postgres
 from ..abc import AdaptContext, Buffer
 from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
-from .._struct import unpack_len
+from .._struct import pack_len, unpack_len
 from ..postgres import TEXT_OID
 from .._typeinfo import CompositeInfo as CompositeInfo  # exported here
 
 _struct_oidlen = struct.Struct("!Ii")
+_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack)
 _unpack_oidlen = cast(
     Callable[[bytes, int], Tuple[int, int]], _struct_oidlen.unpack_from
 )
@@ -68,6 +70,36 @@ class TupleDumper(SequenceDumper):
         return self._dump_sequence(obj, b"(", b")", b",")
 
 
+class TupleBinaryDumper(RecursiveDumper):
+
+    format = pq.Format.BINARY
+
+    # Subclasses must set an info
+    info: CompositeInfo
+
+    def dump(self, obj: Tuple[Any, ...]) -> bytearray:
+
+        if len(obj) != len(self.info.field_types):
+            raise e.DataError(
+                f"expected a sequence of {len(self.info.field_types)} items,"
+                f" got {len(obj)}"
+            )
+
+        out = bytearray(pack_len(len(obj)))
+        get_dumper = self._tx.get_dumper
+        for i in range(len(obj)):
+            item = obj[i]
+            if item is not None:
+                dumper = get_dumper(item, PyFormat.BINARY)
+                b = dumper.dump(item)
+                out += _pack_oidlen(dumper.oid, len(b))
+                out += b
+            else:
+                out += _pack_oidlen(self.info.field_types[i], -1)
+
+        return out
+
+
 class BaseCompositeLoader(RecursiveLoader):
 
     format = pq.Format.TEXT
@@ -216,12 +248,21 @@ def register_composite(
     )
     adapters.register_loader(info.oid, loader)
 
-    # If the factory is a type, register a dumper for it
+    # If the factory is a type, create and register dumpers for it
     if isinstance(factory, type):
+        dumper = type(
+            f"{info.name.title()}BinaryDumper",
+            (TupleBinaryDumper,),
+            {"_oid": info.oid, "info": info},
+        )
+        adapters.register_dumper(factory, dumper)
+
+        # Default to the text dumper because it is more flexible
         dumper = type(
             f"{info.name.title()}Dumper", (TupleDumper,), {"_oid": info.oid}
         )
         adapters.register_dumper(factory, dumper)
+
         info.python_type = factory
 
 
index 5d3504f0eb38f1a7309986517ae0c0a617cdfb8e..172b45dd841bbf9a3ae644ef5fea36ded9994b5d 100644 (file)
@@ -5,7 +5,8 @@ from psycopg.sql import Identifier
 from psycopg.adapt import PyFormat as Format
 from psycopg.postgres import types as builtins
 from psycopg.types.composite import CompositeInfo, register_composite
-from psycopg.types.composite import TupleDumper
+from psycopg.types.composite import TupleDumper, TupleBinaryDumper
+from psycopg.types.numeric import Int8, Float8
 
 tests_str = [
     ("", ()),
@@ -107,6 +108,7 @@ def testcomp(svcconn):
         create type testschema.testcomp as (foo text, bar int8, qux bool);
         """
     )
+    return CompositeInfo.fetch(svcconn, "testcomp")
 
 
 fetch_cases = [
@@ -156,10 +158,8 @@ async def test_fetch_info_async(aconn, testcomp, name, fields):
         assert info.field_types[i] == builtins[t].oid
 
 
-@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
-def test_dump_composite_all_chars(conn, fmt_in, testcomp):
-    if fmt_in == Format.BINARY:
-        pytest.xfail("binary composite dumper not implemented")
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT])
+def test_dump_tuple_all_chars(conn, fmt_in, testcomp):
     cur = conn.cursor()
     for i in range(1, 256):
         (res,) = cur.execute(
@@ -169,6 +169,40 @@ def test_dump_composite_all_chars(conn, fmt_in, testcomp):
         assert res is True
 
 
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_composite_all_chars(conn, fmt_in, testcomp):
+    cur = conn.cursor()
+    register_composite(testcomp, cur)
+    factory = testcomp.python_type
+    for i in range(1, 256):
+        if fmt_in == Format.BINARY:
+            obj = factory(chr(i), Int8(1), Float8(1.0))
+        else:
+            obj = factory(chr(i), 1, 1.0)
+
+        (res,) = cur.execute(
+            f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in}", (i, obj)
+        ).fetchone()
+        assert res is True
+
+
+@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
+def test_dump_composite_null(conn, fmt_in, testcomp):
+    cur = conn.cursor()
+    register_composite(testcomp, cur)
+    factory = testcomp.python_type
+
+    if fmt_in == Format.BINARY:
+        obj = factory("foo", Int8(1), None)
+    else:
+        obj = factory("foo", 1, None)
+
+    (res,) = cur.execute(
+        f"select row('foo', 1, NULL)::testcomp = %{fmt_in}", (obj,)
+    ).fetchone()
+    assert res is True
+
+
 @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY])
 def test_load_composite(conn, testcomp, fmt_out):
     info = CompositeInfo.fetch(conn, "testcomp")
@@ -221,11 +255,9 @@ def test_register_scope(conn, testcomp):
         for oid in (info.oid, info.array_oid):
             assert postgres.adapters._loaders[fmt].pop(oid)
 
-    for fmt in (Format.AUTO, Format.TEXT):
+    for fmt in Format:
         assert postgres.adapters._dumpers[fmt].pop(info.python_type)
 
-    assert info.python_type not in postgres.adapters._dumpers[Format.BINARY]
-
     cur = conn.cursor()
     register_composite(info, cur)
     for fmt in (pq.Format.TEXT, pq.Format.BINARY):
@@ -255,6 +287,20 @@ def test_type_dumper_registered(conn, testcomp):
     assert cur.fetchone()[0] == "testcomp"
 
 
+def test_type_dumper_registered_binary(conn, testcomp):
+    info = CompositeInfo.fetch(conn, "testcomp")
+    register_composite(info, conn)
+    assert issubclass(info.python_type, tuple)
+    assert info.python_type.__name__ == "testcomp"
+    d = conn.adapters.get_dumper(info.python_type, "b")
+    assert issubclass(d, TupleBinaryDumper)
+    assert d is not TupleBinaryDumper
+
+    tc = info.python_type("foo", Int8(42), Float8(3.14))
+    cur = conn.execute("select pg_typeof(%b)", [tc])
+    assert cur.fetchone()[0] == "testcomp"
+
+
 def test_callable_dumper_not_registered(conn, testcomp):
     info = CompositeInfo.fetch(conn, "testcomp")