From: Kamil Monicz Date: Fri, 21 Mar 2025 16:47:02 +0000 (+0000) Subject: feat: implement binary hstore protocol X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=refs%2Fpull%2F1031%2Fhead;p=thirdparty%2Fpsycopg.git feat: implement binary hstore protocol --- diff --git a/docs/news.rst b/docs/news.rst index 9f0321848..06f8864ad 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -20,6 +20,7 @@ Psycopg 3.2.7 (unreleased) ^^^^^^^^^^^^^^^^^^^^^^^^^^ - Add SRID support to shapely dumpers/loaders (:ticket:`#1028`). +- Add support for binary hstore (:ticket:`#1030`). Current release diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py index ea4a6e613..6ed323fbd 100644 --- a/psycopg/psycopg/types/hstore.py +++ b/psycopg/psycopg/types/hstore.py @@ -7,15 +7,18 @@ dict to hstore adaptation from __future__ import annotations import re +from struct import Struct from functools import cache from .. import errors as e from .. import postgres +from ..pq import Format from ..abc import AdaptContext, Buffer from .._oids import TEXT_OID -from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader +from ..adapt import Loader, PyFormat, RecursiveDumper, RecursiveLoader from .._compat import TypeAlias from .._typeinfo import TypeInfo +from .._encodings import conn_encoding _re_escape = re.compile(r'(["\\])') _re_unescape = re.compile(r"\\(.)") @@ -36,6 +39,12 @@ _re_hstore = re.compile( re.VERBOSE, ) +_U32_STRUCT = Struct("!I") +"""Simple struct representing an unsigned 32-bit big-endian integer.""" + +_I2B = [i.to_bytes(4, "big") for i in range(64)] +"""Lookup list for small ints to bytes conversions.""" + Hstore: TypeAlias = "dict[str, str | None]" @@ -74,6 +83,43 @@ class BaseHstoreDumper(RecursiveDumper): return dumper.dump(data) +class BaseHstoreBinaryDumper(RecursiveDumper): + format = Format.BINARY + + def __init__(self, cls: type, context: AdaptContext | None = None): + super().__init__(cls, context) + enc = conn_encoding(self.connection) + self.encoding = enc if enc != "ascii" else "utf-8" + + def dump(self, obj: Hstore) -> Buffer: + if not obj: + hstore_empty = b"\x00\x00\x00\x00" + return hstore_empty + + hstore_null_marker = b"\xff\xff\xff\xff" + i2b = _I2B + encoding = self.encoding + buffer: list[bytes] = [i2b[i] if (i := len(obj)) < 64 else i.to_bytes(4, "big")] + + for key, value in obj.items(): + key_bytes = key.encode(encoding) + buffer.append( + i2b[i] if (i := len(key_bytes)) < 64 else i.to_bytes(4, "big") + ) + buffer.append(key_bytes) + + if value is None: + buffer.append(hstore_null_marker) + else: + value_bytes = value.encode(encoding) + buffer.append( + i2b[i] if (i := len(value_bytes)) < 64 else i.to_bytes(4, "big") + ) + buffer.append(value_bytes) + + return b"".join(buffer) + + class HstoreLoader(RecursiveLoader): def load(self, data: Buffer) -> Hstore: loader = self._tx.get_loader(TEXT_OID, self.format) @@ -97,6 +143,48 @@ class HstoreLoader(RecursiveLoader): return rv +class HstoreBinaryLoader(Loader): + format = Format.BINARY + + def __init__(self, oid: int, context: AdaptContext | None = None): + super().__init__(oid, context) + enc = conn_encoding(self.connection) + self.encoding = enc if enc != "ascii" else "utf-8" + + def load(self, data: Buffer) -> Hstore: + if len(data) < 12: # Fast-path if too small to contain any data. + return {} + + hstore_null_marker = 0xFFFFFFFF + unpack_from = _U32_STRUCT.unpack_from + encoding = self.encoding + result = {} + + view = bytes(data) + (size,) = unpack_from(view) + pos = 4 + + for _ in range(size): + (key_size,) = unpack_from(view, pos) + pos += 4 + + key = view[pos : pos + key_size].decode(encoding) + pos += key_size + + (value_size,) = unpack_from(view, pos) + pos += 4 + + if value_size == hstore_null_marker: + value = None + else: + value = view[pos : pos + value_size].decode(encoding) + pos += value_size + + result[key] = value + + return result + + def register_hstore(info: TypeInfo, context: AdaptContext | None = None) -> None: """Register the adapters to load and dump hstore. @@ -121,11 +209,13 @@ def register_hstore(info: TypeInfo, context: AdaptContext | None = None) -> None adapters = context.adapters if context else postgres.adapters - # Generate and register a customized text dumper + # Generate and register customized dumpers adapters.register_dumper(dict, _make_hstore_dumper(info.oid)) + adapters.register_dumper(dict, _make_hstore_binary_dumper(info.oid)) - # register the text loader on the oid + # Register the loaders on the oid adapters.register_loader(info.oid, HstoreLoader) + adapters.register_loader(info.oid, HstoreBinaryLoader) # Cache all dynamically-generated types to avoid leaks in case the types @@ -144,3 +234,11 @@ def _make_hstore_dumper(oid_in: int) -> type[BaseHstoreDumper]: oid = oid_in return HstoreDumper + + +@cache +def _make_hstore_binary_dumper(oid_in: int) -> type[BaseHstoreBinaryDumper]: + class HstoreBinaryDumper(BaseHstoreBinaryDumper): + oid = oid_in + + return HstoreBinaryDumper diff --git a/tests/types/test_hstore.py b/tests/types/test_hstore.py index 1648e2637..954687e4f 100644 --- a/tests/types/test_hstore.py +++ b/tests/types/test_hstore.py @@ -1,8 +1,10 @@ import pytest import psycopg +from psycopg.pq import Format from psycopg.types import TypeInfo -from psycopg.types.hstore import HstoreLoader, register_hstore +from psycopg.types.hstore import HstoreBinaryLoader, HstoreLoader +from psycopg.types.hstore import _make_hstore_binary_dumper, register_hstore pytestmark = pytest.mark.crdb_skip("hstore") @@ -29,6 +31,41 @@ def test_parse_ok(s, d): assert loader.load(s.encode()) == d +@pytest.mark.parametrize( + "d, b", + [ + ({}, b"\x00\x00\x00\x00"), + ( + {"a": "1", "b": "2"}, + b"\x00\x00\x00\x02" + b"\x00\x00\x00\x01a\x00\x00\x00\x011" + b"\x00\x00\x00\x01b\x00\x00\x00\x012", + ), + ( + {"a": None, "b": "2"}, + b"\x00\x00\x00\x02" + b"\x00\x00\x00\x01a\xff\xff\xff\xff" + b"\x00\x00\x00\x01b\x00\x00\x00\x012", + ), + ( + {"\xe8": "\xe0"}, + b"\x00\x00\x00\x01\x00\x00\x00\x02\xc3\xa8\x00\x00\x00\x02\xc3\xa0", + ), + ( + {"a": None, "b": "1" * 300}, + b"\x00\x00\x00\x02" + b"\x00\x00\x00\x01a\xff\xff\xff\xff" + b"\x00\x00\x00\x01b\x00\x00\x01," + b"1" * 300, + ), + ], +) +def test_binary(d, b): + dumper = _make_hstore_binary_dumper(0)(dict) + assert dumper.dump(d) == b + loader = HstoreBinaryLoader(0) + assert loader.load(b) == d + + @pytest.mark.parametrize( "s", [ @@ -83,24 +120,21 @@ def test_register_globally(conn_cls, hstore, dsn, svcconn, global_adapters): ab = list(map(chr, range(32, 128))) -samp = [ - {}, - {"a": "b", "c": None}, - dict(zip(ab, ab)), - {"".join(ab): "".join(ab)}, -] +samp = [{}, {"a": "b", "c": None}, dict(zip(ab, ab)), {"".join(ab): "".join(ab)}] @pytest.mark.parametrize("d", samp) -def test_roundtrip(hstore, conn, d): +@pytest.mark.parametrize("fmt_out", Format) +def test_roundtrip(hstore, conn, d, fmt_out): register_hstore(TypeInfo.fetch(conn, "hstore"), conn) - d1 = conn.execute("select %s", [d]).fetchone()[0] + d1 = conn.cursor(binary=fmt_out).execute("select %s", [d]).fetchone()[0] assert d == d1 -def test_roundtrip_array(hstore, conn): +@pytest.mark.parametrize("fmt_out", Format) +def test_roundtrip_array(hstore, conn, fmt_out): register_hstore(TypeInfo.fetch(conn, "hstore"), conn) - samp1 = conn.execute("select %s", (samp,)).fetchone()[0] + samp1 = conn.cursor(binary=fmt_out).execute("select %s", (samp,)).fetchone()[0] assert samp1 == samp