From: Daniele Varrazzo Date: Wed, 8 Apr 2020 09:00:16 +0000 (+1200) Subject: Added typecasting of composite types X-Git-Tag: 3.0.dev0~592 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=ae6adc37c26f7aeab7b6f5ae85da52a4dff91097;p=thirdparty%2Fpsycopg.git Added typecasting of composite types --- diff --git a/psycopg3/types/__init__.py b/psycopg3/types/__init__.py index 7b2556dfd..f578c651c 100644 --- a/psycopg3/types/__init__.py +++ b/psycopg3/types/__init__.py @@ -8,9 +8,10 @@ psycopg3 types package from .oids import builtins # Register default adapters -from . import composite, numeric, text # noqa +from . import array, composite, numeric, text # noqa # Register associations with array oids -from . import array # noqa +array.register_all_arrays() + __all__ = ["builtins"] diff --git a/psycopg3/types/array.py b/psycopg3/types/array.py index 5da427192..13318cd17 100644 --- a/psycopg3/types/array.py +++ b/psycopg3/types/array.py @@ -292,10 +292,16 @@ def register_array( TypeCaster.register(array_oid, t, context=context, format=format) -# Register associations between array and base oids -for t in builtins: - if t.array_oid and ( - (t.oid, Format.TEXT) in TypeCaster.globals - or (t.oid, Format.BINARY) in TypeCaster.globals - ): - register_array(t.array_oid, t.oid, name=t.name) +def register_all_arrays() -> None: + """ + Associate the array oid of all the types in TypeCaster.globals. + + This function is designed to be called once at import time, after having + registered all the base casters. + """ + for t in builtins: + if t.array_oid and ( + (t.oid, Format.TEXT) in TypeCaster.globals + or (t.oid, Format.BINARY) in TypeCaster.globals + ): + register_array(t.array_oid, t.oid, name=t.name) diff --git a/psycopg3/types/composite.py b/psycopg3/types/composite.py index f663d4e57..a40498cba 100644 --- a/psycopg3/types/composite.py +++ b/psycopg3/types/composite.py @@ -4,39 +4,125 @@ Support for composite types adaptation. import re import struct -from typing import Any, Generator, Optional, Tuple +from collections import namedtuple +from typing import Any, Callable, Generator, List, Sequence, Tuple, Union +from typing import Optional, TYPE_CHECKING from ..adapt import Format, TypeCaster, Transformer, AdaptContext -from .oids import builtins +from .oids import builtins, TypeInfo +from .array import register_array + +if TYPE_CHECKING: + from ..connection import Connection TEXT_OID = builtins["text"].oid -_re_tokenize = re.compile( - br"""(?x) - \(? ([,)]) # an empty token, representing NULL - | \(? " ((?: [^"] | "")*) " [,)] # or a quoted string - | \(? ([^",)]+) [,)] # or an unquoted string - """ -) +class FieldInfo: + def __init__(self, name: str, type_oid: int): + self.name = name + self.type_oid = type_oid + + +class CompositeTypeInfo(TypeInfo): + def __init__( + self, + name: str, + oid: int, + array_oid: int, + fields: Sequence[Union[FieldInfo, Tuple[str, int]]], + ): + super().__init__(name, oid, array_oid) + self.fields: List[FieldInfo] = [] + for f in fields: + if isinstance(f, FieldInfo): + self.fields.append(f) + elif isinstance(f, tuple): + self.fields.append(FieldInfo(f[0], f[1])) + else: + raise TypeError(f"bad field info: {f}") + + +def fetch_info(conn: "Connection", name: str) -> Optional[CompositeTypeInfo]: + cur = conn.cursor(binary=True) + cur.execute(_type_info_query, (name,)) + rec = cur.fetchone() + if rec is not None: + return CompositeTypeInfo(*rec) + else: + return None + + +def register( + info: CompositeTypeInfo, + context: AdaptContext = None, + factory: Optional[Callable[..., Any]] = None, +) -> None: + if factory is None: + factory = namedtuple( # type: ignore + info.name, [f.name for f in info.fields] + ) -_re_undouble = re.compile(br'(["\\])\1') + # generate and register a customized text typecaster + caster = type( + f"{info.name.title()}Caster", + (CompositeCaster,), + { + "factory": factory, + "fields_types": tuple(f.type_oid for f in info.fields), + }, + ) + TypeCaster.register(info.oid, caster, context=context, format=Format.TEXT) + + # generate and register a customized binary typecaster + caster = type( + f"{info.name.title()}BinaryCaster", + (CompositeBinaryCaster,), + {"factory": factory}, + ) + TypeCaster.register( + info.oid, caster, context=context, format=Format.BINARY + ) + + if info.array_oid: + register_array(info.array_oid, info.oid) + + +_type_info_query = """\ +select + name, oid, array_oid, + array_agg(row(field_name, field_type)) as fields +from ( + select + typname as name, + t.oid as oid, + t.typarray as array_oid, + a.attname as field_name, + a.atttypid as field_type + from pg_type t + left join pg_attribute a on a.attrelid = t.typrelid + where t.typname = %s + and a.attnum > 0 + order by a.attnum +) x +group by name, oid, array_oid +""" class BaseCompositeCaster(TypeCaster): def __init__(self, oid: int, context: AdaptContext = None): super().__init__(oid, context) - self.tx = Transformer(context) + self._tx = Transformer(context) @TypeCaster.text(builtins["record"].oid) class RecordCaster(BaseCompositeCaster): def cast(self, data: bytes) -> Tuple[Any, ...]: - cast = self.tx.get_cast_function(TEXT_OID, format=Format.TEXT) + cast = self._tx.get_cast_function(TEXT_OID, format=Format.TEXT) return tuple( - cast(item) if item is not None else None - for item in self._parse_record(data) + cast(token) if token is not None else None + for token in self._parse_record(data) ) def _parse_record( @@ -45,21 +131,31 @@ class RecordCaster(BaseCompositeCaster): if data == b"()": return - for m in _re_tokenize.finditer(data): + for m in self._re_tokenize.finditer(data): if m.group(1) is not None: yield None elif m.group(2) is not None: - yield _re_undouble.sub(br"\1", m.group(2)) + yield self._re_undouble.sub(br"\1", m.group(2)) else: yield m.group(3) + _re_tokenize = re.compile( + br"""(?x) + \(? ([,)]) # an empty token, representing NULL + | \(? " ((?: [^"] | "")*) " [,)] # or a quoted string + | \(? ([^",)]+) [,)] # or an unquoted string + """ + ) + + _re_undouble = re.compile(br'(["\\])\1') + _struct_len = struct.Struct("!i") _struct_oidlen = struct.Struct("!Ii") @TypeCaster.binary(builtins["record"].oid) -class BinaryRecordCaster(BaseCompositeCaster): +class RecordBinaryCaster(BaseCompositeCaster): _types_set = False def cast(self, data: bytes) -> Tuple[Any, ...]: @@ -68,7 +164,7 @@ class BinaryRecordCaster(BaseCompositeCaster): self._types_set = True return tuple( - self.tx.cast_sequence( + self._tx.cast_sequence( data[offset : offset + length] if length != -1 else None for _, offset, length in self._walk_record(data) ) @@ -88,6 +184,32 @@ class BinaryRecordCaster(BaseCompositeCaster): i += (8 + length) if length > 0 else 8 def _config_types(self, data: bytes) -> None: - self.tx.set_row_types( + self._tx.set_row_types( (oid, Format.BINARY) for oid, _, _ in self._walk_record(data) ) + + +class CompositeCaster(RecordCaster): + factory: Callable[..., Any] + fields_types: Tuple[int, ...] + _types_set = False + + def cast(self, data: bytes) -> Any: + if not self._types_set: + self._config_types(data) + self._types_set = True + + return type(self).factory( + *self._tx.cast_sequence(self._parse_record(data)) + ) + + def _config_types(self, data: bytes) -> None: + self._tx.set_row_types((oid, Format.TEXT) for oid in self.fields_types) + + +class CompositeBinaryCaster(RecordBinaryCaster): + factory: Callable[..., Any] + + def cast(self, data: bytes) -> Any: + r = super().cast(data) + return type(self).factory(*r) diff --git a/psycopg3/types/oids.py b/psycopg3/types/oids.py index 1b211fd65..39a9d4206 100644 --- a/psycopg3/types/oids.py +++ b/psycopg3/types/oids.py @@ -8,17 +8,30 @@ to a Postgres server. # Copyright (C) 2020 The Psycopg Team import re -from typing import Dict, Generator, Optional, NamedTuple, Union +from typing import Dict, Generator, Optional, Union INVALID_OID = 0 -class TypeInfo(NamedTuple): - name: str - oid: int - array_oid: int - alt_name: str - delimiter: str +class TypeInfo: + def __init__(self, name: str, oid: int, array_oid: int): + self.name = name + self.oid = oid + self.array_oid = array_oid + + +class BuiltinTypeInfo(TypeInfo): + def __init__( + self, + name: str, + oid: int, + array_oid: int, + alt_name: str, + delimiter: str, + ): + super().__init__(name, oid, array_oid) + self.alt_name = alt_name + self.delimiter = delimiter class TypesRegistry: @@ -35,8 +48,10 @@ class TypesRegistry: if info.array_oid: self._by_oid[info.array_oid] = info self._by_name[info.name] = info - if info.alt_name not in self._by_name: - self._by_name[info.alt_name] = info + + if isinstance(info, BuiltinTypeInfo): + if info.alt_name not in self._by_name: + self._by_name[info.alt_name] = info def __iter__(self) -> Generator[TypeInfo, None, None]: seen = set() @@ -143,7 +158,7 @@ for r in [ # autogenerated: end # fmt: on ]: - builtins.add(TypeInfo(*r)) + builtins.add(BuiltinTypeInfo(*r)) def self_update() -> None: diff --git a/tests/fix_db.py b/tests/fix_db.py index 8487fedb8..05743e382 100644 --- a/tests/fix_db.py +++ b/tests/fix_db.py @@ -87,7 +87,7 @@ def check_libpq_version(got, want): ) -@pytest.fixture +@pytest.fixture(scope="session") def dsn(request): """Return the dsn used to connect to the `--test-dsn` database.""" dsn = request.config.getoption("--test-dsn") @@ -113,3 +113,13 @@ def conn(dsn): from psycopg3 import Connection return Connection.connect(dsn) + + +@pytest.fixture(scope="session") +def svcconn(dsn): + """ + Return a session `Connection` connected to the ``--test-dsn`` database. + """ + from psycopg3 import Connection + + return Connection.connect(dsn) diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 9406e175c..cf04e38e7 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -1,6 +1,8 @@ import pytest from psycopg3.adapt import Format +from psycopg3.types import builtins +from psycopg3.types import composite @pytest.mark.parametrize( @@ -73,3 +75,71 @@ def test_cast_record_binary(conn, want, rec): assert res == want for o1, o2 in zip(res, want): assert type(o1) is type(o2) + + +@pytest.fixture(scope="session") +def testcomp(svcconn): + cur = svcconn.cursor() + cur.execute( + """ + drop type if exists testcomp cascade; + create type testcomp as (foo text, bar int8, baz float8); + """ + ) + + +def test_fetch_info(conn, testcomp): + info = composite.fetch_info(conn, "testcomp") + assert info.name == "testcomp" + assert info.oid > 0 + assert info.oid != info.array_oid > 0 + assert len(info.fields) == 3 + for i, (name, t) in enumerate( + [("foo", "text"), ("bar", "int8"), ("baz", "float8")] + ): + assert info.fields[i].name == name + assert info.fields[i].type_oid == builtins[t].oid + + +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_composite(conn, testcomp, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) + info = composite.fetch_info(conn, "testcomp") + composite.register(info) + + res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] + assert res.foo == "hello" + assert res.bar == 10 + assert res.baz == 20.0 + assert isinstance(res.baz, float) + + res = cur.execute( + "select array[row('hello', 10, 30)::testcomp]" + ).fetchone()[0] + assert len(res) == 1 + assert res[0].baz == 30.0 + assert isinstance(res[0].baz, float) + + +@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) +def test_cast_composite_factory(conn, testcomp, fmt_out): + cur = conn.cursor(binary=fmt_out == Format.BINARY) + info = composite.fetch_info(conn, "testcomp") + + class MyThing: + def __init__(self, *args): + self.foo, self.bar, self.baz = args + + composite.register(info, factory=MyThing) + + res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] + assert isinstance(res, MyThing) + assert res.baz == 20.0 + assert isinstance(res.baz, float) + + res = cur.execute( + "select array[row('hello', 10, 30)::testcomp]" + ).fetchone()[0] + assert len(res) == 1 + assert res[0].baz == 30.0 + assert isinstance(res[0].baz, float)