]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
feat: implement binary hstore protocol 1031/head
authorKamil Monicz <kamil@monicz.dev>
Fri, 21 Mar 2025 16:47:02 +0000 (16:47 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 9 Apr 2025 14:14:42 +0000 (15:14 +0100)
docs/news.rst
psycopg/psycopg/types/hstore.py
tests/types/test_hstore.py

index 9f0321848ddd6bb93029263d55ce9b7b55ee13eb..06f8864ade14ec2afc05c6e6bc3bd569f5470728 100644 (file)
@@ -20,6 +20,7 @@ Psycopg 3.2.7 (unreleased)
 ^^^^^^^^^^^^^^^^^^^^^^^^^^
 
 - Add SRID support to shapely dumpers/loaders (:ticket:`#1028`).
+- Add support for binary hstore (:ticket:`#1030`).
 
 
 Current release
index ea4a6e61350fca064d5270e4cdd608408804f80c..6ed323fbdb3ce11185e1fae2911d05c631eb6677 100644 (file)
@@ -7,15 +7,18 @@ dict to hstore adaptation
 from __future__ import annotations
 
 import re
+from struct import Struct
 from functools import cache
 
 from .. import errors as e
 from .. import postgres
+from ..pq import Format
 from ..abc import AdaptContext, Buffer
 from .._oids import TEXT_OID
-from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
+from ..adapt import Loader, PyFormat, RecursiveDumper, RecursiveLoader
 from .._compat import TypeAlias
 from .._typeinfo import TypeInfo
+from .._encodings import conn_encoding
 
 _re_escape = re.compile(r'(["\\])')
 _re_unescape = re.compile(r"\\(.)")
@@ -36,6 +39,12 @@ _re_hstore = re.compile(
     re.VERBOSE,
 )
 
+_U32_STRUCT = Struct("!I")
+"""Simple struct representing an unsigned 32-bit big-endian integer."""
+
+_I2B = [i.to_bytes(4, "big") for i in range(64)]
+"""Lookup list for small ints to bytes conversions."""
+
 
 Hstore: TypeAlias = "dict[str, str | None]"
 
@@ -74,6 +83,43 @@ class BaseHstoreDumper(RecursiveDumper):
         return dumper.dump(data)
 
 
+class BaseHstoreBinaryDumper(RecursiveDumper):
+    format = Format.BINARY
+
+    def __init__(self, cls: type, context: AdaptContext | None = None):
+        super().__init__(cls, context)
+        enc = conn_encoding(self.connection)
+        self.encoding = enc if enc != "ascii" else "utf-8"
+
+    def dump(self, obj: Hstore) -> Buffer:
+        if not obj:
+            hstore_empty = b"\x00\x00\x00\x00"
+            return hstore_empty
+
+        hstore_null_marker = b"\xff\xff\xff\xff"
+        i2b = _I2B
+        encoding = self.encoding
+        buffer: list[bytes] = [i2b[i] if (i := len(obj)) < 64 else i.to_bytes(4, "big")]
+
+        for key, value in obj.items():
+            key_bytes = key.encode(encoding)
+            buffer.append(
+                i2b[i] if (i := len(key_bytes)) < 64 else i.to_bytes(4, "big")
+            )
+            buffer.append(key_bytes)
+
+            if value is None:
+                buffer.append(hstore_null_marker)
+            else:
+                value_bytes = value.encode(encoding)
+                buffer.append(
+                    i2b[i] if (i := len(value_bytes)) < 64 else i.to_bytes(4, "big")
+                )
+                buffer.append(value_bytes)
+
+        return b"".join(buffer)
+
+
 class HstoreLoader(RecursiveLoader):
     def load(self, data: Buffer) -> Hstore:
         loader = self._tx.get_loader(TEXT_OID, self.format)
@@ -97,6 +143,48 @@ class HstoreLoader(RecursiveLoader):
         return rv
 
 
+class HstoreBinaryLoader(Loader):
+    format = Format.BINARY
+
+    def __init__(self, oid: int, context: AdaptContext | None = None):
+        super().__init__(oid, context)
+        enc = conn_encoding(self.connection)
+        self.encoding = enc if enc != "ascii" else "utf-8"
+
+    def load(self, data: Buffer) -> Hstore:
+        if len(data) < 12:  # Fast-path if too small to contain any data.
+            return {}
+
+        hstore_null_marker = 0xFFFFFFFF
+        unpack_from = _U32_STRUCT.unpack_from
+        encoding = self.encoding
+        result = {}
+
+        view = bytes(data)
+        (size,) = unpack_from(view)
+        pos = 4
+
+        for _ in range(size):
+            (key_size,) = unpack_from(view, pos)
+            pos += 4
+
+            key = view[pos : pos + key_size].decode(encoding)
+            pos += key_size
+
+            (value_size,) = unpack_from(view, pos)
+            pos += 4
+
+            if value_size == hstore_null_marker:
+                value = None
+            else:
+                value = view[pos : pos + value_size].decode(encoding)
+                pos += value_size
+
+            result[key] = value
+
+        return result
+
+
 def register_hstore(info: TypeInfo, context: AdaptContext | None = None) -> None:
     """Register the adapters to load and dump hstore.
 
@@ -121,11 +209,13 @@ def register_hstore(info: TypeInfo, context: AdaptContext | None = None) -> None
 
     adapters = context.adapters if context else postgres.adapters
 
-    # Generate and register a customized text dumper
+    # Generate and register customized dumpers
     adapters.register_dumper(dict, _make_hstore_dumper(info.oid))
+    adapters.register_dumper(dict, _make_hstore_binary_dumper(info.oid))
 
-    # register the text loader on the oid
+    # Register the loaders on the oid
     adapters.register_loader(info.oid, HstoreLoader)
+    adapters.register_loader(info.oid, HstoreBinaryLoader)
 
 
 # Cache all dynamically-generated types to avoid leaks in case the types
@@ -144,3 +234,11 @@ def _make_hstore_dumper(oid_in: int) -> type[BaseHstoreDumper]:
         oid = oid_in
 
     return HstoreDumper
+
+
+@cache
+def _make_hstore_binary_dumper(oid_in: int) -> type[BaseHstoreBinaryDumper]:
+    class HstoreBinaryDumper(BaseHstoreBinaryDumper):
+        oid = oid_in
+
+    return HstoreBinaryDumper
index 1648e26375b522989a4948e73906e0e86ef1bbd7..954687e4f9ed9246a7c044ed4edbe41c9b2869c3 100644 (file)
@@ -1,8 +1,10 @@
 import pytest
 
 import psycopg
+from psycopg.pq import Format
 from psycopg.types import TypeInfo
-from psycopg.types.hstore import HstoreLoader, register_hstore
+from psycopg.types.hstore import HstoreBinaryLoader, HstoreLoader
+from psycopg.types.hstore import _make_hstore_binary_dumper, register_hstore
 
 pytestmark = pytest.mark.crdb_skip("hstore")
 
@@ -29,6 +31,41 @@ def test_parse_ok(s, d):
     assert loader.load(s.encode()) == d
 
 
+@pytest.mark.parametrize(
+    "d, b",
+    [
+        ({}, b"\x00\x00\x00\x00"),
+        (
+            {"a": "1", "b": "2"},
+            b"\x00\x00\x00\x02"
+            b"\x00\x00\x00\x01a\x00\x00\x00\x011"
+            b"\x00\x00\x00\x01b\x00\x00\x00\x012",
+        ),
+        (
+            {"a": None, "b": "2"},
+            b"\x00\x00\x00\x02"
+            b"\x00\x00\x00\x01a\xff\xff\xff\xff"
+            b"\x00\x00\x00\x01b\x00\x00\x00\x012",
+        ),
+        (
+            {"\xe8": "\xe0"},
+            b"\x00\x00\x00\x01\x00\x00\x00\x02\xc3\xa8\x00\x00\x00\x02\xc3\xa0",
+        ),
+        (
+            {"a": None, "b": "1" * 300},
+            b"\x00\x00\x00\x02"
+            b"\x00\x00\x00\x01a\xff\xff\xff\xff"
+            b"\x00\x00\x00\x01b\x00\x00\x01," + b"1" * 300,
+        ),
+    ],
+)
+def test_binary(d, b):
+    dumper = _make_hstore_binary_dumper(0)(dict)
+    assert dumper.dump(d) == b
+    loader = HstoreBinaryLoader(0)
+    assert loader.load(b) == d
+
+
 @pytest.mark.parametrize(
     "s",
     [
@@ -83,24 +120,21 @@ def test_register_globally(conn_cls, hstore, dsn, svcconn, global_adapters):
 
 
 ab = list(map(chr, range(32, 128)))
-samp = [
-    {},
-    {"a": "b", "c": None},
-    dict(zip(ab, ab)),
-    {"".join(ab): "".join(ab)},
-]
+samp = [{}, {"a": "b", "c": None}, dict(zip(ab, ab)), {"".join(ab): "".join(ab)}]
 
 
 @pytest.mark.parametrize("d", samp)
-def test_roundtrip(hstore, conn, d):
+@pytest.mark.parametrize("fmt_out", Format)
+def test_roundtrip(hstore, conn, d, fmt_out):
     register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
-    d1 = conn.execute("select %s", [d]).fetchone()[0]
+    d1 = conn.cursor(binary=fmt_out).execute("select %s", [d]).fetchone()[0]
     assert d == d1
 
 
-def test_roundtrip_array(hstore, conn):
+@pytest.mark.parametrize("fmt_out", Format)
+def test_roundtrip_array(hstore, conn, fmt_out):
     register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
-    samp1 = conn.execute("select %s", (samp,)).fetchone()[0]
+    samp1 = conn.cursor(binary=fmt_out).execute("select %s", (samp,)).fetchone()[0]
     assert samp1 == samp