from __future__ import annotations
import re
+from struct import Struct
from functools import cache
from .. import errors as e
from .. import postgres
+from ..pq import Format
from ..abc import AdaptContext, Buffer
from .._oids import TEXT_OID
-from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
+from ..adapt import Loader, PyFormat, RecursiveDumper, RecursiveLoader
from .._compat import TypeAlias
from .._typeinfo import TypeInfo
+from .._encodings import conn_encoding
_re_escape = re.compile(r'(["\\])')
_re_unescape = re.compile(r"\\(.)")
re.VERBOSE,
)
+_U32_STRUCT = Struct("!I")
+"""Simple struct representing an unsigned 32-bit big-endian integer."""
+
+_I2B = [i.to_bytes(4, "big") for i in range(64)]
+"""Lookup list for small ints to bytes conversions."""
+
Hstore: TypeAlias = "dict[str, str | None]"
return dumper.dump(data)
+class BaseHstoreBinaryDumper(RecursiveDumper):
+ format = Format.BINARY
+
+ def __init__(self, cls: type, context: AdaptContext | None = None):
+ super().__init__(cls, context)
+ enc = conn_encoding(self.connection)
+ self.encoding = enc if enc != "ascii" else "utf-8"
+
+ def dump(self, obj: Hstore) -> Buffer:
+ if not obj:
+ hstore_empty = b"\x00\x00\x00\x00"
+ return hstore_empty
+
+ hstore_null_marker = b"\xff\xff\xff\xff"
+ i2b = _I2B
+ encoding = self.encoding
+ buffer: list[bytes] = [i2b[i] if (i := len(obj)) < 64 else i.to_bytes(4, "big")]
+
+ for key, value in obj.items():
+ key_bytes = key.encode(encoding)
+ buffer.append(
+ i2b[i] if (i := len(key_bytes)) < 64 else i.to_bytes(4, "big")
+ )
+ buffer.append(key_bytes)
+
+ if value is None:
+ buffer.append(hstore_null_marker)
+ else:
+ value_bytes = value.encode(encoding)
+ buffer.append(
+ i2b[i] if (i := len(value_bytes)) < 64 else i.to_bytes(4, "big")
+ )
+ buffer.append(value_bytes)
+
+ return b"".join(buffer)
+
+
class HstoreLoader(RecursiveLoader):
def load(self, data: Buffer) -> Hstore:
loader = self._tx.get_loader(TEXT_OID, self.format)
return rv
+class HstoreBinaryLoader(Loader):
+ format = Format.BINARY
+
+ def __init__(self, oid: int, context: AdaptContext | None = None):
+ super().__init__(oid, context)
+ enc = conn_encoding(self.connection)
+ self.encoding = enc if enc != "ascii" else "utf-8"
+
+ def load(self, data: Buffer) -> Hstore:
+ if len(data) < 12: # Fast-path if too small to contain any data.
+ return {}
+
+ hstore_null_marker = 0xFFFFFFFF
+ unpack_from = _U32_STRUCT.unpack_from
+ encoding = self.encoding
+ result = {}
+
+ view = bytes(data)
+ (size,) = unpack_from(view)
+ pos = 4
+
+ for _ in range(size):
+ (key_size,) = unpack_from(view, pos)
+ pos += 4
+
+ key = view[pos : pos + key_size].decode(encoding)
+ pos += key_size
+
+ (value_size,) = unpack_from(view, pos)
+ pos += 4
+
+ if value_size == hstore_null_marker:
+ value = None
+ else:
+ value = view[pos : pos + value_size].decode(encoding)
+ pos += value_size
+
+ result[key] = value
+
+ return result
+
+
def register_hstore(info: TypeInfo, context: AdaptContext | None = None) -> None:
"""Register the adapters to load and dump hstore.
adapters = context.adapters if context else postgres.adapters
- # Generate and register a customized text dumper
+ # Generate and register customized dumpers
adapters.register_dumper(dict, _make_hstore_dumper(info.oid))
+ adapters.register_dumper(dict, _make_hstore_binary_dumper(info.oid))
- # register the text loader on the oid
+ # Register the loaders on the oid
adapters.register_loader(info.oid, HstoreLoader)
+ adapters.register_loader(info.oid, HstoreBinaryLoader)
# Cache all dynamically-generated types to avoid leaks in case the types
oid = oid_in
return HstoreDumper
+
+
+@cache
+def _make_hstore_binary_dumper(oid_in: int) -> type[BaseHstoreBinaryDumper]:
+ class HstoreBinaryDumper(BaseHstoreBinaryDumper):
+ oid = oid_in
+
+ return HstoreBinaryDumper
import pytest
import psycopg
+from psycopg.pq import Format
from psycopg.types import TypeInfo
-from psycopg.types.hstore import HstoreLoader, register_hstore
+from psycopg.types.hstore import HstoreBinaryLoader, HstoreLoader
+from psycopg.types.hstore import _make_hstore_binary_dumper, register_hstore
pytestmark = pytest.mark.crdb_skip("hstore")
assert loader.load(s.encode()) == d
+@pytest.mark.parametrize(
+ "d, b",
+ [
+ ({}, b"\x00\x00\x00\x00"),
+ (
+ {"a": "1", "b": "2"},
+ b"\x00\x00\x00\x02"
+ b"\x00\x00\x00\x01a\x00\x00\x00\x011"
+ b"\x00\x00\x00\x01b\x00\x00\x00\x012",
+ ),
+ (
+ {"a": None, "b": "2"},
+ b"\x00\x00\x00\x02"
+ b"\x00\x00\x00\x01a\xff\xff\xff\xff"
+ b"\x00\x00\x00\x01b\x00\x00\x00\x012",
+ ),
+ (
+ {"\xe8": "\xe0"},
+ b"\x00\x00\x00\x01\x00\x00\x00\x02\xc3\xa8\x00\x00\x00\x02\xc3\xa0",
+ ),
+ (
+ {"a": None, "b": "1" * 300},
+ b"\x00\x00\x00\x02"
+ b"\x00\x00\x00\x01a\xff\xff\xff\xff"
+ b"\x00\x00\x00\x01b\x00\x00\x01," + b"1" * 300,
+ ),
+ ],
+)
+def test_binary(d, b):
+ dumper = _make_hstore_binary_dumper(0)(dict)
+ assert dumper.dump(d) == b
+ loader = HstoreBinaryLoader(0)
+ assert loader.load(b) == d
+
+
@pytest.mark.parametrize(
"s",
[
ab = list(map(chr, range(32, 128)))
-samp = [
- {},
- {"a": "b", "c": None},
- dict(zip(ab, ab)),
- {"".join(ab): "".join(ab)},
-]
+samp = [{}, {"a": "b", "c": None}, dict(zip(ab, ab)), {"".join(ab): "".join(ab)}]
@pytest.mark.parametrize("d", samp)
-def test_roundtrip(hstore, conn, d):
+@pytest.mark.parametrize("fmt_out", Format)
+def test_roundtrip(hstore, conn, d, fmt_out):
register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
- d1 = conn.execute("select %s", [d]).fetchone()[0]
+ d1 = conn.cursor(binary=fmt_out).execute("select %s", [d]).fetchone()[0]
assert d == d1
-def test_roundtrip_array(hstore, conn):
+@pytest.mark.parametrize("fmt_out", Format)
+def test_roundtrip_array(hstore, conn, fmt_out):
register_hstore(TypeInfo.fetch(conn, "hstore"), conn)
- samp1 = conn.execute("select %s", (samp,)).fetchone()[0]
+ samp1 = conn.cursor(binary=fmt_out).execute("select %s", (samp,)).fetchone()[0]
assert samp1 == samp