import re
import struct
-from typing import Any, Generator, Optional, Tuple
+from collections import namedtuple
+from typing import Any, Callable, Generator, List, Sequence, Tuple, Union
+from typing import Optional, TYPE_CHECKING
from ..adapt import Format, TypeCaster, Transformer, AdaptContext
-from .oids import builtins
+from .oids import builtins, TypeInfo
+from .array import register_array
+
+if TYPE_CHECKING:
+ from ..connection import Connection
TEXT_OID = builtins["text"].oid
-_re_tokenize = re.compile(
- br"""(?x)
- \(? ([,)]) # an empty token, representing NULL
- | \(? " ((?: [^"] | "")*) " [,)] # or a quoted string
- | \(? ([^",)]+) [,)] # or an unquoted string
- """
-)
+class FieldInfo:
+ def __init__(self, name: str, type_oid: int):
+ self.name = name
+ self.type_oid = type_oid
+
+
+class CompositeTypeInfo(TypeInfo):
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ fields: Sequence[Union[FieldInfo, Tuple[str, int]]],
+ ):
+ super().__init__(name, oid, array_oid)
+ self.fields: List[FieldInfo] = []
+ for f in fields:
+ if isinstance(f, FieldInfo):
+ self.fields.append(f)
+ elif isinstance(f, tuple):
+ self.fields.append(FieldInfo(f[0], f[1]))
+ else:
+ raise TypeError(f"bad field info: {f}")
+
+
+def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]:
+ cur = conn.cursor(binary=True)
+ cur.execute(_type_info_query, (name,))
+ rec = cur.fetchone()
+ if rec is not None:
+ return CompositeTypeInfo(*rec)
+ else:
+ return None
+
+
+def register(
+ info: CompositeTypeInfo,
+ context: AdaptContext = None,
+ factory: Optional[Callable[..., Any]] = None,
+) -> None:
+ if factory is None:
+ factory = namedtuple( # type: ignore
+ info.name, [f.name for f in info.fields]
+ )
-_re_undouble = re.compile(br'(["\\])\1')
+ # generate and register a customized text typecaster
+ caster = type(
+ f"{info.name.title()}Caster",
+ (CompositeCaster,),
+ {
+ "factory": factory,
+ "fields_types": tuple(f.type_oid for f in info.fields),
+ },
+ )
+ TypeCaster.register(info.oid, caster, context=context, format=Format.TEXT)
+
+ # generate and register a customized binary typecaster
+ caster = type(
+ f"{info.name.title()}BinaryCaster",
+ (CompositeBinaryCaster,),
+ {"factory": factory},
+ )
+ TypeCaster.register(
+ info.oid, caster, context=context, format=Format.BINARY
+ )
+
+ if info.array_oid:
+ register_array(info.array_oid, info.oid)
+
+
+_type_info_query = """\
+select
+ name, oid, array_oid,
+ array_agg(row(field_name, field_type)) as fields
+from (
+ select
+ typname as name,
+ t.oid as oid,
+ t.typarray as array_oid,
+ a.attname as field_name,
+ a.atttypid as field_type
+ from pg_type t
+ left join pg_attribute a on a.attrelid = t.typrelid
+ where t.typname = %s
+ and a.attnum > 0
+ order by a.attnum
+) x
+group by name, oid, array_oid
+"""
class BaseCompositeCaster(TypeCaster):
def __init__(self, oid: int, context: AdaptContext = None):
super().__init__(oid, context)
- self.tx = Transformer(context)
+ self._tx = Transformer(context)
@TypeCaster.text(builtins["record"].oid)
class RecordCaster(BaseCompositeCaster):
def cast(self, data: bytes) -> Tuple[Any, ...]:
- cast = self.tx.get_cast_function(TEXT_OID, format=Format.TEXT)
+ cast = self._tx.get_cast_function(TEXT_OID, format=Format.TEXT)
return tuple(
- cast(item) if item is not None else None
- for item in self._parse_record(data)
+ cast(token) if token is not None else None
+ for token in self._parse_record(data)
)
def _parse_record(
if data == b"()":
return
- for m in _re_tokenize.finditer(data):
+ for m in self._re_tokenize.finditer(data):
if m.group(1) is not None:
yield None
elif m.group(2) is not None:
- yield _re_undouble.sub(br"\1", m.group(2))
+ yield self._re_undouble.sub(br"\1", m.group(2))
else:
yield m.group(3)
+ _re_tokenize = re.compile(
+ br"""(?x)
+ \(? ([,)]) # an empty token, representing NULL
+ | \(? " ((?: [^"] | "")*) " [,)] # or a quoted string
+ | \(? ([^",)]+) [,)] # or an unquoted string
+ """
+ )
+
+ _re_undouble = re.compile(br'(["\\])\1')
+
_struct_len = struct.Struct("!i")
_struct_oidlen = struct.Struct("!Ii")
@TypeCaster.binary(builtins["record"].oid)
-class BinaryRecordCaster(BaseCompositeCaster):
+class RecordBinaryCaster(BaseCompositeCaster):
_types_set = False
def cast(self, data: bytes) -> Tuple[Any, ...]:
self._types_set = True
return tuple(
- self.tx.cast_sequence(
+ self._tx.cast_sequence(
data[offset : offset + length] if length != -1 else None
for _, offset, length in self._walk_record(data)
)
i += (8 + length) if length > 0 else 8
def _config_types(self, data: bytes) -> None:
- self.tx.set_row_types(
+ self._tx.set_row_types(
(oid, Format.BINARY) for oid, _, _ in self._walk_record(data)
)
+
+
+class CompositeCaster(RecordCaster):
+ factory: Callable[..., Any]
+ fields_types: Tuple[int, ...]
+ _types_set = False
+
+ def cast(self, data: bytes) -> Any:
+ if not self._types_set:
+ self._config_types(data)
+ self._types_set = True
+
+ return type(self).factory(
+ *self._tx.cast_sequence(self._parse_record(data))
+ )
+
+ def _config_types(self, data: bytes) -> None:
+ self._tx.set_row_types((oid, Format.TEXT) for oid in self.fields_types)
+
+
+class CompositeBinaryCaster(RecordBinaryCaster):
+ factory: Callable[..., Any]
+
+ def cast(self, data: bytes) -> Any:
+ r = super().cast(data)
+ return type(self).factory(*r)
# Copyright (C) 2020 The Psycopg Team
import re
-from typing import Dict, Generator, Optional, NamedTuple, Union
+from typing import Dict, Generator, Optional, Union
INVALID_OID = 0
-class TypeInfo(NamedTuple):
- name: str
- oid: int
- array_oid: int
- alt_name: str
- delimiter: str
+class TypeInfo:
+ def __init__(self, name: str, oid: int, array_oid: int):
+ self.name = name
+ self.oid = oid
+ self.array_oid = array_oid
+
+
+class BuiltinTypeInfo(TypeInfo):
+ def __init__(
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ alt_name: str,
+ delimiter: str,
+ ):
+ super().__init__(name, oid, array_oid)
+ self.alt_name = alt_name
+ self.delimiter = delimiter
class TypesRegistry:
if info.array_oid:
self._by_oid[info.array_oid] = info
self._by_name[info.name] = info
- if info.alt_name not in self._by_name:
- self._by_name[info.alt_name] = info
+
+ if isinstance(info, BuiltinTypeInfo):
+ if info.alt_name not in self._by_name:
+ self._by_name[info.alt_name] = info
def __iter__(self) -> Generator[TypeInfo, None, None]:
seen = set()
# autogenerated: end
# fmt: on
]:
- builtins.add(TypeInfo(*r))
+ builtins.add(BuiltinTypeInfo(*r))
def self_update() -> None:
import pytest
from psycopg3.adapt import Format
+from psycopg3.types import builtins
+from psycopg3.types import composite
@pytest.mark.parametrize(
assert res == want
for o1, o2 in zip(res, want):
assert type(o1) is type(o2)
+
+
+@pytest.fixture(scope="session")
+def testcomp(svcconn):
+ cur = svcconn.cursor()
+ cur.execute(
+ """
+ drop type if exists testcomp cascade;
+ create type testcomp as (foo text, bar int8, baz float8);
+ """
+ )
+
+
+def test_fetch_info(conn, testcomp):
+ info = composite.fetch_info(conn, "testcomp")
+ assert info.name == "testcomp"
+ assert info.oid > 0
+ assert info.oid != info.array_oid > 0
+ assert len(info.fields) == 3
+ for i, (name, t) in enumerate(
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")]
+ ):
+ assert info.fields[i].name == name
+ assert info.fields[i].type_oid == builtins[t].oid
+
+
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_composite(conn, testcomp, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
+ info = composite.fetch_info(conn, "testcomp")
+ composite.register(info)
+
+ res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+ assert res.foo == "hello"
+ assert res.bar == 10
+ assert res.baz == 20.0
+ assert isinstance(res.baz, float)
+
+ res = cur.execute(
+ "select array[row('hello', 10, 30)::testcomp]"
+ ).fetchone()[0]
+ assert len(res) == 1
+ assert res[0].baz == 30.0
+ assert isinstance(res[0].baz, float)
+
+
+@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
+def test_cast_composite_factory(conn, testcomp, fmt_out):
+ cur = conn.cursor(binary=fmt_out == Format.BINARY)
+ info = composite.fetch_info(conn, "testcomp")
+
+ class MyThing:
+ def __init__(self, *args):
+ self.foo, self.bar, self.baz = args
+
+ composite.register(info, factory=MyThing)
+
+ res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
+ assert isinstance(res, MyThing)
+ assert res.baz == 20.0
+ assert isinstance(res.baz, float)
+
+ res = cur.execute(
+ "select array[row('hello', 10, 30)::testcomp]"
+ ).fetchone()[0]
+ assert len(res) == 1
+ assert res[0].baz == 30.0
+ assert isinstance(res[0].baz, float)