import re
import struct
from collections import namedtuple
-from typing import Any, Callable, Iterator, List, Sequence, Tuple, Type, Union
-from typing import Optional, TYPE_CHECKING
+from typing import Any, Callable, Iterator, List, NamedTuple, Optional
+from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING
from .. import sql
from .. import errors as e
TEXT_OID = builtins["text"].oid
-class FieldInfo:
- def __init__(self, name: str, type_oid: int):
- self.name = name
- self.type_oid = type_oid
+class CompositeInfo(TypeInfo):
+ """Manage information about a composite type.
+ The class allows to:
+
+ - read information about a composite type using `fetch()` and `fetch_async()`
+ - configure a composite type adaptation using `register()`
+ """
-class CompositeTypeInfo(TypeInfo):
def __init__(
- self, name: str, oid: int, array_oid: int, fields: Sequence[FieldInfo]
+ self,
+ name: str,
+ oid: int,
+ array_oid: int,
+ fields: Sequence["CompositeInfo.FieldInfo"],
):
super().__init__(name, oid, array_oid)
self.fields = list(fields)
+ class FieldInfo(NamedTuple):
+ """Information about a single field in a composite type."""
+
+ name: str
+ type_oid: int
+
+ @classmethod
+ def fetch(
+ cls, conn: "Connection", name: Union[str, sql.Identifier]
+ ) -> Optional["CompositeInfo"]:
+ if isinstance(name, sql.Composable):
+ name = name.as_string(conn)
+ cur = conn.cursor(format=Format.BINARY)
+ cur.execute(cls._info_query, {"name": name})
+ recs = cur.fetchall()
+ return cls._from_records(recs)
+
+ @classmethod
+ async def fetch_async(
+ cls, conn: "AsyncConnection", name: Union[str, sql.Identifier]
+ ) -> Optional["CompositeInfo"]:
+ if isinstance(name, sql.Composable):
+ name = name.as_string(conn)
+ cur = await conn.cursor(format=Format.BINARY)
+ await cur.execute(cls._info_query, {"name": name})
+ recs = await cur.fetchall()
+ return cls._from_records(recs)
+
+ def register(
+ self,
+ context: AdaptContext = None,
+ factory: Optional[Callable[..., Any]] = None,
+ ) -> None:
+ if not factory:
+ factory = namedtuple( # type: ignore
+ self.name, [f.name for f in self.fields]
+ )
+
+ loader: Type[Loader]
+
+ # generate and register a customized text loader
+ loader = type(
+ f"{self.name.title()}Loader",
+ (CompositeLoader,),
+ {
+ "factory": factory,
+ "fields_types": tuple(f.type_oid for f in self.fields),
+ },
+ )
+ loader.register(self.oid, context=context, format=Format.TEXT)
+
+ # generate and register a customized binary loader
+ loader = type(
+ f"{self.name.title()}BinaryLoader",
+ (CompositeBinaryLoader,),
+ {"factory": factory},
+ )
+ loader.register(self.oid, context=context, format=Format.BINARY)
+
+ if self.array_oid:
+ array.register(
+ self.array_oid, self.oid, context=context, name=self.name
+ )
+
@classmethod
- def _from_records(cls, recs: List[Any]) -> Optional["CompositeTypeInfo"]:
+ def _from_records(cls, recs: List[Any]) -> Optional["CompositeInfo"]:
if not recs:
return None
if len(recs) > 1:
)
name, oid, array_oid, fnames, ftypes = recs[0]
- fields = [FieldInfo(*p) for p in zip(fnames, ftypes)]
- return CompositeTypeInfo(name, oid, array_oid, fields)
-
-
-def fetch_info(
- conn: "Connection", name: Union[str, sql.Identifier]
-) -> Optional[CompositeTypeInfo]:
- if isinstance(name, sql.Composable):
- name = name.as_string(conn)
- cur = conn.cursor(format=Format.BINARY)
- cur.execute(_type_info_query, {"name": name})
- recs = cur.fetchall()
- return CompositeTypeInfo._from_records(recs)
-
-
-async def fetch_info_async(
- conn: "AsyncConnection", name: Union[str, sql.Identifier]
-) -> Optional[CompositeTypeInfo]:
- if isinstance(name, sql.Composable):
- name = name.as_string(conn)
- cur = await conn.cursor(format=Format.BINARY)
- await cur.execute(_type_info_query, {"name": name})
- recs = await cur.fetchall()
- return CompositeTypeInfo._from_records(recs)
-
-
-def register(
- info: CompositeTypeInfo,
- context: AdaptContext = None,
- factory: Optional[Callable[..., Any]] = None,
-) -> None:
- if not factory:
- factory = namedtuple( # type: ignore
- info.name, [f.name for f in info.fields]
- )
-
- loader: Type[Loader]
-
- # generate and register a customized text loader
- loader = type(
- f"{info.name.title()}Loader",
- (CompositeLoader,),
- {
- "factory": factory,
- "fields_types": tuple(f.type_oid for f in info.fields),
- },
- )
- loader.register(info.oid, context=context, format=Format.TEXT)
-
- # generate and register a customized binary loader
- loader = type(
- f"{info.name.title()}BinaryLoader",
- (CompositeBinaryLoader,),
- {"factory": factory},
- )
- loader.register(info.oid, context=context, format=Format.BINARY)
-
- if info.array_oid:
- array.register(
- info.array_oid, info.oid, context=context, name=info.name
- )
-
+ fields = [cls.FieldInfo(*p) for p in zip(fnames, ftypes)]
+ return cls(name, oid, array_oid, fields)
-_type_info_query = """\
+ _info_query = """\
select
t.typname as name, t.oid as oid, t.typarray as array_oid,
coalesce(a.fnames, '{}') as fnames,
from psycopg3.sql import Identifier
from psycopg3.oids import builtins
from psycopg3.adapt import Format, Loader
-from psycopg3.types import composite
+from psycopg3.types.composite import CompositeInfo
tests_str = [
create type tmptype as ({', '.join(fields)});
"""
)
- info = composite.fetch_info(conn, "tmptype")
- composite.register(info, context=conn)
+ info = CompositeInfo.fetch(conn, "tmptype")
+ info.register(context=conn)
res = cur.execute("select %s::tmptype", [obj]).fetchone()[0]
assert res == obj
)
-@pytest.mark.parametrize(
- "name, fields",
- [
- (
- "testcomp",
- [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
- ),
- (
- "testschema.testcomp",
- [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
- ),
- (
- Identifier("testcomp"),
- [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
- ),
- (
- Identifier("testschema", "testcomp"),
- [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
- ),
- ],
-)
+fetch_cases = [
+ (
+ "testcomp",
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ "testschema.testcomp",
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+ (
+ Identifier("testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
+ ),
+ (
+ Identifier("testschema", "testcomp"),
+ [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
+ ),
+]
+
+
+@pytest.mark.parametrize("name, fields", fetch_cases)
def test_fetch_info(conn, testcomp, name, fields):
- info = composite.fetch_info(conn, name)
+ info = CompositeInfo.fetch(conn, name)
assert info.name == "testcomp"
assert info.oid > 0
assert info.oid != info.array_oid > 0
assert info.fields[i].type_oid == builtins[t].oid
-@pytest.mark.parametrize(
- "name, fields",
- [
- (
- "testcomp",
- [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
- ),
- (
- "testschema.testcomp",
- [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
- ),
- (
- Identifier("testcomp"),
- [("foo", "text"), ("bar", "int8"), ("baz", "float8")],
- ),
- (
- Identifier("testschema", "testcomp"),
- [("foo", "text"), ("bar", "int8"), ("qux", "bool")],
- ),
- ],
-)
@pytest.mark.asyncio
+@pytest.mark.parametrize("name, fields", fetch_cases)
async def test_fetch_info_async(aconn, testcomp, name, fields):
- info = await composite.fetch_info_async(aconn, name)
+ info = await CompositeInfo.fetch_async(aconn, name)
assert info.name == "testcomp"
assert info.oid > 0
assert info.oid != info.array_oid > 0
@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
def test_load_composite(conn, testcomp, fmt_out):
cur = conn.cursor(format=fmt_out)
- info = composite.fetch_info(conn, "testcomp")
- composite.register(info, conn)
+ info = CompositeInfo.fetch(conn, "testcomp")
+ info.register(conn)
res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
assert res.foo == "hello"
@pytest.mark.parametrize("fmt_out", [Format.TEXT, Format.BINARY])
def test_load_composite_factory(conn, testcomp, fmt_out):
cur = conn.cursor(format=fmt_out)
- info = composite.fetch_info(conn, "testcomp")
+ info = CompositeInfo.fetch(conn, "testcomp")
class MyThing:
def __init__(self, *args):
self.foo, self.bar, self.baz = args
- composite.register(info, conn, factory=MyThing)
+ info.register(conn, factory=MyThing)
res = cur.execute("select row('hello', 10, 20)::testcomp").fetchone()[0]
assert isinstance(res, MyThing)
def test_register_scope(conn):
- info = composite.fetch_info(conn, "testcomp")
-
- composite.register(info)
+ info = CompositeInfo.fetch(conn, "testcomp")
+ info.register()
for fmt in (Format.TEXT, Format.BINARY):
for oid in (info.oid, info.array_oid):
assert Loader.globals.pop((oid, fmt))
cur = conn.cursor()
- composite.register(info, cur)
+ info.register(cur)
for fmt in (Format.TEXT, Format.BINARY):
for oid in (info.oid, info.array_oid):
key = oid, fmt
assert key not in conn.loaders
assert key in cur.loaders
- composite.register(info, conn)
+ info.register(conn)
for fmt in (Format.TEXT, Format.BINARY):
for oid in (info.oid, info.array_oid):
key = oid, fmt