]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Fetch/register composite made methods of a CompositeInfo class.
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 6 Dec 2020 01:35:16 +0000 (01:35 +0000)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 10 Dec 2020 01:43:50 +0000 (02:43 +0100)
psycopg3/psycopg3/types/composite.py
tests/types/test_composite.py

index f7a79f844c7c029305f77be03d03bb84bd0c14b0..3e5c42d476f0e73718d6a7547130a817eca2eecb 100644 (file)
@@ -7,8 +7,8 @@ Support for composite types adaptation.
 import re
 import struct
 from collections import namedtuple
-from typing import Any, Callable, Iterator, List, Sequence, Tuple, Type, Union
-from typing import Optional, TYPE_CHECKING
+from typing import Any, Callable, Iterator, List, NamedTuple, Optional
+from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING
 
 from .. import sql
 from .. import errors as e
@@ -24,21 +24,91 @@ if TYPE_CHECKING:
 TEXT_OID = builtins["text"].oid
 
 
-class FieldInfo:
-    def __init__(self, name: str, type_oid: int):
-        self.name = name
-        self.type_oid = type_oid
+class CompositeInfo(TypeInfo):
+    """Manage information about a composite type.
 
+    The class allows to:
+
+    - read information about a composite type using `fetch()` and `fetch_async()`
+    - configure a composite type adaptation using `register()`
+    """
 
-class CompositeTypeInfo(TypeInfo):
     def __init__(
-        self, name: str, oid: int, array_oid: int, fields: Sequence[FieldInfo]
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        fields: Sequence["CompositeInfo.FieldInfo"],
     ):
         super().__init__(name, oid, array_oid)
         self.fields = list(fields)
 
+    class FieldInfo(NamedTuple):
+        """Information about a single field in a composite type."""
+
+        name: str
+        type_oid: int
+
+    @classmethod
+    def fetch(
+        cls, conn: "Connection", name: Union[str, sql.Identifier]
+    ) -> Optional["CompositeInfo"]:
+        if isinstance(name, sql.Composable):
+            name = name.as_string(conn)
+        cur = conn.cursor(format=Format.BINARY)
+        cur.execute(cls._info_query, {"name": name})
+        recs = cur.fetchall()
+        return cls._from_records(recs)
+
+    @classmethod
+    async def fetch_async(
+        cls, conn: "AsyncConnection", name: Union[str, sql.Identifier]
+    ) -> Optional["CompositeInfo"]:
+        if isinstance(name, sql.Composable):
+            name = name.as_string(conn)
+        cur = await conn.cursor(format=Format.BINARY)
+        await cur.execute(cls._info_query, {"name": name})
+        recs = await cur.fetchall()
+        return cls._from_records(recs)
+
+    def register(
+        self,
+        context: AdaptContext = None,
+        factory: Optional[Callable[..., Any]] = None,
+    ) -> None:
+        if not factory:
+            factory = namedtuple(  # type: ignore
+                self.name, [f.name for f in self.fields]
+            )
+
+        loader: Type[Loader]
+
+        # generate and register a customized text loader
+        loader = type(
+            f"{self.name.title()}Loader",
+            (CompositeLoader,),
+            {
+                "factory": factory,
+                "fields_types": tuple(f.type_oid for f in self.fields),
+            },
+        )
+        loader.register(self.oid, context=context, format=Format.TEXT)
+
+        # generate and register a customized binary loader
+        loader = type(
+            f"{self.name.title()}BinaryLoader",
+            (CompositeBinaryLoader,),
+            {"factory": factory},
+        )
+        loader.register(self.oid, context=context, format=Format.BINARY)
+
+        if self.array_oid:
+            array.register(
+                self.array_oid, self.oid, context=context, name=self.name
+            )
+
     @classmethod
-    def _from_records(cls, recs: List[Any]) -> Optional["CompositeTypeInfo"]:
+    def _from_records(cls, recs: List[Any]) -> Optional["CompositeInfo"]:
         if not recs:
             return None
         if len(recs) > 1:
@@ -47,70 +117,10 @@ class CompositeTypeInfo(TypeInfo):
             )
 
         name, oid, array_oid, fnames, ftypes = recs[0]
-        fields = [FieldInfo(*p) for p in zip(fnames, ftypes)]
-        return CompositeTypeInfo(name, oid, array_oid, fields)
-
-
-def fetch_info(
-    conn: "Connection", name: Union[str, sql.Identifier]
-) -> Optional[CompositeTypeInfo]:
-    if isinstance(name, sql.Composable):
-        name = name.as_string(conn)
-    cur = conn.cursor(format=Format.BINARY)
-    cur.execute(_type_info_query, {"name": name})
-    recs = cur.fetchall()
-    return CompositeTypeInfo._from_records(recs)
-
-
-async def fetch_info_async(
-    conn: "AsyncConnection", name: Union[str, sql.Identifier]
-) -> Optional[CompositeTypeInfo]:
-    if isinstance(name, sql.Composable):
-        name = name.as_string(conn)
-    cur = await conn.cursor(format=Format.BINARY)
-    await cur.execute(_type_info_query, {"name": name})
-    recs = await cur.fetchall()
-    return CompositeTypeInfo._from_records(recs)
-
-
-def register(
-    info: CompositeTypeInfo,
-    context: AdaptContext = None,
-    factory: Optional[Callable[..., Any]] = None,
-) -> None:
-    if not factory:
-        factory = namedtuple(  # type: ignore
-            info.name, [f.name for f in info.fields]
-        )
-
-    loader: Type[Loader]
-
-    # generate and register a customized text loader
-    loader = type(
-        f"{info.name.title()}Loader",
-        (CompositeLoader,),
-        {
-            "factory": factory,
-            "fields_types": tuple(f.type_oid for f in info.fields),
-        },
-    )
-    loader.register(info.oid, context=context, format=Format.TEXT)
-
-    # generate and register a customized binary loader
-    loader = type(
-        f"{info.name.title()}BinaryLoader",
-        (CompositeBinaryLoader,),
-        {"factory": factory},
-    )
-    loader.register(info.oid, context=context, format=Format.BINARY)
-
-    if info.array_oid:
-        array.register(
-            info.array_oid, info.oid, context=context, name=info.name
-        )
-
+        fields = [cls.FieldInfo(*p) for p in zip(fnames, ftypes)]
+        return cls(name, oid, array_oid, fields)
 
-_type_info_query = """\
+    _info_query = """\
 select
     t.typname as name, t.oid as oid, t.typarray as array_oid,
     coalesce(a.fnames, '{}') as fnames,
index a35f065631df621beaa1133226013c115b47e1b1..94721eda65d7869d230416d1b8c00f22dbc75ee6 100644 (file)
@@ -3,7 +3,7 @@ import pytest
 from psycopg3.sql import Identifier
 from psycopg3.oids import builtins
 from psycopg3.adapt import Format, Loader
-from psycopg3.types import composite
+from psycopg3.types.composite import CompositeInfo
 
 
 tests_str = [
@@ -37,8 +37,8 @@ def test_dump_tuple(conn, rec, obj):
         create type tmptype as ({', '.join(fields)});
         """
     )
-    info = composite.fetch_info(conn, "tmptype")
-    composite.register(info, context=conn)
+    info = CompositeInfo.fetch(conn, "tmptype")
+    info.register(context=conn)
 
     res = cur.execute("select %s::tmptype", [obj]).fetchone()[0]
     assert res == obj
@@ -108,29 +108,29 @@ def testcomp(svcconn):
     )
 
 
-@pytest.mark.parametrize(
-    "name, fields",
-    [
-        (
-            "testcomp",
-            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
-        ),
-        (
-            "testschema.testcomp",
-            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
-        ),
-        (
-            Identifier("testcomp"),
-            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
-        ),
-        (
-            Identifier("testschema", "testcomp"),
-            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
-        ),
-    ],
-)
+fetch_cases = [
+    (
+        "testcomp",
+        [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+    ),
+    (
+        "testschema.testcomp",
+        [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+    ),
+    (
+        Identifier("testcomp"),
+        [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+    ),
+    (
+        Identifier("testschema", "testcomp"),
+        [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+    ),
+]
+
+
+@pytest.mark.parametrize("name, fields", fetch_cases)
 def test_fetch_info(conn, testcomp, name, fields):
-    info = composite.fetch_info(conn, name)
+    info = CompositeInfo.fetch(conn, name)
     assert info.name == "testcomp"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
@@ -140,30 +140,10 @@ def test_fetch_info(conn, testcomp, name, fields):
         assert info.fields[i].type_oid == builtins[t].oid
 
 
-@pytest.mark.parametrize(
-    "name, fields",
-    [
-        (
-            "testcomp",
-            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
-        ),
-        (
-            "testschema.testcomp",
-            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
-        ),
-        (
-            Identifier("testcomp"),
-            [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
-        ),
-        (
-            Identifier("testschema", "testcomp"),
-            [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
-        ),
-    ],
-)
 @pytest.mark.asyncio
+@pytest.mark.parametrize("name, fields", fetch_cases)
 async def test_fetch_info_async(aconn, testcomp, name, fields):
-    info = await composite.fetch_info_async(aconn, name)
+    info = await CompositeInfo.fetch_async(aconn, name)
     assert info.name == "testcomp"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
@@ -190,8 +170,8 @@ def test_dump_composite_all_chars(conn, fmt_in, testcomp):
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_load_composite(conn, testcomp, fmt_out):
     cur = conn.cursor(format=fmt_out)
-    info = composite.fetch_info(conn, "testcomp")
-    composite.register(info, conn)
+    info = CompositeInfo.fetch(conn, "testcomp")
+    info.register(conn)
 
     res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
     assert res.foo == "hello"
@@ -210,13 +190,13 @@ def test_load_composite(conn, testcomp, fmt_out):
 @pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
 def test_load_composite_factory(conn, testcomp, fmt_out):
     cur = conn.cursor(format=fmt_out)
-    info = composite.fetch_info(conn, "testcomp")
+    info = CompositeInfo.fetch(conn, "testcomp")
 
     class MyThing:
         def __init__(self, *args):
             self.foo, self.bar, self.baz = args
 
-    composite.register(info, conn, factory=MyThing)
+    info.register(conn, factory=MyThing)
 
     res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
     assert isinstance(res, MyThing)
@@ -232,15 +212,14 @@ def test_load_composite_factory(conn, testcomp, fmt_out):
 
 
 def test_register_scope(conn):
-    info = composite.fetch_info(conn, "testcomp")
-
-    composite.register(info)
+    info = CompositeInfo.fetch(conn, "testcomp")
+    info.register()
     for fmt in (Format.TEXT, Format.BINARY):
         for oid in (info.oid, info.array_oid):
             assert Loader.globals.pop((oid, fmt))
 
     cur = conn.cursor()
-    composite.register(info, cur)
+    info.register(cur)
     for fmt in (Format.TEXT, Format.BINARY):
         for oid in (info.oid, info.array_oid):
             key = oid, fmt
@@ -248,7 +227,7 @@ def test_register_scope(conn):
             assert key not in conn.loaders
             assert key in cur.loaders
 
-    composite.register(info, conn)
+    info.register(conn)
     for fmt in (Format.TEXT, Format.BINARY):
         for oid in (info.oid, info.array_oid):
             key = oid, fmt