]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Added binary cast of record type
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Apr 2020 09:23:42 +0000 (21:23 +1200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Tue, 7 Apr 2020 09:23:42 +0000 (21:23 +1200)
psycopg3/types/composite.py
tests/types/test_composite.py

index 7a25595e9d161c3ccf484ce8a172b0bbfc5db385..f0779a8f15b2614895a1d1c49e4d218b59ac3384 100644 (file)
@@ -3,6 +3,7 @@ Support for composite types adaptation.
 """
 
 import re
+import struct
 from typing import Any, Generator, Optional, Tuple
 
 from ..pq import Format
@@ -24,20 +25,22 @@ _re_tokenize = re.compile(
 _re_undouble = re.compile(br'(["\\])\1')
 
 
-@TypeCaster.text(builtins["record"].oid)
-class RecordCaster(TypeCaster):
+class BaseCompositeCaster(TypeCaster):
     def __init__(self, oid: int, context: AdaptContext = None):
         super().__init__(oid, context)
         self.tx = Transformer(context)
 
+
+@TypeCaster.text(builtins["record"].oid)
+class RecordCaster(BaseCompositeCaster):
     def cast(self, data: bytes) -> Tuple[Any, ...]:
         cast = self.tx.get_cast_function(TEXT_OID, format=Format.TEXT)
         return tuple(
             cast(item) if item is not None else None
-            for item in self.parse_record(data)
+            for item in self._parse_record(data)
         )
 
-    def parse_record(
+    def _parse_record(
         self, data: bytes
     ) -> Generator[Optional[bytes], None, None]:
         if data == b"()":
@@ -50,3 +53,42 @@ class RecordCaster(TypeCaster):
                 yield _re_undouble.sub(br"\1", m.group(2))
             else:
                 yield m.group(3)
+
+
+_struct_len = struct.Struct("!i")
+_struct_oidlen = struct.Struct("!Ii")
+
+
+@TypeCaster.binary(builtins["record"].oid)
+class BinaryRecordCaster(BaseCompositeCaster):
+    _types_set = False
+
+    def cast(self, data: bytes) -> Tuple[Any, ...]:
+        if not self._types_set:
+            self._config_types(data)
+            self._types_set = True
+
+        return tuple(
+            self.tx.cast_sequence(
+                data[offset : offset + length] if length != -1 else None
+                for _, offset, length in self._walk_record(data)
+            )
+        )
+
+    def _walk_record(
+        self, data: bytes
+    ) -> Generator[Tuple[int, int, int], None, None]:
+        """
+        Yield a sequence of (oid, offset, length) for the content of the record
+        """
+        nfields = _struct_len.unpack_from(data, 0)[0]
+        i = 4
+        for _ in range(nfields):
+            oid, length = _struct_oidlen.unpack_from(data, i)
+            yield oid, i + 8, length
+            i += (8 + length) if length > 0 else 8
+
+    def _config_types(self, data: bytes) -> None:
+        self.tx.set_row_types(
+            (oid, Format.BINARY) for oid, _, _ in self._walk_record(data)
+        )
index e7d9bd29b5b1053513fde0269e9479ad899b255e..9406e175c2765809aa06a6403e1ce5279e263d2c 100644 (file)
@@ -1,5 +1,7 @@
 import pytest
 
+from psycopg3.adapt import Format
+
 
 @pytest.mark.parametrize(
     "rec, want",
@@ -25,8 +27,9 @@ def test_cast_record(conn, want, rec):
     assert res == want
 
 
-def test_cast_all_chars(conn):
-    cur = conn.cursor()
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_all_chars(conn, fmt_out):
+    cur = conn.cursor(binary=fmt_out == Format.BINARY)
     for i in range(1, 256):
         res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0]
         assert res == (chr(i),)
@@ -40,3 +43,33 @@ def test_cast_all_chars(conn):
     s = "".join(map(chr, range(1, 256)))
     res = cur.execute("select row(%s)", [s]).fetchone()[0]
     assert res == (s,)
+
+
+@pytest.mark.parametrize(
+    "rec, want",
+    [
+        ("", ()),
+        ("null", (None,)),  # Unlike text format, this is a thing
+        ("null,null", (None, None)),
+        ("null, ''", (None, b"")),
+        (
+            "42,'foo','ba,r','ba''z','qu\"x'",
+            (42, b"foo", b"ba,r", b"ba'z", b'qu"x'),
+        ),
+        (
+            "'foo''', '''foo', '\"bar', 'bar\"' ",
+            (b"foo'", b"'foo", b'"bar', b'bar"'),
+        ),
+        (
+            "10::int, null::text, 20::float,"
+            " null::text, 'foo'::text, 'bar'::bytea ",
+            (10, None, 20.0, None, "foo", b"bar"),
+        ),
+    ],
+)
+def test_cast_record_binary(conn, want, rec):
+    cur = conn.cursor(binary=True)
+    res = cur.execute(f"select row({rec})").fetchone()[0]
+    assert res == want
+    for o1, o2 in zip(res, want):
+        assert type(o1) is type(o2)