From: Daniele Varrazzo Date: Thu, 26 Aug 2021 21:39:33 +0000 (+0200) Subject: Add binary composite dumper X-Git-Tag: 3.0.beta1~30^2 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=03b0529903d7b24b0f90072c61b594ea2c09fbc4;p=thirdparty%2Fpsycopg.git Add binary composite dumper --- diff --git a/psycopg/psycopg/types/composite.py b/psycopg/psycopg/types/composite.py index ac223c757..0eedd2432 100644 --- a/psycopg/psycopg/types/composite.py +++ b/psycopg/psycopg/types/composite.py @@ -11,14 +11,16 @@ from typing import Any, Callable, cast, Iterator, List, Optional from typing import Sequence, Tuple, Type from .. import pq +from .. import errors as e from .. import postgres from ..abc import AdaptContext, Buffer from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader -from .._struct import unpack_len +from .._struct import pack_len, unpack_len from ..postgres import TEXT_OID from .._typeinfo import CompositeInfo as CompositeInfo # exported here _struct_oidlen = struct.Struct("!Ii") +_pack_oidlen = cast(Callable[[int, int], bytes], _struct_oidlen.pack) _unpack_oidlen = cast( Callable[[bytes, int], Tuple[int, int]], _struct_oidlen.unpack_from ) @@ -68,6 +70,36 @@ class TupleDumper(SequenceDumper): return self._dump_sequence(obj, b"(", b")", b",") +class TupleBinaryDumper(RecursiveDumper): + + format = pq.Format.BINARY + + # Subclasses must set an info + info: CompositeInfo + + def dump(self, obj: Tuple[Any, ...]) -> bytearray: + + if len(obj) != len(self.info.field_types): + raise e.DataError( + f"expected a sequence of {len(self.info.field_types)} items," + f" got {len(obj)}" + ) + + out = bytearray(pack_len(len(obj))) + get_dumper = self._tx.get_dumper + for i in range(len(obj)): + item = obj[i] + if item is not None: + dumper = get_dumper(item, PyFormat.BINARY) + b = dumper.dump(item) + out += _pack_oidlen(dumper.oid, len(b)) + out += b + else: + out += _pack_oidlen(self.info.field_types[i], -1) + + return out + + class BaseCompositeLoader(RecursiveLoader): format = pq.Format.TEXT @@ -216,12 +248,21 @@ def register_composite( ) adapters.register_loader(info.oid, loader) - # If the factory is a type, register a dumper for it + # If the factory is a type, create and register dumpers for it if isinstance(factory, type): + dumper = type( + f"{info.name.title()}BinaryDumper", + (TupleBinaryDumper,), + {"_oid": info.oid, "info": info}, + ) + adapters.register_dumper(factory, dumper) + + # Default to the text dumper because it is more flexible dumper = type( f"{info.name.title()}Dumper", (TupleDumper,), {"_oid": info.oid} ) adapters.register_dumper(factory, dumper) + info.python_type = factory diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 5d3504f0e..172b45dd8 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -5,7 +5,8 @@ from psycopg.sql import Identifier from psycopg.adapt import PyFormat as Format from psycopg.postgres import types as builtins from psycopg.types.composite import CompositeInfo, register_composite -from psycopg.types.composite import TupleDumper +from psycopg.types.composite import TupleDumper, TupleBinaryDumper +from psycopg.types.numeric import Int8, Float8 tests_str = [ ("", ()), @@ -107,6 +108,7 @@ def testcomp(svcconn): create type testschema.testcomp as (foo text, bar int8, qux bool); """ ) + return CompositeInfo.fetch(svcconn, "testcomp") fetch_cases = [ @@ -156,10 +158,8 @@ async def test_fetch_info_async(aconn, testcomp, name, fields): assert info.field_types[i] == builtins[t].oid -@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) -def test_dump_composite_all_chars(conn, fmt_in, testcomp): - if fmt_in == Format.BINARY: - pytest.xfail("binary composite dumper not implemented") +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT]) +def test_dump_tuple_all_chars(conn, fmt_in, testcomp): cur = conn.cursor() for i in range(1, 256): (res,) = cur.execute( @@ -169,6 +169,40 @@ def test_dump_composite_all_chars(conn, fmt_in, testcomp): assert res is True +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_composite_all_chars(conn, fmt_in, testcomp): + cur = conn.cursor() + register_composite(testcomp, cur) + factory = testcomp.python_type + for i in range(1, 256): + if fmt_in == Format.BINARY: + obj = factory(chr(i), Int8(1), Float8(1.0)) + else: + obj = factory(chr(i), 1, 1.0) + + (res,) = cur.execute( + f"select row(chr(%s::int), 1, 1.0)::testcomp = %{fmt_in}", (i, obj) + ).fetchone() + assert res is True + + +@pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) +def test_dump_composite_null(conn, fmt_in, testcomp): + cur = conn.cursor() + register_composite(testcomp, cur) + factory = testcomp.python_type + + if fmt_in == Format.BINARY: + obj = factory("foo", Int8(1), None) + else: + obj = factory("foo", 1, None) + + (res,) = cur.execute( + f"select row('foo', 1, NULL)::testcomp = %{fmt_in}", (obj,) + ).fetchone() + assert res is True + + @pytest.mark.parametrize("fmt_out", [pq.Format.TEXT, pq.Format.BINARY]) def test_load_composite(conn, testcomp, fmt_out): info = CompositeInfo.fetch(conn, "testcomp") @@ -221,11 +255,9 @@ def test_register_scope(conn, testcomp): for oid in (info.oid, info.array_oid): assert postgres.adapters._loaders[fmt].pop(oid) - for fmt in (Format.AUTO, Format.TEXT): + for fmt in Format: assert postgres.adapters._dumpers[fmt].pop(info.python_type) - assert info.python_type not in postgres.adapters._dumpers[Format.BINARY] - cur = conn.cursor() register_composite(info, cur) for fmt in (pq.Format.TEXT, pq.Format.BINARY): @@ -255,6 +287,20 @@ def test_type_dumper_registered(conn, testcomp): assert cur.fetchone()[0] == "testcomp" +def test_type_dumper_registered_binary(conn, testcomp): + info = CompositeInfo.fetch(conn, "testcomp") + register_composite(info, conn) + assert issubclass(info.python_type, tuple) + assert info.python_type.__name__ == "testcomp" + d = conn.adapters.get_dumper(info.python_type, "b") + assert issubclass(d, TupleBinaryDumper) + assert d is not TupleBinaryDumper + + tc = info.python_type("foo", Int8(42), Float8(3.14)) + cur = conn.execute("select pg_typeof(%b)", [tc]) + assert cur.fetchone()[0] == "testcomp" + + def test_callable_dumper_not_registered(conn, testcomp): info = CompositeInfo.fetch(conn, "testcomp")