"""
import re
+import struct
from typing import Any, Generator, Optional, Tuple
from ..pq import Format
_re_undouble = re.compile(br'(["\\])\1')
-@TypeCaster.text(builtins["record"].oid)
-class RecordCaster(TypeCaster):
+class BaseCompositeCaster(TypeCaster):
def __init__(self, oid: int, context: AdaptContext = None):
super().__init__(oid, context)
self.tx = Transformer(context)
+
+@TypeCaster.text(builtins["record"].oid)
+class RecordCaster(BaseCompositeCaster):
def cast(self, data: bytes) -> Tuple[Any, ...]:
cast = self.tx.get_cast_function(TEXT_OID, format=Format.TEXT)
return tuple(
cast(item) if item is not None else None
- for item in self.parse_record(data)
+ for item in self._parse_record(data)
)
- def parse_record(
+ def _parse_record(
self, data: bytes
) -> Generator[Optional[bytes], None, None]:
if data == b"()":
yield _re_undouble.sub(br"\1", m.group(2))
else:
yield m.group(3)
+
+
+_struct_len = struct.Struct("!i")
+_struct_oidlen = struct.Struct("!Ii")
+
+
+@TypeCaster.binary(builtins["record"].oid)
+class BinaryRecordCaster(BaseCompositeCaster):
+ _types_set = False
+
+ def cast(self, data: bytes) -> Tuple[Any, ...]:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ return tuple(
+ self.tx.cast_sequence(
+ data[offset : offset + length] if length != -1 else None
+ for _, offset, length in self._walk_record(data)
+ )
+ )
+
+ def _walk_record(
+ self, data: bytes
+ ) -> Generator[Tuple[int, int, int], None, None]:
+ """
+ Yield a sequence of (oid, offset, length) for the content of the record
+ """
+ nfields = _struct_len.unpack_from(data, 0)[0]
+ i = 4
+ for _ in range(nfields):
+ oid, length = _struct_oidlen.unpack_from(data, i)
+ yield oid, i + 8, length
+ i += (8 + length) if length > 0 else 8
+
+ def _config_types(self, data: bytes) -> None:
+ self.tx.set_row_types(
+ (oid, Format.BINARY) for oid, _, _ in self._walk_record(data)
+ )
import pytest
+from psycopg3.adapt import Format
+
@pytest.mark.parametrize(
"rec, want",
assert res == want
-def test_cast_all_chars(conn):
- cur = conn.cursor()
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_all_chars(conn, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
for i in range(1, 256):
res = cur.execute("select row(chr(%s::int))", (i,)).fetchone()[0]
assert res == (chr(i),)
s = "".join(map(chr, range(1, 256)))
res = cur.execute("select row(%s)", [s]).fetchone()[0]
assert res == (s,)
+
+
+@pytest.mark.parametrize(
+ "rec, want",
+ [
+ ("", ()),
+ ("null", (None,)), # Unlike text format, this is a thing
+ ("null,null", (None, None)),
+ ("null, ''", (None, b"")),
+ (
+ "42,'foo','ba,r','ba''z','qu\"x'",
+ (42, b"foo", b"ba,r", b"ba'z", b'qu"x'),
+ ),
+ (
+ "'foo''', '''foo', '\"bar', 'bar\"' ",
+ (b"foo'", b"'foo", b'"bar', b'bar"'),
+ ),
+ (
+ "10::int, null::text, 20::float,"
+ " null::text, 'foo'::text, 'bar'::bytea ",
+ (10, None, 20.0, None, "foo", b"bar"),
+ ),
+ ],
+)
+def test_cast_record_binary(conn, want, rec):
+ cur = conn.cursor(binary=True)
+ res = cur.execute(f"select row({rec})").fetchone()[0]
+ assert res == want
+ for o1, o2 in zip(res, want):
+ assert type(o1) is type(o2)