]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added text cast of records
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Apr 2020 06:44:04 +0000 (18:44 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Apr 2020 06:44:04 +0000 (18:44 +1200)
psycopg3/types/__init__.py
psycopg3/types/composite.py [new file with mode: 0644]
tests/types/test_composite.py [new file with mode: 0644]

index 8379dfd225dcb209ffdffdb2268cce18a82e0634..b07d7f770023f89d492133af9298e7e208997b30 100644 (file)
@@ -8,6 +8,6 @@ psycopg3 types package
 from .oids import builtins
 
 # Register default adapters
-from . import array, numeric, text  # noqa
+from . import array, composite, numeric, text  # noqa
 
 __all__ = ["builtins"]
diff --git a/psycopg3/types/composite.py b/psycopg3/types/composite.py
new file mode 100644 (file)
index 0000000..7a25595
--- /dev/null
@@ -0,0 +1,52 @@
+"""
+Support for composite types adaptation.
+"""
+
+import re
+from typing import Any, Generator, Optional, Tuple
+
+from ..pq import Format
+from ..adapt import TypeCaster, Transformer, AdaptContext
+from .oids import builtins
+
+
+TEXT_OID = builtins["text"].oid
+
+
+_re_tokenize = re.compile(
+    br"""(?x)
+      \(? ([,)])                        # an empty token, representing NULL
+    | \(? " ((?: [^"] | "")*) " [,)]    # or a quoted string
+    | \(? ([^",)]+) [,)]                # or an unquoted string
+    """
+)
+
+_re_undouble = re.compile(br'(["\\])\1')
+
+
+@TypeCaster.text(builtins["record"].oid)
+class RecordCaster(TypeCaster):
+    def __init__(self, oid: int, context: AdaptContext = None):
+        super().__init__(oid, context)
+        self.tx = Transformer(context)
+
+    def cast(self, data: bytes) -> Tuple[Any, ...]:
+        cast = self.tx.get_cast_function(TEXT_OID, format=Format.TEXT)
+        return tuple(
+            cast(item) if item is not None else None
+            for item in self.parse_record(data)
+        )
+
+    def parse_record(
+        self, data: bytes
+    ) -> Generator[Optional[bytes], None, None]:
+        if data == b"()":
+            return
+
+        for m in _re_tokenize.finditer(data):
+            if m.group(1) is not None:
+                yield None
+            elif m.group(2) is not None:
+                yield _re_undouble.sub(br"\1", m.group(2))
+            else:
+                yield m.group(3)
diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py
new file mode 100644 (file)
index 0000000..e7d9bd2
--- /dev/null
@@ -0,0 +1,42 @@
+import pytest
+
+
+@pytest.mark.parametrize(
+    "rec, want",
+    [
+        ("", ()),
+        # Funnily enough there's no way to represent (None,) in Postgres
+        ("null", ()),
+        ("null,null", (None, None)),
+        ("null, ''", (None, "")),
+        (
+            "42,'foo','ba,r','ba''z','qu\"x'",
+            ("42", "foo", "ba,r", "ba'z", 'qu"x'),
+        ),
+        (
+            "'foo''', '''foo', '\"bar', 'bar\"' ",
+            ("foo'", "'foo", '"bar', 'bar"'),
+        ),
+    ],
+)
+def test_cast_record(conn, want, rec):
+    cur = conn.cursor()
+    res = cur.execute(f"select row({rec})").fetchone()[0]
+    assert res == want
+
+
+def test_cast_all_chars(conn):
+    cur = conn.cursor()
+    for i in range(1, 256):
+        res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0]
+        assert res == (chr(i),)
+
+    cur.execute(
+        "select row(%s)" % ",".join(f"chr({i}::int)" for i in range(1, 256))
+    )
+    res = cur.fetchone()[0]
+    assert res == tuple(map(chr, range(1, 256)))
+
+    s = "".join(map(chr, range(1, 256)))
+    res = cur.execute("select row(%s)", [s]).fetchone()[0]
+    assert res == (s,)