]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add hstore adapters
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 16 Jul 2021 01:40:48 +0000 (03:40 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 23 Jul 2021 15:56:04 +0000 (17:56 +0200)
psycopg/psycopg/types/hstore.py [new file with mode: 0644]
tests/types/test_hstore.py [new file with mode: 0644]

diff --git a/psycopg/psycopg/types/hstore.py b/psycopg/psycopg/types/hstore.py
new file mode 100644 (file)
index 0000000..3367bc0
--- /dev/null
@@ -0,0 +1,123 @@
+"""
+Dict to hstore adaptation
+"""
+
+# Copyright (C) 2021 The Psycopg Team
+
+import re
+from typing import Dict, List, Optional, Type
+
+from .. import pq
+from .. import errors as e
+from .. import postgres
+from ..abc import Buffer, AdaptContext
+from ..adapt import PyFormat, RecursiveDumper, RecursiveLoader
+from ..postgres import TEXT_OID
+from .._typeinfo import TypeInfo
+
+_re_escape = re.compile(r'(["\\])')
+_re_unescape = re.compile(r"\\(.)")
+
+_re_hstore = re.compile(
+    r"""
+    # hstore key:
+    # a string of normal or escaped chars
+    "((?: [^"\\] | \\. )*)"
+    \s*=>\s* # hstore value
+    (?:
+        NULL # the value can be null - not caught
+        # or a quoted string like the key
+        | "((?: [^"\\] | \\. )*)"
+    )
+    (?:\s*,\s*|$) # pairs separated by comma or end of string.
+""",
+    re.VERBOSE,
+)
+
+
+Hstore = Dict[str, Optional[str]]
+
+
+class BaseHstoreDumper(RecursiveDumper):
+
+    format = pq.Format.TEXT
+
+    def dump(self, obj: Hstore) -> bytes:
+        if not obj:
+            return b""
+
+        tokens: List[str] = []
+
+        def add_token(s: str) -> None:
+            tokens.append('"')
+            tokens.append(_re_escape.sub(r"\\\1", s))
+            tokens.append('"')
+
+        for k, v in obj.items():
+
+            if not isinstance(k, str):
+                raise e.DataError("hstore keys can only be strings")
+            add_token(k)
+
+            tokens.append("=>")
+
+            if v is None:
+                tokens.append("NULL")
+            elif not isinstance(v, str):
+                raise e.DataError("hstore keys can only be strings")
+            else:
+                add_token(v)
+
+            tokens.append(",")
+
+        del tokens[-1]
+        data = "".join(tokens)
+        dumper = self._tx.get_dumper(data, PyFormat.TEXT)
+        return dumper.dump(data)
+
+
+class HstoreLoader(RecursiveLoader):
+
+    format = pq.Format.TEXT
+
+    def load(self, data: Buffer) -> Hstore:
+        loader = self._tx.get_loader(TEXT_OID, self.format)
+        s: str = loader.load(data)
+
+        rv: Hstore = {}
+        start = 0
+        for m in _re_hstore.finditer(s):
+            if m is None or m.start() != start:
+                raise e.DataError(f"error parsing hstore pair at char {start}")
+            k = _re_unescape.sub(r"\1", m.group(1))
+            v = m.group(2)
+            if v is not None:
+                v = _re_unescape.sub(r"\1", v)
+
+            rv[k] = v
+            start = m.end()
+
+        if start < len(s):
+            raise e.DataError(
+                f"error parsing hstore: unparsed data after char {start}"
+            )
+
+        return rv
+
+
+def register_adapters(
+    info: TypeInfo, context: Optional[AdaptContext] = None
+) -> None:
+
+    info.register(context)
+
+    adapters = context.adapters if context else postgres.adapters
+
+    # Generate and register a customized text dumper
+    dumper: Type[BaseHstoreDumper] = type(
+        "HstoreDumper", (BaseHstoreDumper,), {"_oid": info.oid}
+    )
+    adapters.register_dumper(dict, dumper)
+
+    # register the text loader on the oid
+    adapters.register_loader(info.oid, HstoreLoader)
diff --git a/tests/types/test_hstore.py b/tests/types/test_hstore.py
new file mode 100644 (file)
index 0000000..e52aae7
--- /dev/null
@@ -0,0 +1,108 @@
+import pytest
+
+import psycopg
+from psycopg.types import TypeInfo
+from psycopg.types.hstore import HstoreLoader, register_adapters
+
+
+@pytest.mark.parametrize(
+    "s, d",
+    [
+        ("", {}),
+        ('"a"=>"1", "b"=>"2"', {"a": "1", "b": "2"}),
+        ('"a"  => "1" , "b"  =>  "2"', {"a": "1", "b": "2"}),
+        ('"a"=>NULL, "b"=>"2"', {"a": None, "b": "2"}),
+        (r'"a"=>"\"", "\""=>"2"', {"a": '"', '"': "2"}),
+        ('"a"=>"\'", "\'"=>"2"', {"a": "'", "'": "2"}),
+        ('"a"=>"1", "b"=>NULL', {"a": "1", "b": None}),
+        (r'"a\\"=>"1"', {"a\\": "1"}),
+        (r'"a\""=>"1"', {'a"': "1"}),
+        (r'"a\\\""=>"1"', {r"a\"": "1"}),
+        (r'"a\\\\\""=>"1"', {r'a\\"': "1"}),
+        ('"\xe8"=>"\xe0"', {"\xe8": "\xe0"}),
+    ],
+)
+def test_parse_ok(s, d):
+    loader = HstoreLoader(dict, None)
+    assert loader.load(s.encode("utf8")) == d
+
+
+@pytest.mark.parametrize(
+    "s",
+    [
+        "a",
+        '"a"',
+        r'"a\\""=>"1"',
+        r'"a\\\\""=>"1"',
+        '"a=>"1"',
+        '"a"=>"1", "b"=>NUL',
+    ],
+)
+def test_parse_bad(s):
+    with pytest.raises(psycopg.DataError):
+        loader = HstoreLoader(dict, None)
+        loader.load(s.encode("utf8"))
+
+
+def test_register_conn(hstore, conn):
+    info = TypeInfo.fetch(conn, "hstore")
+    register_adapters(info, conn)
+    assert conn.adapters.types[info.oid].name == "hstore"
+
+    cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+    assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_curs(hstore, conn):
+    info = TypeInfo.fetch(conn, "hstore")
+    cur = conn.cursor()
+    register_adapters(info, cur)
+    assert conn.adapters.types.get(info.oid) is None
+    assert cur.adapters.types[info.oid].name == "hstore"
+
+    cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+    assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+def test_register_globally(hstore, dsn, svcconn, global_adapters):
+    info = TypeInfo.fetch(svcconn, "hstore")
+    register_adapters(info)
+    assert psycopg.adapters.types[info.oid].name == "hstore"
+
+    assert svcconn.adapters.types.get(info.oid) is None
+    conn = psycopg.connect(dsn)
+    assert conn.adapters.types[info.oid].name == "hstore"
+
+    cur = conn.execute("select null::hstore, ''::hstore, 'a => b'::hstore")
+    assert cur.fetchone() == (None, {}, {"a": "b"})
+
+
+ab = list(map(chr, range(32, 128)))
+samp = [
+    {},
+    {"a": "b", "c": None},
+    dict(zip(ab, ab)),
+    {"".join(ab): "".join(ab)},
+]
+
+
+@pytest.mark.parametrize("d", samp)
+def test_roundtrip(hstore, conn, d):
+    register_adapters(TypeInfo.fetch(conn, "hstore"), conn)
+    d1 = conn.execute("select %s", [d]).fetchone()[0]
+    assert d == d1
+
+
+def test_roundtrip_array(hstore, conn):
+    register_adapters(TypeInfo.fetch(conn, "hstore"), conn)
+    samp1 = conn.execute("select %s", (samp,)).fetchone()[0]
+    assert samp1 == samp
+
+
+@pytest.fixture
+def hstore(svcconn):
+    try:
+        with svcconn.transaction():
+            svcconn.execute("create extension if not exists hstore")
+    except psycopg.Error as e:
+        pytest.skip(str(e))