--- /dev/null
+"""
+Dict to hstore adaptation
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+from typing import Dict, List, Optional, Type
+
+from .. import pq
+from .. import errors as e
+from .. import postgres
+from ..abc import Buffer, AdaptContext
+from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
+from ..postgres import TEXT_OID
+from .._typeinfo import TypeInfo
+
+_re_escape = re.compile(r'(["\\])')
+_re_unescape = re.compile(r"\\(.)")
+
+_re_hstore = re.compile(
+ r"""
+ # hstore key:
+ # a string of normal or escaped chars
+ "((?: [^"\\] | \\. )*)"
+ \s*=>\s* # hstore value
+ (?:
+ NULL # the value can be null - not caught
+ # or a quoted string like the key
+ | "((?: [^"\\] | \\. )*)"
+ )
+ (?:\s*,\s*|$) # pairs separated by comma or end of string.
+""",
+ re.VERBOSE,
+)
+
+
+Hstore = Dict[str, Optional[str]]
+
+
+class BaseHstoreDumper(RecursiveDumper):
+
+ format = pq.Format.TEXT
+
+ def dump(self, obj: Hstore) -> bytes:
+ if not obj:
+ return b""
+
+ tokens: List[str] = []
+
+ def add_token(s: str) -> None:
+ tokens.append('"')
+ tokens.append(_re_escape.sub(r"\\\1", s))
+ tokens.append('"')
+
+ for k, v in obj.items():
+
+ if not isinstance(k, str):
+ raise e.DataError("hstore keys can only be strings")
+ add_token(k)
+
+ tokens.append("=>")
+
+ if v is None:
+ tokens.append("NULL")
+ elif not isinstance(v, str):
+ raise e.DataError("hstore keys can only be strings")
+ else:
+ add_token(v)
+
+ tokens.append(",")
+
+ del tokens[-1]
+ data = "".join(tokens)
+ dumper = self._tx.get_dumper(data, PyFormat.TEXT)
+ return dumper.dump(data)
+
+
+class HstoreLoader(RecursiveLoader):
+
+ format = pq.Format.TEXT
+
+ def load(self, data: Buffer) -> Hstore:
+ loader = self._tx.get_loader(TEXT_OID, self.format)
+ s: str = loader.load(data)
+
+ rv: Hstore = {}
+ start = 0
+ for m in _re_hstore.finditer(s):
+ if m is None or m.start() != start:
+ raise e.DataError(f"error parsing hstore pair at char {start}")
+ k = _re_unescape.sub(r"\1", m.group(1))
+ v = m.group(2)
+ if v is not None:
+ v = _re_unescape.sub(r"\1", v)
+
+ rv[k] = v
+ start = m.end()
+
+ if start < len(s):
+ raise e.DataError(
+ f"error parsing hstore: unparsed data after char {start}"
+ )
+
+ return rv
+
+
+def register_adapters(
+ info: TypeInfo, context: Optional[AdaptContext] = None
+) -> None:
+
+ info.register(context)
+
+ adapters = context.adapters if context else postgres.adapters
+
+ # Generate and register a customized text dumper
+ dumper: Type[BaseHstoreDumper] = type(
+ "HstoreDumper", (BaseHstoreDumper,), {"_oid": info.oid}
+ )
+ adapters.register_dumper(dict, dumper)
+
+ # register the text loader on the oid
+ adapters.register_loader(info.oid, HstoreLoader)
--- /dev/null
+import pytest
+
+import psycopg
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import HstoreLoader, register_adapters
+
+
+@pytest.mark.parametrize(
+ "s, d",
+ [
+ ("", {}),
+ ('"a"=>"1", "b"=>"2"', {"a": "1", "b": "2"}),
+ ('"a" => "1" , "b" => "2"', {"a": "1", "b": "2"}),
+ ('"a"=>NULL, "b"=>"2"', {"a": None, "b": "2"}),
+ (r'"a"=>"\"", "\""=>"2"', {"a": '"', '"': "2"}),
+ ('"a"=>"\'", "\'"=>"2"', {"a": "'", "'": "2"}),
+ ('"a"=>"1", "b"=>NULL', {"a": "1", "b": None}),
+ (r'"a\\"=>"1"', {"a\\": "1"}),
+ (r'"a\""=>"1"', {'a"': "1"}),
+ (r'"a\\\""=>"1"', {r"a\"": "1"}),
+ (r'"a\\\\\""=>"1"', {r'a\\"': "1"}),
+ ('"\xe8"=>"\xe0"', {"\xe8": "\xe0"}),
+ ],
+)
+def test_parse_ok(s, d):
+ loader = HstoreLoader(dict, None)
+ assert loader.load(s.encode("utf8")) == d
+
+
+@pytest.mark.parametrize(
+ "s",
+ [
+ "a",
+ '"a"',
+ r'"a\\""=>"1"',
+ r'"a\\\\""=>"1"',
+ '"a=>"1"',
+ '"a"=>"1", "b"=>NUL',
+ ],
+)
+def test_parse_bad(s):
+ with pytest.raises(psycopg.DataError):
+ loader = HstoreLoader(dict, None)
+ loader.load(s.encode("utf8"))
+
+
+def test_register_conn(hstore, conn):
+ info = TypeInfo.fetch(conn, "hstore")
+ register_adapters(info, conn)
+ assert conn.adapters.types[info.oid].name == "hstore"
+
+ cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_curs(hstore, conn):
+ info = TypeInfo.fetch(conn, "hstore")
+ cur = conn.cursor()
+ register_adapters(info, cur)
+ assert conn.adapters.types.get(info.oid) is None
+ assert cur.adapters.types[info.oid].name == "hstore"
+
+ cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_globally(hstore, dsn, svcconn, global_adapters):
+ info = TypeInfo.fetch(svcconn, "hstore")
+ register_adapters(info)
+ assert psycopg.adapters.types[info.oid].name == "hstore"
+
+ assert svcconn.adapters.types.get(info.oid) is None
+ conn = psycopg.connect(dsn)
+ assert conn.adapters.types[info.oid].name == "hstore"
+
+ cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+ assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+ab = list(map(chr, range(32, 128)))
+samp = [
+ {},
+ {"a": "b", "c": None},
+ dict(zip(ab, ab)),
+ {"".join(ab): "".join(ab)},
+]
+
+
+@pytest.mark.parametrize("d", samp)
+def test_roundtrip(hstore, conn, d):
+ register_adapters(TypeInfo.fetch(conn, "hstore"), conn)
+ d1 = conn.execute("select %s", [d]).fetchone()[0]
+ assert d == d1
+
+
+def test_roundtrip_array(hstore, conn):
+ register_adapters(TypeInfo.fetch(conn, "hstore"), conn)
+ samp1 = conn.execute("select %s", (samp,)).fetchone()[0]
+ assert samp1 == samp
+
+
+@pytest.fixture
+def hstore(svcconn):
+ try:
+ with svcconn.transaction():
+ svcconn.execute("create extension if not exists hstore")
+ except psycopg.Error as e:
+ pytest.skip(str(e))