From: Daniele Varrazzo Date: Sun, 6 Dec 2020 01:35:16 +0000 (+0000) Subject: Fetch/register composite made methods of a CompositeInfo class. X-Git-Tag: 3.0.dev0~274^2~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=20a16dafd32fe7bd139273ccc790a7cd6061905c;p=thirdparty%2Fpsycopg.git Fetch/register composite made methods of a CompositeInfo class. --- diff --git a/psycopg3/psycopg3/types/composite.py b/psycopg3/psycopg3/types/composite.py index f7a79f844..3e5c42d47 100644 --- a/psycopg3/psycopg3/types/composite.py +++ b/psycopg3/psycopg3/types/composite.py @@ -7,8 +7,8 @@ Support for composite types adaptation. import re import struct from collections import namedtuple -from typing import Any, Callable, Iterator, List, Sequence, Tuple, Type, Union -from typing import Optional, TYPE_CHECKING +from typing import Any, Callable, Iterator, List, NamedTuple, Optional +from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING from .. import sql from .. import errors as e @@ -24,21 +24,91 @@ if TYPE_CHECKING: TEXT_OID = builtins["text"].oid -class FieldInfo: - def __init__(self, name: str, type_oid: int): - self.name = name - self.type_oid = type_oid +class CompositeInfo(TypeInfo): + """Manage information about a composite type. + The class allows to: + + - read information about a composite type using `fetch()` and `fetch_async()` + - configure a composite type adaptation using `register()` + """ -class CompositeTypeInfo(TypeInfo): def __init__( - self, name: str, oid: int, array_oid: int, fields: Sequence[FieldInfo] + self, + name: str, + oid: int, + array_oid: int, + fields: Sequence["CompositeInfo.FieldInfo"], ): super().__init__(name, oid, array_oid) self.fields = list(fields) + class FieldInfo(NamedTuple): + """Information about a single field in a composite type.""" + + name: str + type_oid: int + + @classmethod + def fetch( + cls, conn: "Connection", name: Union[str, sql.Identifier] + ) -> Optional["CompositeInfo"]: + if isinstance(name, sql.Composable): + name = name.as_string(conn) + cur = conn.cursor(format=Format.BINARY) + cur.execute(cls._info_query, {"name": name}) + recs = cur.fetchall() + return cls._from_records(recs) + + @classmethod + async def fetch_async( + cls, conn: "AsyncConnection", name: Union[str, sql.Identifier] + ) -> Optional["CompositeInfo"]: + if isinstance(name, sql.Composable): + name = name.as_string(conn) + cur = await conn.cursor(format=Format.BINARY) + await cur.execute(cls._info_query, {"name": name}) + recs = await cur.fetchall() + return cls._from_records(recs) + + def register( + self, + context: AdaptContext = None, + factory: Optional[Callable[..., Any]] = None, + ) -> None: + if not factory: + factory = namedtuple( # type: ignore + self.name, [f.name for f in self.fields] + ) + + loader: Type[Loader] + + # generate and register a customized text loader + loader = type( + f"{self.name.title()}Loader", + (CompositeLoader,), + { + "factory": factory, + "fields_types": tuple(f.type_oid for f in self.fields), + }, + ) + loader.register(self.oid, context=context, format=Format.TEXT) + + # generate and register a customized binary loader + loader = type( + f"{self.name.title()}BinaryLoader", + (CompositeBinaryLoader,), + {"factory": factory}, + ) + loader.register(self.oid, context=context, format=Format.BINARY) + + if self.array_oid: + array.register( + self.array_oid, self.oid, context=context, name=self.name + ) + @classmethod - def _from_records(cls, recs: List[Any]) -> Optional["CompositeTypeInfo"]: + def _from_records(cls, recs: List[Any]) -> Optional["CompositeInfo"]: if not recs: return None if len(recs) > 1: @@ -47,70 +117,10 @@ class CompositeTypeInfo(TypeInfo): ) name, oid, array_oid, fnames, ftypes = recs[0] - fields = [FieldInfo(*p) for p in zip(fnames, ftypes)] - return CompositeTypeInfo(name, oid, array_oid, fields) - - -def fetch_info( - conn: "Connection", name: Union[str, sql.Identifier] -) -> Optional[CompositeTypeInfo]: - if isinstance(name, sql.Composable): - name = name.as_string(conn) - cur = conn.cursor(format=Format.BINARY) - cur.execute(_type_info_query, {"name": name}) - recs = cur.fetchall() - return CompositeTypeInfo._from_records(recs) - - -async def fetch_info_async( - conn: "AsyncConnection", name: Union[str, sql.Identifier] -) -> Optional[CompositeTypeInfo]: - if isinstance(name, sql.Composable): - name = name.as_string(conn) - cur = await conn.cursor(format=Format.BINARY) - await cur.execute(_type_info_query, {"name": name}) - recs = await cur.fetchall() - return CompositeTypeInfo._from_records(recs) - - -def register( - info: CompositeTypeInfo, - context: AdaptContext = None, - factory: Optional[Callable[..., Any]] = None, -) -> None: - if not factory: - factory = namedtuple( # type: ignore - info.name, [f.name for f in info.fields] - ) - - loader: Type[Loader] - - # generate and register a customized text loader - loader = type( - f"{info.name.title()}Loader", - (CompositeLoader,), - { - "factory": factory, - "fields_types": tuple(f.type_oid for f in info.fields), - }, - ) - loader.register(info.oid, context=context, format=Format.TEXT) - - # generate and register a customized binary loader - loader = type( - f"{info.name.title()}BinaryLoader", - (CompositeBinaryLoader,), - {"factory": factory}, - ) - loader.register(info.oid, context=context, format=Format.BINARY) - - if info.array_oid: - array.register( - info.array_oid, info.oid, context=context, name=info.name - ) - + fields = [cls.FieldInfo(*p) for p in zip(fnames, ftypes)] + return cls(name, oid, array_oid, fields) -_type_info_query = """\ + _info_query = """\ select t.typname as name, t.oid as oid, t.typarray as array_oid, coalesce(a.fnames, '{}') as fnames, diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index a35f06563..94721eda6 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -3,7 +3,7 @@ import pytest from psycopg3.sql import Identifier from psycopg3.oids import builtins from psycopg3.adapt import Format, Loader -from psycopg3.types import composite +from psycopg3.types.composite import CompositeInfo tests_str = [ @@ -37,8 +37,8 @@ def test_dump_tuple(conn, rec, obj): create type tmptype as ({', '.join(fields)}); """ ) - info = composite.fetch_info(conn, "tmptype") - composite.register(info, context=conn) + info = CompositeInfo.fetch(conn, "tmptype") + info.register(context=conn) res = cur.execute("select %s::tmptype", [obj]).fetchone()[0] assert res == obj @@ -108,29 +108,29 @@ def testcomp(svcconn): ) -@pytest.mark.parametrize( - "name, fields", - [ - ( - "testcomp", - [("foo", "text"), ("bar", "int8"), ("baz", "float8")], - ), - ( - "testschema.testcomp", - [("foo", "text"), ("bar", "int8"), ("qux", "bool")], - ), - ( - Identifier("testcomp"), - [("foo", "text"), ("bar", "int8"), ("baz", "float8")], - ), - ( - Identifier("testschema", "testcomp"), - [("foo", "text"), ("bar", "int8"), ("qux", "bool")], - ), - ], -) +fetch_cases = [ + ( + "testcomp", + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + "testschema.testcomp", + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), + ( + Identifier("testcomp"), + [("foo", "text"), ("bar", "int8"), ("baz", "float8")], + ), + ( + Identifier("testschema", "testcomp"), + [("foo", "text"), ("bar", "int8"), ("qux", "bool")], + ), +] + + +@pytest.mark.parametrize("name, fields", fetch_cases) def test_fetch_info(conn, testcomp, name, fields): - info = composite.fetch_info(conn, name) + info = CompositeInfo.fetch(conn, name) assert info.name == "testcomp" assert info.oid > 0 assert info.oid != info.array_oid > 0 @@ -140,30 +140,10 @@ def test_fetch_info(conn, testcomp, name, fields): assert info.fields[i].type_oid == builtins[t].oid -@pytest.mark.parametrize( - "name, fields", - [ - ( - "testcomp", - [("foo", "text"), ("bar", "int8"), ("baz", "float8")], - ), - ( - "testschema.testcomp", - [("foo", "text"), ("bar", "int8"), ("qux", "bool")], - ), - ( - Identifier("testcomp"), - [("foo", "text"), ("bar", "int8"), ("baz", "float8")], - ), - ( - Identifier("testschema", "testcomp"), - [("foo", "text"), ("bar", "int8"), ("qux", "bool")], - ), - ], -) @pytest.mark.asyncio +@pytest.mark.parametrize("name, fields", fetch_cases) async def test_fetch_info_async(aconn, testcomp, name, fields): - info = await composite.fetch_info_async(aconn, name) + info = await CompositeInfo.fetch_async(aconn, name) assert info.name == "testcomp" assert info.oid > 0 assert info.oid != info.array_oid > 0 @@ -190,8 +170,8 @@ def test_dump_composite_all_chars(conn, fmt_in, testcomp): @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) def test_load_composite(conn, testcomp, fmt_out): cur = conn.cursor(format=fmt_out) - info = composite.fetch_info(conn, "testcomp") - composite.register(info, conn) + info = CompositeInfo.fetch(conn, "testcomp") + info.register(conn) res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] assert res.foo == "hello" @@ -210,13 +190,13 @@ def test_load_composite(conn, testcomp, fmt_out): @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY]) def test_load_composite_factory(conn, testcomp, fmt_out): cur = conn.cursor(format=fmt_out) - info = composite.fetch_info(conn, "testcomp") + info = CompositeInfo.fetch(conn, "testcomp") class MyThing: def __init__(self, *args): self.foo, self.bar, self.baz = args - composite.register(info, conn, factory=MyThing) + info.register(conn, factory=MyThing) res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0] assert isinstance(res, MyThing) @@ -232,15 +212,14 @@ def test_load_composite_factory(conn, testcomp, fmt_out): def test_register_scope(conn): - info = composite.fetch_info(conn, "testcomp") - - composite.register(info) + info = CompositeInfo.fetch(conn, "testcomp") + info.register() for fmt in (Format.TEXT, Format.BINARY): for oid in (info.oid, info.array_oid): assert Loader.globals.pop((oid, fmt)) cur = conn.cursor() - composite.register(info, cur) + info.register(cur) for fmt in (Format.TEXT, Format.BINARY): for oid in (info.oid, info.array_oid): key = oid, fmt @@ -248,7 +227,7 @@ def test_register_scope(conn): assert key not in conn.loaders assert key in cur.loaders - composite.register(info, conn) + info.register(conn) for fmt in (Format.TEXT, Format.BINARY): for oid in (info.oid, info.array_oid): key = oid, fmt