From: Daniele Varrazzo Date: Sat, 6 Feb 2021 03:14:43 +0000 (+0100) Subject: Register arrays using the TypeInfo object X-Git-Tag: 3.0.dev0~123 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=57860e0591103c971f65366538fd6adc27498de8;p=thirdparty%2Fpsycopg.git Register arrays using the TypeInfo object TypeInfo and subclasses implementations moved together to a dedicated module, sharing more uniform interface and implementation. --- diff --git a/psycopg3/psycopg3/adapt.py b/psycopg3/psycopg3/adapt.py index 0b4ffe43d..789593d3c 100644 --- a/psycopg3/psycopg3/adapt.py +++ b/psycopg3/psycopg3/adapt.py @@ -12,8 +12,9 @@ from . import pq from . import proto from . import errors as e from ._enums import Format as Format -from .oids import TypesRegistry, postgres_types +from .oids import postgres_types from .proto import AdaptContext, Buffer as Buffer +from .typeinfo import TypesRegistry if TYPE_CHECKING: from .connection import BaseConnection diff --git a/psycopg3/psycopg3/oids.py b/psycopg3/psycopg3/oids.py index 2942f3b56..48c2ee900 100644 --- a/psycopg3/psycopg3/oids.py +++ b/psycopg3/psycopg3/oids.py @@ -4,243 +4,88 @@ Maps of builtin types and names # Copyright (C) 2020-2021 The Psycopg Team -from typing import Dict, Iterator, Optional, Union - - -class TypeInfo: - def __init__( - self, name: str, oid: int, array_oid: int, range_subtype: int = 0 - ): - self.name = name - self.oid = oid - self.array_oid = array_oid - self.range_subtype = range_subtype - - def __repr__(self) -> str: - return ( - f"<{self.__class__.__qualname__}:" - f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>" - ) - - -class BuiltinTypeInfo(TypeInfo): - def __init__( - self, - name: str, - oid: int, - array_oid: int, - range_subtype: int, - alt_name: str, - delimiter: str, - ): - super().__init__(name, oid, array_oid, range_subtype) - self.alt_name = alt_name - self.delimiter = delimiter - - -class TypesRegistry: - """ - Container for the information about types in a database. - """ - - def __init__(self, template: Optional["TypesRegistry"] = None): - self._by_oid: Dict[int, TypeInfo] - self._by_name: Dict[str, TypeInfo] - self._by_range_subtype: Dict[int, TypeInfo] - - if template: - self._by_oid = template._by_oid - self._by_name = template._by_name - self._by_range_subtype = template._by_range_subtype - self._own_state = False - else: - self._by_oid = {} - self._by_name = {} - self._by_range_subtype = {} - self._own_state = True - - def add(self, info: TypeInfo) -> None: - self._ensure_own_state() - self._by_oid[info.oid] = info - if info.array_oid: - self._by_oid[info.array_oid] = info - self._by_name[info.name] = info - if info.range_subtype: - # Note: actually not unique (e.g. range of different collation?) - self._by_range_subtype[info.range_subtype] = info - - if isinstance(info, BuiltinTypeInfo): - if info.alt_name not in self._by_name: - self._by_name[info.alt_name] = info - - def __iter__(self) -> Iterator[TypeInfo]: - seen = set() - for t in self._by_oid.values(): - if t.oid not in seen: - seen.add(t.oid) - yield t - - def __getitem__(self, key: Union[str, int]) -> TypeInfo: - """ - Return info about a type, specified by name or oid - - The type name or oid may refer to the array too. - - Raise KeyError if not found. - """ - if isinstance(key, str): - if key.endswith("[]"): - key = key[:-2] - return self._by_name[key] - elif isinstance(key, int): - return self._by_oid[key] - else: - raise TypeError( - f"the key must be an oid or a name, got {type(key)}" - ) - - def get(self, key: Union[str, int]) -> Optional[TypeInfo]: - """ - Return info about a type, specified by name or oid - - The type name or oid may refer to the array too. - - Return None if not found. - """ - try: - return self[key] - except KeyError: - return None - - def get_oid(self, name: str) -> int: - """ - Return the oid of a PostgreSQL type by name. - - Return the array oid if the type ends with "[]" - - Raise KeyError if the name is unknown. - """ - t = self[name] - if name.endswith("[]"): - return t.array_oid - else: - return t.oid - - def get_range(self, key: Union[str, int]) -> Optional[TypeInfo]: - """ - Return info about a range by its element name or oid - - Return None if the element or its range are not found. - """ - try: - info = self[key] - except KeyError: - return None - return self._by_range_subtype.get(info.oid) - - def _ensure_own_state(self) -> None: - # Time to write! so, copy. - if not self._own_state: - self._by_oid = self._by_oid.copy() - self._by_name = self._by_name.copy() - self._by_range_subtype = self._by_range_subtype.copy() - self._own_state = True - +from .typeinfo import TypeInfo, RangeInfo, TypesRegistry +# Global objects with PostgreSQL builtins and globally registered user types. postgres_types = TypesRegistry() + # Use tools/update_oids.py to update this data. -for r in [ +for t in [ # autogenerated: start - # Generated from PostgreSQL 13.1 - ("aclitem", 1033, 1034, 0, "aclitem", ","), - ("any", 2276, 0, 0, '"any"', ","), - ("anyarray", 2277, 0, 0, "anyarray", ","), - ("anycompatible", 5077, 0, 0, "anycompatible", ","), - ("anycompatiblearray", 5078, 0, 0, "anycompatiblearray", ","), - ("anycompatiblenonarray", 5079, 0, 0, "anycompatiblenonarray", ","), - ("anycompatiblerange", 5080, 0, 0, "anycompatiblerange", ","), - ("anyelement", 2283, 0, 0, "anyelement", ","), - ("anyenum", 3500, 0, 0, "anyenum", ","), - ("anynonarray", 2776, 0, 0, "anynonarray", ","), - ("anyrange", 3831, 0, 0, "anyrange", ","), - ("bit", 1560, 1561, 0, "bit", ","), - ("bool", 16, 1000, 0, "boolean", ","), - ("box", 603, 1020, 0, "box", ";"), - ("bpchar", 1042, 1014, 0, "character", ","), - ("bytea", 17, 1001, 0, "bytea", ","), - ("char", 18, 1002, 0, '"char"', ","), - ("cid", 29, 1012, 0, "cid", ","), - ("cidr", 650, 651, 0, "cidr", ","), - ("circle", 718, 719, 0, "circle", ","), - ("cstring", 2275, 1263, 0, "cstring", ","), - ("date", 1082, 1182, 0, "date", ","), - ("daterange", 3912, 3913, 1082, "daterange", ","), - ("event_trigger", 3838, 0, 0, "event_trigger", ","), - ("float4", 700, 1021, 0, "real", ","), - ("float8", 701, 1022, 0, "double precision", ","), - ("gtsvector", 3642, 3644, 0, "gtsvector", ","), - ("inet", 869, 1041, 0, "inet", ","), - ("int2", 21, 1005, 0, "smallint", ","), - ("int2vector", 22, 1006, 0, "int2vector", ","), - ("int4", 23, 1007, 0, "integer", ","), - ("int4range", 3904, 3905, 23, "int4range", ","), - ("int8", 20, 1016, 0, "bigint", ","), - ("int8range", 3926, 3927, 20, "int8range", ","), - ("internal", 2281, 0, 0, "internal", ","), - ("interval", 1186, 1187, 0, "interval", ","), - ("json", 114, 199, 0, "json", ","), - ("jsonb", 3802, 3807, 0, "jsonb", ","), - ("jsonpath", 4072, 4073, 0, "jsonpath", ","), - ("line", 628, 629, 0, "line", ","), - ("lseg", 601, 1018, 0, "lseg", ","), - ("macaddr", 829, 1040, 0, "macaddr", ","), - ("macaddr8", 774, 775, 0, "macaddr8", ","), - ("money", 790, 791, 0, "money", ","), - ("name", 19, 1003, 0, "name", ","), - ("numeric", 1700, 1231, 0, "numeric", ","), - ("numrange", 3906, 3907, 1700, "numrange", ","), - ("oid", 26, 1028, 0, "oid", ","), - ("oidvector", 30, 1013, 0, "oidvector", ","), - ("path", 602, 1019, 0, "path", ","), - ("point", 600, 1017, 0, "point", ","), - ("polygon", 604, 1027, 0, "polygon", ","), - ("record", 2249, 2287, 0, "record", ","), - ("refcursor", 1790, 2201, 0, "refcursor", ","), - ("regclass", 2205, 2210, 0, "regclass", ","), - ("regcollation", 4191, 4192, 0, "regcollation", ","), - ("regconfig", 3734, 3735, 0, "regconfig", ","), - ("regdictionary", 3769, 3770, 0, "regdictionary", ","), - ("regnamespace", 4089, 4090, 0, "regnamespace", ","), - ("regoper", 2203, 2208, 0, "regoper", ","), - ("regoperator", 2204, 2209, 0, "regoperator", ","), - ("regproc", 24, 1008, 0, "regproc", ","), - ("regprocedure", 2202, 2207, 0, "regprocedure", ","), - ("regrole", 4096, 4097, 0, "regrole", ","), - ("regtype", 2206, 2211, 0, "regtype", ","), - ("text", 25, 1009, 0, "text", ","), - ("tid", 27, 1010, 0, "tid", ","), - ("time", 1083, 1183, 0, "time without time zone", ","), - ("timestamp", 1114, 1115, 0, "timestamp without time zone", ","), - ("timestamptz", 1184, 1185, 0, "timestamp with time zone", ","), - ("timetz", 1266, 1270, 0, "time with time zone", ","), - ("trigger", 2279, 0, 0, "trigger", ","), - ("tsquery", 3615, 3645, 0, "tsquery", ","), - ("tsrange", 3908, 3909, 1114, "tsrange", ","), - ("tstzrange", 3910, 3911, 1184, "tstzrange", ","), - ("tsvector", 3614, 3643, 0, "tsvector", ","), - ("txid_snapshot", 2970, 2949, 0, "txid_snapshot", ","), - ("unknown", 705, 0, 0, "unknown", ","), - ("uuid", 2950, 2951, 0, "uuid", ","), - ("varbit", 1562, 1563, 0, "bit varying", ","), - ("varchar", 1043, 1015, 0, "character varying", ","), - ("void", 2278, 0, 0, "void", ","), - ("xid", 28, 1011, 0, "xid", ","), - ("xid8", 5069, 271, 0, "xid8", ","), - ("xml", 142, 143, 0, "xml", ","), + # Generated from PostgreSQL 13.0 + TypeInfo("aclitem", 1033, 1034), + TypeInfo("bit", 1560, 1561), + TypeInfo("bool", 16, 1000, alt_name="boolean"), + TypeInfo("box", 603, 1020, delimiter=";"), + TypeInfo("bpchar", 1042, 1014, alt_name="character"), + TypeInfo("bytea", 17, 1001), + TypeInfo("char", 18, 1002, alt_name='"char"'), + TypeInfo("cid", 29, 1012), + TypeInfo("cidr", 650, 651), + TypeInfo("circle", 718, 719), + TypeInfo("date", 1082, 1182), + TypeInfo("float4", 700, 1021, alt_name="real"), + TypeInfo("float8", 701, 1022, alt_name="double precision"), + TypeInfo("gtsvector", 3642, 3644), + TypeInfo("inet", 869, 1041), + TypeInfo("int2", 21, 1005, alt_name="smallint"), + TypeInfo("int2vector", 22, 1006), + TypeInfo("int4", 23, 1007, alt_name="integer"), + TypeInfo("int8", 20, 1016, alt_name="bigint"), + TypeInfo("interval", 1186, 1187), + TypeInfo("json", 114, 199), + TypeInfo("jsonb", 3802, 3807), + TypeInfo("jsonpath", 4072, 4073), + TypeInfo("line", 628, 629), + TypeInfo("lseg", 601, 1018), + TypeInfo("macaddr", 829, 1040), + TypeInfo("macaddr8", 774, 775), + TypeInfo("money", 790, 791), + TypeInfo("name", 19, 1003), + TypeInfo("numeric", 1700, 1231), + TypeInfo("oid", 26, 1028), + TypeInfo("oidvector", 30, 1013), + TypeInfo("path", 602, 1019), + TypeInfo("point", 600, 1017), + TypeInfo("polygon", 604, 1027), + TypeInfo("record", 2249, 2287), + TypeInfo("refcursor", 1790, 2201), + TypeInfo("regclass", 2205, 2210), + TypeInfo("regcollation", 4191, 4192), + TypeInfo("regconfig", 3734, 3735), + TypeInfo("regdictionary", 3769, 3770), + TypeInfo("regnamespace", 4089, 4090), + TypeInfo("regoper", 2203, 2208), + TypeInfo("regoperator", 2204, 2209), + TypeInfo("regproc", 24, 1008), + TypeInfo("regprocedure", 2202, 2207), + TypeInfo("regrole", 4096, 4097), + TypeInfo("regtype", 2206, 2211), + TypeInfo("text", 25, 1009), + TypeInfo("tid", 27, 1010), + TypeInfo("time", 1083, 1183, alt_name="time without time zone"), + TypeInfo("timestamp", 1114, 1115, alt_name="timestamp without time zone"), + TypeInfo("timestamptz", 1184, 1185, alt_name="timestamp with time zone"), + TypeInfo("timetz", 1266, 1270, alt_name="time with time zone"), + TypeInfo("tsquery", 3615, 3645), + TypeInfo("tsvector", 3614, 3643), + TypeInfo("txid_snapshot", 2970, 2949), + TypeInfo("uuid", 2950, 2951), + TypeInfo("varbit", 1562, 1563, alt_name="bit varying"), + TypeInfo("varchar", 1043, 1015, alt_name="character varying"), + TypeInfo("xid", 28, 1011), + TypeInfo("xid8", 5069, 271), + TypeInfo("xml", 142, 143), + RangeInfo("daterange", 3912, 3913, subtype_oid=1082), + RangeInfo("int4range", 3904, 3905, subtype_oid=23), + RangeInfo("int8range", 3926, 3927, subtype_oid=20), + RangeInfo("numrange", 3906, 3907, subtype_oid=1700), + RangeInfo("tsrange", 3908, 3909, subtype_oid=1114), + RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184), # autogenerated: end ]: - postgres_types.add(BuiltinTypeInfo(*r)) + postgres_types.add(t) # A few oids used a bit everywhere diff --git a/psycopg3/psycopg3/typeinfo.py b/psycopg3/psycopg3/typeinfo.py new file mode 100644 index 000000000..67f0bd7bb --- /dev/null +++ b/psycopg3/psycopg3/typeinfo.py @@ -0,0 +1,315 @@ +""" +Information about PostgreSQL types + +These types allow to read information from the system catalog and provide +information to the adapters if needed. +""" + +# Copyright (C) 2020-2021 The Psycopg Team + +from typing import Any, Callable, Dict, Iterator, Optional +from typing import Sequence, Type, TypeVar, Union, TYPE_CHECKING + +from . import errors as e +from .proto import AdaptContext + +if TYPE_CHECKING: + from .connection import Connection, AsyncConnection + from .sql import Identifier + +T = TypeVar("T", bound="TypeInfo") + + +class TypeInfo: + """ + Hold information about a PostgreSQL base type. + + The class allows to: + + - read information about a range type using `fetch()` and `fetch_async()` + - configure a composite type adaptation using `register()` + """ + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + alt_name: str = "", + delimiter: str = ",", + ): + self.name = name + self.oid = oid + self.array_oid = array_oid + self.alt_name = alt_name + self.delimiter = delimiter + + def __repr__(self) -> str: + return ( + f"<{self.__class__.__qualname__}:" + f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>" + ) + + @classmethod + def fetch( + cls: Type[T], conn: "Connection", name: Union[str, "Identifier"] + ) -> Optional[T]: + from .sql import Composable + + if isinstance(name, Composable): + name = name.as_string(conn) + cur = conn.cursor(binary=True) + cur.execute(cls._info_query, {"name": name}) + recs = cur.fetchall() + fields = [d[0] for d in cur.description or ()] + return cls._fetch(name, fields, recs) + + @classmethod + async def fetch_async( + cls: Type[T], conn: "AsyncConnection", name: Union[str, "Identifier"] + ) -> Optional[T]: + from .sql import Composable + + if isinstance(name, Composable): + name = name.as_string(conn) + cur = await conn.cursor(binary=True) + await cur.execute(cls._info_query, {"name": name}) + recs = await cur.fetchall() + fields = [d[0] for d in cur.description or ()] + return cls._fetch(name, fields, recs) + + @classmethod + def _fetch( + cls: Type[T], + name: str, + fields: Sequence[str], + recs: Sequence[Sequence[Any]], + ) -> Optional[T]: + if len(recs) == 1: + return cls(**dict(zip(fields, recs[0]))) + elif not recs: + return None + else: + raise e.ProgrammingError( + f"found {len(recs)} different types named {name}" + ) + + def register( + self, + context: Optional["AdaptContext"] = None, + ) -> None: + + if context: + types = context.adapters.types + else: + from .oids import postgres_types + + types = postgres_types + + types.add(self) + + if self.array_oid: + from .types.array import register_adapters + + register_adapters(self, context) + + _info_query = """\ +select + typname as name, oid, typarray as array_oid, + oid::regtype as alt_name, typdelim as delimiter +from pg_type t +where t.oid = %(name)s::regtype +order by t.oid +""" + + +class RangeInfo(TypeInfo): + """Manage information about a range type.""" + + def __init__(self, name: str, oid: int, array_oid: int, subtype_oid: int): + super().__init__(name, oid, array_oid) + self.subtype_oid = subtype_oid + + def register( + self, + context: Optional[AdaptContext] = None, + ) -> None: + super().register(context) + + from .types.range import register_adapters + + register_adapters(self, context) + + _info_query = """\ +select t.typname as name, t.oid as oid, t.typarray as array_oid, + r.rngsubtype as subtype_oid +from pg_type t +join pg_range r on t.oid = r.rngtypid +where t.oid = %(name)s::regtype +""" + + +class CompositeInfo(TypeInfo): + """Manage information about a composite type.""" + + def __init__( + self, + name: str, + oid: int, + array_oid: int, + field_names: Sequence[str], + field_types: Sequence[int], + ): + super().__init__(name, oid, array_oid) + self.field_names = field_names + self.field_types = field_types + + def register( + self, + context: Optional[AdaptContext] = None, + factory: Optional[Callable[..., Any]] = None, + ) -> None: + super().register(context) + + from .types.composite import register_adapters + + register_adapters(self, context, factory) + + _info_query = """\ +select + t.typname as name, t.oid as oid, t.typarray as array_oid, + coalesce(a.fnames, '{}') as field_names, + coalesce(a.ftypes, '{}') as field_types +from pg_type t +left join ( + select + attrelid, + array_agg(attname) as fnames, + array_agg(atttypid) as ftypes + from ( + select a.attrelid, a.attname, a.atttypid + from pg_attribute a + join pg_type t on t.typrelid = a.attrelid + where t.oid = %(name)s::regtype + and a.attnum > 0 + and not a.attisdropped + order by a.attnum + ) x + group by attrelid +) a on a.attrelid = t.typrelid +where t.oid = %(name)s::regtype +""" + + +class TypesRegistry: + """ + Container for the information about types in a database. + """ + + def __init__(self, template: Optional["TypesRegistry"] = None): + self._by_oid: Dict[int, TypeInfo] + self._by_name: Dict[str, TypeInfo] + self._by_range_subtype: Dict[int, TypeInfo] + + # Make a shallow copy: it will become a proper copy if the registry + # is edited (note the BUG: a child will get shallow-copied, but changing + # the parent will change children who weren't copied yet. It can be + # probably fixed by setting _own_state to False on the parent on copy, + # but needs testing and for the moment I'll leave it there TODO). + if template: + self._by_oid = template._by_oid + self._by_name = template._by_name + self._by_range_subtype = template._by_range_subtype + self._own_state = False + else: + self._by_oid = {} + self._by_name = {} + self._by_range_subtype = {} + self._own_state = True + + def add(self, info: TypeInfo) -> None: + self._ensure_own_state() + self._by_oid[info.oid] = info + if info.array_oid: + self._by_oid[info.array_oid] = info + self._by_name[info.name] = info + + if info.alt_name and info.alt_name not in self._by_name: + self._by_name[info.alt_name] = info + + # Map ranges subtypes to info + if isinstance(info, RangeInfo): + self._by_range_subtype[info.subtype_oid] = info + + def __iter__(self) -> Iterator[TypeInfo]: + seen = set() + for t in self._by_oid.values(): + if t.oid not in seen: + seen.add(t.oid) + yield t + + def __getitem__(self, key: Union[str, int]) -> TypeInfo: + """ + Return info about a type, specified by name or oid + + The type name or oid may refer to the array too. + + Raise KeyError if not found. + """ + if isinstance(key, str): + if key.endswith("[]"): + key = key[:-2] + return self._by_name[key] + elif isinstance(key, int): + return self._by_oid[key] + else: + raise TypeError( + f"the key must be an oid or a name, got {type(key)}" + ) + + def get(self, key: Union[str, int]) -> Optional[TypeInfo]: + """ + Return info about a type, specified by name or oid + + The type name or oid may refer to the array too. + + Return None if not found. + """ + try: + return self[key] + except KeyError: + return None + + def get_oid(self, name: str) -> int: + """ + Return the oid of a PostgreSQL type by name. + + Return the array oid if the type ends with "[]" + + Raise KeyError if the name is unknown. + """ + t = self[name] + if name.endswith("[]"): + return t.array_oid + else: + return t.oid + + def get_range(self, key: Union[str, int]) -> Optional[TypeInfo]: + """ + Return info about a range by its element name or oid + + Return None if the element or its range are not found. + """ + try: + info = self[key] + except KeyError: + return None + return self._by_range_subtype.get(info.oid) + + def _ensure_own_state(self) -> None: + # Time to write! so, copy. + if not self._own_state: + self._by_oid = self._by_oid.copy() + self._by_name = self._by_name.copy() + self._by_range_subtype = self._by_range_subtype.copy() + self._own_state = True diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index dc70082b9..2b7bda463 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -8,17 +8,15 @@ from ..oids import INVALID_OID from ..proto import AdaptContext # Register default adapters -from . import array, composite -from . import range +from . import array, composite, range # Wrapper objects from ..wrappers.numeric import Int2, Int4, Int8, IntNumeric, Oid from .json import Json, Jsonb from .range import Range -# Supper objects -from .range import RangeInfo -from .composite import CompositeInfo +# Database types descriptors +from ..typeinfo import TypeInfo, RangeInfo, CompositeInfo # Adapter objects from .text import ( diff --git a/psycopg3/psycopg3/types/array.py b/psycopg3/psycopg3/types/array.py index b68bf77c3..b282d8583 100644 --- a/psycopg3/psycopg3/types/array.py +++ b/psycopg3/psycopg3/types/array.py @@ -14,6 +14,7 @@ from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID from ..adapt import Buffer, Dumper, Loader, Transformer from ..adapt import Format as Pg3Format from ..proto import AdaptContext +from ..typeinfo import TypeInfo class BaseListDumper(Dumper): @@ -315,20 +316,15 @@ class ArrayBinaryLoader(BaseArrayLoader): return agg(dims) -def register( - array_oid: int, - base_oid: int, - context: Optional[AdaptContext] = None, - name: Optional[str] = None, +def register_adapters( + info: TypeInfo, context: Optional["AdaptContext"] ) -> None: - if not name: - name = f"oid{base_oid}" - for base in (ArrayLoader, ArrayBinaryLoader): - fmt = "Binary" if base.format == pq.Format.BINARY else "" - lname = f"{name.title()}Array{fmt}Loader" - loader: Type[Loader] = type(lname, (base,), {"base_oid": base_oid}) - loader.register(array_oid, context=context) + lname = f"{info.name.title()}{base.__name__}" + loader: Type[BaseArrayLoader] = type( + lname, (base,), {"base_oid": info.oid} + ) + loader.register(info.array_oid, context=context) def register_all_arrays(ctx: AdaptContext) -> None: @@ -341,4 +337,4 @@ def register_all_arrays(ctx: AdaptContext) -> None: for t in ctx.adapters.types: # TODO: handle different delimiters (box) if t.array_oid and getattr(t, "delimiter", None) == ",": - register(t.array_oid, t.oid, name=t.name) + t.register(ctx) diff --git a/psycopg3/psycopg3/types/composite.py b/psycopg3/psycopg3/types/composite.py index c43a68b1c..f91f29b1b 100644 --- a/psycopg3/psycopg3/types/composite.py +++ b/psycopg3/psycopg3/types/composite.py @@ -7,141 +7,14 @@ Support for composite types adaptation. import re import struct from collections import namedtuple -from typing import Any, Callable, Iterator, List, NamedTuple, Optional -from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING +from typing import Any, Callable, Iterator, List, Optional +from typing import Sequence, Tuple, Type from .. import pq -from .. import sql -from .. import errors as e -from ..oids import TypeInfo, TEXT_OID +from ..oids import TEXT_OID from ..adapt import Buffer, Format, Dumper, Loader, Transformer from ..proto import AdaptContext -from . import array - -if TYPE_CHECKING: - from ..connection import Connection, AsyncConnection - - -class CompositeInfo(TypeInfo): - """Manage information about a composite type. - - The class allows to: - - - read information about a composite type using `fetch()` and `fetch_async()` - - configure a composite type adaptation using `register()` - """ - - def __init__( - self, - name: str, - oid: int, - array_oid: int, - fields: Sequence["CompositeInfo.FieldInfo"], - ): - super().__init__(name, oid, array_oid) - self.fields = list(fields) - - class FieldInfo(NamedTuple): - """Information about a single field in a composite type.""" - - name: str - type_oid: int - - @classmethod - def fetch( - cls, conn: "Connection", name: Union[str, sql.Identifier] - ) -> Optional["CompositeInfo"]: - if isinstance(name, sql.Composable): - name = name.as_string(conn) - cur = conn.cursor() - cur.execute(cls._info_query, {"name": name}) - recs = cur.fetchall() - return cls._from_records(recs) - - @classmethod - async def fetch_async( - cls, conn: "AsyncConnection", name: Union[str, sql.Identifier] - ) -> Optional["CompositeInfo"]: - if isinstance(name, sql.Composable): - name = name.as_string(conn) - cur = await conn.cursor() - await cur.execute(cls._info_query, {"name": name}) - recs = await cur.fetchall() - return cls._from_records(recs) - - def register( - self, - context: Optional[AdaptContext] = None, - factory: Optional[Callable[..., Any]] = None, - ) -> None: - if not factory: - factory = namedtuple( # type: ignore - self.name, [f.name for f in self.fields] - ) - - loader: Type[Loader] - - # generate and register a customized text loader - loader = type( - f"{self.name.title()}Loader", - (CompositeLoader,), - { - "factory": factory, - "fields_types": [f.type_oid for f in self.fields], - }, - ) - loader.register(self.oid, context=context) - - # generate and register a customized binary loader - loader = type( - f"{self.name.title()}BinaryLoader", - (CompositeBinaryLoader,), - {"factory": factory}, - ) - loader.register(self.oid, context=context) - - if self.array_oid: - array.register( - self.array_oid, self.oid, context=context, name=self.name - ) - - @classmethod - def _from_records(cls, recs: Sequence[Any]) -> Optional["CompositeInfo"]: - if not recs: - return None - if len(recs) > 1: - raise e.ProgrammingError( - f"found {len(recs)} different types named {recs[0][0]}" - ) - - name, oid, array_oid, fnames, ftypes = recs[0] - fields = [cls.FieldInfo(*p) for p in zip(fnames, ftypes)] - return cls(name, oid, array_oid, fields) - - _info_query = """\ -select - t.typname as name, t.oid as oid, t.typarray as array_oid, - coalesce(a.fnames, '{}') as fnames, - coalesce(a.ftypes, '{}') as ftypes -from pg_type t -left join ( - select - attrelid, - array_agg(attname) as fnames, - array_agg(atttypid) as ftypes - from ( - select a.attrelid, a.attname, a.atttypid - from pg_attribute a - join pg_type t on t.typrelid = a.attrelid - where t.oid = %(name)s::regtype - and a.attnum > 0 - and not a.attisdropped - order by a.attnum - ) x - group by attrelid -) a on a.attrelid = t.typrelid -where t.oid = %(name)s::regtype -""" +from ..typeinfo import CompositeInfo class SequenceDumper(Dumper): @@ -317,3 +190,31 @@ class CompositeBinaryLoader(RecordBinaryLoader): def load(self, data: Buffer) -> Any: r = super().load(data) return type(self).factory(*r) + + +def register_adapters( + info: CompositeInfo, + context: Optional["AdaptContext"], + factory: Optional[Callable[..., Any]] = None, +) -> None: + if not factory: + factory = namedtuple(info.name, info.field_names) # type: ignore + + # generate and register a customized text loader + loader: Type[BaseCompositeLoader] = type( + f"{info.name.title()}Loader", + (CompositeLoader,), + { + "factory": factory, + "fields_types": info.field_types, + }, + ) + loader.register(info.oid, context=context) + + # generate and register a customized binary loader + loader = type( + f"{info.name.title()}BinaryLoader", + (CompositeBinaryLoader,), + {"factory": factory}, + ) + loader.register(info.oid, context=context) diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py index 69ac4ee44..43b9ccbd6 100644 --- a/psycopg3/psycopg3/types/range.py +++ b/psycopg3/psycopg3/types/range.py @@ -5,24 +5,18 @@ Support for range types adaptation. # Copyright (C) 2020-2021 The Psycopg Team import re -from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Type, Union -from typing import cast, Tuple, TYPE_CHECKING +from typing import cast, Any, Dict, Generic, Optional, Tuple, TypeVar, Type from decimal import Decimal from datetime import date, datetime -from .. import sql -from .. import errors as e from ..pq import Format -from ..oids import postgres_types as builtins, TypeInfo, INVALID_OID -from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format +from ..oids import postgres_types as builtins, INVALID_OID +from ..adapt import Buffer, Dumper, Format as Pg3Format from ..proto import AdaptContext +from ..typeinfo import RangeInfo -from . import array from .composite import SequenceDumper, BaseCompositeLoader -if TYPE_CHECKING: - from ..connection import Connection, AsyncConnection - T = TypeVar("T") @@ -312,6 +306,18 @@ class RangeLoader(BaseCompositeLoader, Generic[T]): _int2parens = {ord(c): c for c in "[]()"} +def register_adapters( + info: RangeInfo, context: Optional["AdaptContext"] +) -> None: + # generate and register a customized text loader + loader: Type[RangeLoader[Any]] = type( + f"{info.name.title()}Loader", + (RangeLoader,), + {"subtype_oid": info.subtype_oid}, + ) + loader.register(info.oid, context=context) + + # Loaders for builtin range types @@ -337,76 +343,3 @@ class TimestampRangeLoader(RangeLoader[datetime]): class TimestampTZRangeLoader(RangeLoader[datetime]): subtype_oid = builtins["timestamptz"].oid - - -class RangeInfo(TypeInfo): - """Manage information about a range type. - - The class allows to: - - - read information about a range type using `fetch()` and `fetch_async()` - - configure a composite type adaptation using `register()` - """ - - @classmethod - def fetch( - cls, conn: "Connection", name: Union[str, sql.Identifier] - ) -> Optional["RangeInfo"]: - if isinstance(name, sql.Composable): - name = name.as_string(conn) - cur = conn.cursor(binary=True) - cur.execute(cls._info_query, {"name": name}) - recs = cur.fetchall() - return cls._from_records(recs) - - @classmethod - async def fetch_async( - cls, conn: "AsyncConnection", name: Union[str, sql.Identifier] - ) -> Optional["RangeInfo"]: - if isinstance(name, sql.Composable): - name = name.as_string(conn) - cur = await conn.cursor(binary=True) - await cur.execute(cls._info_query, {"name": name}) - recs = await cur.fetchall() - return cls._from_records(recs) - - def register( - self, - context: Optional[AdaptContext] = None, - ) -> None: - - if context: - context.adapters.types.add(self) - - # generate and register a customized text loader - loader: Type[Loader] = type( - f"{self.name.title()}Loader", - (RangeLoader,), - {"subtype_oid": self.range_subtype}, - ) - loader.register(self.oid, context=context) - - if self.array_oid: - array.register( - self.array_oid, self.oid, context=context, name=self.name - ) - - @classmethod - def _from_records(cls, recs: Sequence[Any]) -> Optional["RangeInfo"]: - if not recs: - return None - if len(recs) > 1: - raise e.ProgrammingError( - f"found {len(recs)} different ranges named {recs[0][0]}" - ) - - name, oid, array_oid, subtype = recs[0] - return cls(name, oid, array_oid, subtype) - - _info_query = """\ -select t.typname as name, t.oid as oid, t.typarray as array_oid, - r.rngsubtype as range_subtype -from pg_type t -join pg_range r on t.oid = r.rngtypid -where t.oid = %(name)s::regtype -""" diff --git a/psycopg3_c/psycopg3_c/_psycopg3/oids.pxd b/psycopg3_c/psycopg3_c/_psycopg3/oids.pxd index 7e82198e6..0249bb35b 100644 --- a/psycopg3_c/psycopg3_c/_psycopg3/oids.pxd +++ b/psycopg3_c/psycopg3_c/_psycopg3/oids.pxd @@ -11,19 +11,9 @@ cdef enum: # autogenerated: start - # Generated from PostgreSQL 13.1 + # Generated from PostgreSQL 13.0 ACLITEM_OID = 1033 - ANY_OID = 2276 - ANYARRAY_OID = 2277 - ANYCOMPATIBLE_OID = 5077 - ANYCOMPATIBLEARRAY_OID = 5078 - ANYCOMPATIBLENONARRAY_OID = 5079 - ANYCOMPATIBLERANGE_OID = 5080 - ANYELEMENT_OID = 2283 - ANYENUM_OID = 3500 - ANYNONARRAY_OID = 2776 - ANYRANGE_OID = 3831 BIT_OID = 1560 BOOL_OID = 16 BOX_OID = 603 @@ -33,10 +23,8 @@ cdef enum: CID_OID = 29 CIDR_OID = 650 CIRCLE_OID = 718 - CSTRING_OID = 2275 DATE_OID = 1082 DATERANGE_OID = 3912 - EVENT_TRIGGER_OID = 3838 FLOAT4_OID = 700 FLOAT8_OID = 701 GTSVECTOR_OID = 3642 @@ -47,7 +35,6 @@ cdef enum: INT4RANGE_OID = 3904 INT8_OID = 20 INT8RANGE_OID = 3926 - INTERNAL_OID = 2281 INTERVAL_OID = 1186 JSON_OID = 114 JSONB_OID = 3802 @@ -84,17 +71,14 @@ cdef enum: TIMESTAMP_OID = 1114 TIMESTAMPTZ_OID = 1184 TIMETZ_OID = 1266 - TRIGGER_OID = 2279 TSQUERY_OID = 3615 TSRANGE_OID = 3908 TSTZRANGE_OID = 3910 TSVECTOR_OID = 3614 TXID_SNAPSHOT_OID = 2970 - UNKNOWN_OID = 705 UUID_OID = 2950 VARBIT_OID = 1562 VARCHAR_OID = 1043 - VOID_OID = 2278 XID_OID = 28 XID8_OID = 5069 XML_OID = 142 diff --git a/tests/types/test_array.py b/tests/types/test_array.py index 220876831..ea1805df1 100644 --- a/tests/types/test_array.py +++ b/tests/types/test_array.py @@ -4,7 +4,7 @@ from psycopg3 import pq from psycopg3 import sql from psycopg3.oids import postgres_types as builtins from psycopg3.adapt import Format, Transformer -from psycopg3.types import array +from psycopg3.types import TypeInfo tests_str = [ @@ -118,9 +118,9 @@ def test_array_register(conn): assert res[0] == "(foo)" assert res[1] == "{(foo)}" - array.register( - cur.description[1].type_code, cur.description[0].type_code, context=cur - ) + info = TypeInfo.fetch(conn, "mytype") + info.register(cur) + cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[] -- 2""") res = cur.fetchone() assert res[0] == "(foo)" diff --git a/tests/types/test_composite.py b/tests/types/test_composite.py index 32d2ad75a..078e4f3b4 100644 --- a/tests/types/test_composite.py +++ b/tests/types/test_composite.py @@ -4,7 +4,7 @@ from psycopg3 import pq from psycopg3.sql import Identifier from psycopg3.oids import postgres_types as builtins from psycopg3.adapt import Format, global_adapters -from psycopg3.types.composite import CompositeInfo +from psycopg3.types import CompositeInfo tests_str = [ @@ -135,10 +135,11 @@ def test_fetch_info(conn, testcomp, name, fields): assert info.name == "testcomp" assert info.oid > 0 assert info.oid != info.array_oid > 0 - assert len(info.fields) == 3 + assert len(info.field_names) == 3 + assert len(info.field_types) == 3 for i, (name, t) in enumerate(fields): - assert info.fields[i].name == name - assert info.fields[i].type_oid == builtins[t].oid + assert info.field_names[i] == name + assert info.field_types[i] == builtins[t].oid @pytest.mark.asyncio @@ -148,10 +149,11 @@ async def test_fetch_info_async(aconn, testcomp, name, fields): assert info.name == "testcomp" assert info.oid > 0 assert info.oid != info.array_oid > 0 - assert len(info.fields) == 3 + assert len(info.field_names) == 3 + assert len(info.field_types) == 3 for i, (name, t) in enumerate(fields): - assert info.fields[i].name == name - assert info.fields[i].type_oid == builtins[t].oid + assert info.field_names[i] == name + assert info.field_types[i] == builtins[t].oid @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY]) diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 0ccbc9545..f170cbb9e 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -5,8 +5,7 @@ from decimal import Decimal import pytest from psycopg3.sql import Identifier -from psycopg3.types import range as mrange -from psycopg3.types.range import Range +from psycopg3.types import Range, RangeInfo type2sub = { @@ -159,36 +158,36 @@ fetch_cases = [ @pytest.mark.parametrize("name, subtype", fetch_cases) def test_fetch_info(conn, testrange, name, subtype): - info = mrange.RangeInfo.fetch(conn, name) + info = RangeInfo.fetch(conn, name) assert info.name == "testrange" assert info.oid > 0 assert info.oid != info.array_oid > 0 - assert info.range_subtype == conn.adapters.types[subtype].oid + assert info.subtype_oid == conn.adapters.types[subtype].oid def test_fetch_info_not_found(conn): with pytest.raises(conn.ProgrammingError): - mrange.RangeInfo.fetch(conn, "nosuchrange") + RangeInfo.fetch(conn, "nosuchrange") @pytest.mark.asyncio @pytest.mark.parametrize("name, subtype", fetch_cases) async def test_fetch_info_async(aconn, testrange, name, subtype): - info = await mrange.RangeInfo.fetch_async(aconn, name) + info = await RangeInfo.fetch_async(aconn, name) assert info.name == "testrange" assert info.oid > 0 assert info.oid != info.array_oid > 0 - assert info.range_subtype == aconn.adapters.types[subtype].oid + assert info.subtype_oid == aconn.adapters.types[subtype].oid @pytest.mark.asyncio async def test_fetch_info_not_found_async(aconn): with pytest.raises(aconn.ProgrammingError): - await mrange.RangeInfo.fetch_async(aconn, "nosuchrange") + await RangeInfo.fetch_async(aconn, "nosuchrange") def test_dump_custom_empty(conn, testrange): - info = mrange.RangeInfo.fetch(conn, "testrange") + info = RangeInfo.fetch(conn, "testrange") info.register(conn) r = Range(empty=True) @@ -197,7 +196,7 @@ def test_dump_custom_empty(conn, testrange): def test_dump_quoting(conn, testrange): - info = mrange.RangeInfo.fetch(conn, "testrange") + info = RangeInfo.fetch(conn, "testrange") info.register(conn) cur = conn.cursor() for i in range(1, 254): @@ -212,16 +211,16 @@ def test_dump_quoting(conn, testrange): def test_load_custom_empty(conn, testrange): - info = mrange.RangeInfo.fetch(conn, "testrange") + info = RangeInfo.fetch(conn, "testrange") info.register(conn) (got,) = conn.execute("select 'empty'::testrange").fetchone() - assert isinstance(got, mrange.Range) + assert isinstance(got, Range) assert got.isempty def test_load_quoting(conn, testrange): - info = mrange.RangeInfo.fetch(conn, "testrange") + info = RangeInfo.fetch(conn, "testrange") info.register(conn) cur = conn.cursor() for i in range(1, 254): @@ -230,7 +229,7 @@ def test_load_quoting(conn, testrange): {"low": i, "up": i + 1}, ) got = cur.fetchone()[0] - assert isinstance(got, mrange.Range) + assert isinstance(got, Range) assert ord(got.lower) == i assert ord(got.upper) == i + 1 @@ -483,7 +482,7 @@ class TestRangeObject: string conversion. """ tz = dt.timezone(dt.timedelta(hours=-5)) - r = mrange.Range( + r = Range( dt.datetime(2010, 1, 1, tzinfo=tz), dt.datetime(2011, 1, 1, tzinfo=tz), ) diff --git a/tools/update_oids.py b/tools/update_oids.py index 62c90d395..85349db22 100755 --- a/tools/update_oids.py +++ b/tools/update_oids.py @@ -22,29 +22,57 @@ $$, where name = 'server_version_num' """ -py_oids_sql = """ -select format( - '(%L, %s, %s, %s, %L, %L),', - typname, oid, typarray, coalesce(rngsubtype, 0), oid::regtype, typdelim) - from pg_type t - left join pg_range r on t.oid = rngtypid - where oid < 10000 - and typname !~ all('{^(_|pg_),_handler$}') - order by typname +# Note: "record" is a pseudotype but still a useful one to have. +py_types_sql = """ +select + 'TypeInfo(' + || array_to_string(array_remove(array[ + format('%L', typname), + oid::text, + typarray::text, + case when oid::regtype::text != typname + then format('alt_name=%L', oid::regtype) + end, + case when typdelim != ',' + then format('delimiter=%L', typdelim) + end + ], null), ',') + || '),' +from pg_type t +where + oid < 10000 + and (typtype = 'b' or typname = 'record') + and typname !~ '^(_|pg_)' +order by typname """ +py_ranges_sql = """ +select + format('RangeInfo(%L, %s, %s, subtype_oid=%s),', + typname, oid, typarray, rngsubtype) +from + pg_type t + join pg_range r on t.oid = rngtypid +where + oid < 10000 + and typtype = 'r' + and typname !~ '^(_|pg_)' +order by typname +""" cython_oids_sql = """ select format('%s_OID = %s', upper(typname), oid) - from pg_type - where oid < 10000 - and typname !~ all('{^(_|pg_),_handler$}') - order by typname +from pg_type +where + oid < 10000 + and (typtype = any('{b,r}') or typname = 'record') + and typname !~ '^(_|pg_)' +order by typname """ def update_python_oids() -> None: - queries = [version_sql, py_oids_sql] + queries = [version_sql, py_types_sql, py_ranges_sql] fn = ROOT / "psycopg3/psycopg3/oids.py" update_file(fn, queries) sp.check_call(["black", "-q", fn])