]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Register arrays using the TypeInfo object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 6 Feb 2021 03:14:43 +0000 (04:14 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 7 Feb 2021 01:37:17 +0000 (02:37 +0100)
TypeInfo and subclasses implementations moved together to a dedicated module,
sharing more uniform interface and implementation.

12 files changed:
psycopg3/psycopg3/adapt.py
psycopg3/psycopg3/oids.py
psycopg3/psycopg3/typeinfo.py [new file with mode: 0644]
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/array.py
psycopg3/psycopg3/types/composite.py
psycopg3/psycopg3/types/range.py
psycopg3_c/psycopg3_c/_psycopg3/oids.pxd
tests/types/test_array.py
tests/types/test_composite.py
tests/types/test_range.py
tools/update_oids.py

index 0b4ffe43d52d23424f9ed218ef1690656249cdfb..789593d3cc9c58362124159a186ab03f546f7ae1 100644 (file)
@@ -12,8 +12,9 @@ from . import pq
 from . import proto
 from . import errors as e
 from ._enums import Format as Format
-from .oids import TypesRegistry, postgres_types
+from .oids import postgres_types
 from .proto import AdaptContext, Buffer as Buffer
+from .typeinfo import TypesRegistry
 
 if TYPE_CHECKING:
     from .connection import BaseConnection
index 2942f3b56b0c967b74f9a7e851992df3d704a7e6..48c2ee9007e48a43cab7df8be939baa045263e6b 100644 (file)
@@ -4,243 +4,88 @@ Maps of builtin types and names
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
-from typing import Dict, Iterator, Optional, Union
-
-
-class TypeInfo:
-    def __init__(
-        self, name: str, oid: int, array_oid: int, range_subtype: int = 0
-    ):
-        self.name = name
-        self.oid = oid
-        self.array_oid = array_oid
-        self.range_subtype = range_subtype
-
-    def __repr__(self) -> str:
-        return (
-            f"<{self.__class__.__qualname__}:"
-            f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
-        )
-
-
-class BuiltinTypeInfo(TypeInfo):
-    def __init__(
-        self,
-        name: str,
-        oid: int,
-        array_oid: int,
-        range_subtype: int,
-        alt_name: str,
-        delimiter: str,
-    ):
-        super().__init__(name, oid, array_oid, range_subtype)
-        self.alt_name = alt_name
-        self.delimiter = delimiter
-
-
-class TypesRegistry:
-    """
-    Container for the information about types in a database.
-    """
-
-    def __init__(self, template: Optional["TypesRegistry"] = None):
-        self._by_oid: Dict[int, TypeInfo]
-        self._by_name: Dict[str, TypeInfo]
-        self._by_range_subtype: Dict[int, TypeInfo]
-
-        if template:
-            self._by_oid = template._by_oid
-            self._by_name = template._by_name
-            self._by_range_subtype = template._by_range_subtype
-            self._own_state = False
-        else:
-            self._by_oid = {}
-            self._by_name = {}
-            self._by_range_subtype = {}
-            self._own_state = True
-
-    def add(self, info: TypeInfo) -> None:
-        self._ensure_own_state()
-        self._by_oid[info.oid] = info
-        if info.array_oid:
-            self._by_oid[info.array_oid] = info
-        self._by_name[info.name] = info
-        if info.range_subtype:
-            # Note: actually not unique (e.g. range of different collation?)
-            self._by_range_subtype[info.range_subtype] = info
-
-        if isinstance(info, BuiltinTypeInfo):
-            if info.alt_name not in self._by_name:
-                self._by_name[info.alt_name] = info
-
-    def __iter__(self) -> Iterator[TypeInfo]:
-        seen = set()
-        for t in self._by_oid.values():
-            if t.oid not in seen:
-                seen.add(t.oid)
-                yield t
-
-    def __getitem__(self, key: Union[str, int]) -> TypeInfo:
-        """
-        Return info about a type, specified by name or oid
-
-        The type name or oid may refer to the array too.
-
-        Raise KeyError if not found.
-        """
-        if isinstance(key, str):
-            if key.endswith("[]"):
-                key = key[:-2]
-            return self._by_name[key]
-        elif isinstance(key, int):
-            return self._by_oid[key]
-        else:
-            raise TypeError(
-                f"the key must be an oid or a name, got {type(key)}"
-            )
-
-    def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
-        """
-        Return info about a type, specified by name or oid
-
-        The type name or oid may refer to the array too.
-
-        Return None if not found.
-        """
-        try:
-            return self[key]
-        except KeyError:
-            return None
-
-    def get_oid(self, name: str) -> int:
-        """
-        Return the oid of a PostgreSQL type by name.
-
-        Return the array oid if the type ends with "[]"
-
-        Raise KeyError if the name is unknown.
-        """
-        t = self[name]
-        if name.endswith("[]"):
-            return t.array_oid
-        else:
-            return t.oid
-
-    def get_range(self, key: Union[str, int]) -> Optional[TypeInfo]:
-        """
-        Return info about a range by its element name or oid
-
-        Return None if the element or its range are not found.
-        """
-        try:
-            info = self[key]
-        except KeyError:
-            return None
-        return self._by_range_subtype.get(info.oid)
-
-    def _ensure_own_state(self) -> None:
-        # Time to write! so, copy.
-        if not self._own_state:
-            self._by_oid = self._by_oid.copy()
-            self._by_name = self._by_name.copy()
-            self._by_range_subtype = self._by_range_subtype.copy()
-            self._own_state = True
-
+from .typeinfo import TypeInfo, RangeInfo, TypesRegistry
 
+# Global objects with PostgreSQL builtins and globally registered user types.
 postgres_types = TypesRegistry()
 
+
 # Use tools/update_oids.py to update this data.
-for r in [
+for t in [
     # autogenerated: start
-    # Generated from PostgreSQL 13.1
-    ("aclitem", 1033, 1034, 0, "aclitem", ","),
-    ("any", 2276, 0, 0, '"any"', ","),
-    ("anyarray", 2277, 0, 0, "anyarray", ","),
-    ("anycompatible", 5077, 0, 0, "anycompatible", ","),
-    ("anycompatiblearray", 5078, 0, 0, "anycompatiblearray", ","),
-    ("anycompatiblenonarray", 5079, 0, 0, "anycompatiblenonarray", ","),
-    ("anycompatiblerange", 5080, 0, 0, "anycompatiblerange", ","),
-    ("anyelement", 2283, 0, 0, "anyelement", ","),
-    ("anyenum", 3500, 0, 0, "anyenum", ","),
-    ("anynonarray", 2776, 0, 0, "anynonarray", ","),
-    ("anyrange", 3831, 0, 0, "anyrange", ","),
-    ("bit", 1560, 1561, 0, "bit", ","),
-    ("bool", 16, 1000, 0, "boolean", ","),
-    ("box", 603, 1020, 0, "box", ";"),
-    ("bpchar", 1042, 1014, 0, "character", ","),
-    ("bytea", 17, 1001, 0, "bytea", ","),
-    ("char", 18, 1002, 0, '"char"', ","),
-    ("cid", 29, 1012, 0, "cid", ","),
-    ("cidr", 650, 651, 0, "cidr", ","),
-    ("circle", 718, 719, 0, "circle", ","),
-    ("cstring", 2275, 1263, 0, "cstring", ","),
-    ("date", 1082, 1182, 0, "date", ","),
-    ("daterange", 3912, 3913, 1082, "daterange", ","),
-    ("event_trigger", 3838, 0, 0, "event_trigger", ","),
-    ("float4", 700, 1021, 0, "real", ","),
-    ("float8", 701, 1022, 0, "double precision", ","),
-    ("gtsvector", 3642, 3644, 0, "gtsvector", ","),
-    ("inet", 869, 1041, 0, "inet", ","),
-    ("int2", 21, 1005, 0, "smallint", ","),
-    ("int2vector", 22, 1006, 0, "int2vector", ","),
-    ("int4", 23, 1007, 0, "integer", ","),
-    ("int4range", 3904, 3905, 23, "int4range", ","),
-    ("int8", 20, 1016, 0, "bigint", ","),
-    ("int8range", 3926, 3927, 20, "int8range", ","),
-    ("internal", 2281, 0, 0, "internal", ","),
-    ("interval", 1186, 1187, 0, "interval", ","),
-    ("json", 114, 199, 0, "json", ","),
-    ("jsonb", 3802, 3807, 0, "jsonb", ","),
-    ("jsonpath", 4072, 4073, 0, "jsonpath", ","),
-    ("line", 628, 629, 0, "line", ","),
-    ("lseg", 601, 1018, 0, "lseg", ","),
-    ("macaddr", 829, 1040, 0, "macaddr", ","),
-    ("macaddr8", 774, 775, 0, "macaddr8", ","),
-    ("money", 790, 791, 0, "money", ","),
-    ("name", 19, 1003, 0, "name", ","),
-    ("numeric", 1700, 1231, 0, "numeric", ","),
-    ("numrange", 3906, 3907, 1700, "numrange", ","),
-    ("oid", 26, 1028, 0, "oid", ","),
-    ("oidvector", 30, 1013, 0, "oidvector", ","),
-    ("path", 602, 1019, 0, "path", ","),
-    ("point", 600, 1017, 0, "point", ","),
-    ("polygon", 604, 1027, 0, "polygon", ","),
-    ("record", 2249, 2287, 0, "record", ","),
-    ("refcursor", 1790, 2201, 0, "refcursor", ","),
-    ("regclass", 2205, 2210, 0, "regclass", ","),
-    ("regcollation", 4191, 4192, 0, "regcollation", ","),
-    ("regconfig", 3734, 3735, 0, "regconfig", ","),
-    ("regdictionary", 3769, 3770, 0, "regdictionary", ","),
-    ("regnamespace", 4089, 4090, 0, "regnamespace", ","),
-    ("regoper", 2203, 2208, 0, "regoper", ","),
-    ("regoperator", 2204, 2209, 0, "regoperator", ","),
-    ("regproc", 24, 1008, 0, "regproc", ","),
-    ("regprocedure", 2202, 2207, 0, "regprocedure", ","),
-    ("regrole", 4096, 4097, 0, "regrole", ","),
-    ("regtype", 2206, 2211, 0, "regtype", ","),
-    ("text", 25, 1009, 0, "text", ","),
-    ("tid", 27, 1010, 0, "tid", ","),
-    ("time", 1083, 1183, 0, "time without time zone", ","),
-    ("timestamp", 1114, 1115, 0, "timestamp without time zone", ","),
-    ("timestamptz", 1184, 1185, 0, "timestamp with time zone", ","),
-    ("timetz", 1266, 1270, 0, "time with time zone", ","),
-    ("trigger", 2279, 0, 0, "trigger", ","),
-    ("tsquery", 3615, 3645, 0, "tsquery", ","),
-    ("tsrange", 3908, 3909, 1114, "tsrange", ","),
-    ("tstzrange", 3910, 3911, 1184, "tstzrange", ","),
-    ("tsvector", 3614, 3643, 0, "tsvector", ","),
-    ("txid_snapshot", 2970, 2949, 0, "txid_snapshot", ","),
-    ("unknown", 705, 0, 0, "unknown", ","),
-    ("uuid", 2950, 2951, 0, "uuid", ","),
-    ("varbit", 1562, 1563, 0, "bit varying", ","),
-    ("varchar", 1043, 1015, 0, "character varying", ","),
-    ("void", 2278, 0, 0, "void", ","),
-    ("xid", 28, 1011, 0, "xid", ","),
-    ("xid8", 5069, 271, 0, "xid8", ","),
-    ("xml", 142, 143, 0, "xml", ","),
+    # Generated from PostgreSQL 13.0
+    TypeInfo("aclitem", 1033, 1034),
+    TypeInfo("bit", 1560, 1561),
+    TypeInfo("bool", 16, 1000, alt_name="boolean"),
+    TypeInfo("box", 603, 1020, delimiter=";"),
+    TypeInfo("bpchar", 1042, 1014, alt_name="character"),
+    TypeInfo("bytea", 17, 1001),
+    TypeInfo("char", 18, 1002, alt_name='"char"'),
+    TypeInfo("cid", 29, 1012),
+    TypeInfo("cidr", 650, 651),
+    TypeInfo("circle", 718, 719),
+    TypeInfo("date", 1082, 1182),
+    TypeInfo("float4", 700, 1021, alt_name="real"),
+    TypeInfo("float8", 701, 1022, alt_name="double precision"),
+    TypeInfo("gtsvector", 3642, 3644),
+    TypeInfo("inet", 869, 1041),
+    TypeInfo("int2", 21, 1005, alt_name="smallint"),
+    TypeInfo("int2vector", 22, 1006),
+    TypeInfo("int4", 23, 1007, alt_name="integer"),
+    TypeInfo("int8", 20, 1016, alt_name="bigint"),
+    TypeInfo("interval", 1186, 1187),
+    TypeInfo("json", 114, 199),
+    TypeInfo("jsonb", 3802, 3807),
+    TypeInfo("jsonpath", 4072, 4073),
+    TypeInfo("line", 628, 629),
+    TypeInfo("lseg", 601, 1018),
+    TypeInfo("macaddr", 829, 1040),
+    TypeInfo("macaddr8", 774, 775),
+    TypeInfo("money", 790, 791),
+    TypeInfo("name", 19, 1003),
+    TypeInfo("numeric", 1700, 1231),
+    TypeInfo("oid", 26, 1028),
+    TypeInfo("oidvector", 30, 1013),
+    TypeInfo("path", 602, 1019),
+    TypeInfo("point", 600, 1017),
+    TypeInfo("polygon", 604, 1027),
+    TypeInfo("record", 2249, 2287),
+    TypeInfo("refcursor", 1790, 2201),
+    TypeInfo("regclass", 2205, 2210),
+    TypeInfo("regcollation", 4191, 4192),
+    TypeInfo("regconfig", 3734, 3735),
+    TypeInfo("regdictionary", 3769, 3770),
+    TypeInfo("regnamespace", 4089, 4090),
+    TypeInfo("regoper", 2203, 2208),
+    TypeInfo("regoperator", 2204, 2209),
+    TypeInfo("regproc", 24, 1008),
+    TypeInfo("regprocedure", 2202, 2207),
+    TypeInfo("regrole", 4096, 4097),
+    TypeInfo("regtype", 2206, 2211),
+    TypeInfo("text", 25, 1009),
+    TypeInfo("tid", 27, 1010),
+    TypeInfo("time", 1083, 1183, alt_name="time without time zone"),
+    TypeInfo("timestamp", 1114, 1115, alt_name="timestamp without time zone"),
+    TypeInfo("timestamptz", 1184, 1185, alt_name="timestamp with time zone"),
+    TypeInfo("timetz", 1266, 1270, alt_name="time with time zone"),
+    TypeInfo("tsquery", 3615, 3645),
+    TypeInfo("tsvector", 3614, 3643),
+    TypeInfo("txid_snapshot", 2970, 2949),
+    TypeInfo("uuid", 2950, 2951),
+    TypeInfo("varbit", 1562, 1563, alt_name="bit varying"),
+    TypeInfo("varchar", 1043, 1015, alt_name="character varying"),
+    TypeInfo("xid", 28, 1011),
+    TypeInfo("xid8", 5069, 271),
+    TypeInfo("xml", 142, 143),
+    RangeInfo("daterange", 3912, 3913, subtype_oid=1082),
+    RangeInfo("int4range", 3904, 3905, subtype_oid=23),
+    RangeInfo("int8range", 3926, 3927, subtype_oid=20),
+    RangeInfo("numrange", 3906, 3907, subtype_oid=1700),
+    RangeInfo("tsrange", 3908, 3909, subtype_oid=1114),
+    RangeInfo("tstzrange", 3910, 3911, subtype_oid=1184),
     # autogenerated: end
 ]:
-    postgres_types.add(BuiltinTypeInfo(*r))
+    postgres_types.add(t)
 
 
 # A few oids used a bit everywhere
diff --git a/psycopg3/psycopg3/typeinfo.py b/psycopg3/psycopg3/typeinfo.py
new file mode 100644 (file)
index 0000000..67f0bd7
--- /dev/null
@@ -0,0 +1,315 @@
+"""
+Information about PostgreSQL types
+
+These types allow to read information from the system catalog and provide
+information to the adapters if needed.
+"""
+
+# Copyright (C) 2020-2021 The Psycopg Team
+
+from typing import Any, Callable, Dict, Iterator, Optional
+from typing import Sequence, Type, TypeVar, Union, TYPE_CHECKING
+
+from . import errors as e
+from .proto import AdaptContext
+
+if TYPE_CHECKING:
+    from .connection import Connection, AsyncConnection
+    from .sql import Identifier
+
+T = TypeVar("T", bound="TypeInfo")
+
+
+class TypeInfo:
+    """
+    Hold information about a PostgreSQL base type.
+
+    The class allows to:
+
+    - read information about a range type using `fetch()` and `fetch_async()`
+    - configure a composite type adaptation using `register()`
+    """
+
+    def __init__(
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        alt_name: str = "",
+        delimiter: str = ",",
+    ):
+        self.name = name
+        self.oid = oid
+        self.array_oid = array_oid
+        self.alt_name = alt_name
+        self.delimiter = delimiter
+
+    def __repr__(self) -> str:
+        return (
+            f"<{self.__class__.__qualname__}:"
+            f" {self.name} (oid: {self.oid}, array oid: {self.array_oid})>"
+        )
+
+    @classmethod
+    def fetch(
+        cls: Type[T], conn: "Connection", name: Union[str, "Identifier"]
+    ) -> Optional[T]:
+        from .sql import Composable
+
+        if isinstance(name, Composable):
+            name = name.as_string(conn)
+        cur = conn.cursor(binary=True)
+        cur.execute(cls._info_query, {"name": name})
+        recs = cur.fetchall()
+        fields = [d[0] for d in cur.description or ()]
+        return cls._fetch(name, fields, recs)
+
+    @classmethod
+    async def fetch_async(
+        cls: Type[T], conn: "AsyncConnection", name: Union[str, "Identifier"]
+    ) -> Optional[T]:
+        from .sql import Composable
+
+        if isinstance(name, Composable):
+            name = name.as_string(conn)
+        cur = await conn.cursor(binary=True)
+        await cur.execute(cls._info_query, {"name": name})
+        recs = await cur.fetchall()
+        fields = [d[0] for d in cur.description or ()]
+        return cls._fetch(name, fields, recs)
+
+    @classmethod
+    def _fetch(
+        cls: Type[T],
+        name: str,
+        fields: Sequence[str],
+        recs: Sequence[Sequence[Any]],
+    ) -> Optional[T]:
+        if len(recs) == 1:
+            return cls(**dict(zip(fields, recs[0])))
+        elif not recs:
+            return None
+        else:
+            raise e.ProgrammingError(
+                f"found {len(recs)} different types named {name}"
+            )
+
+    def register(
+        self,
+        context: Optional["AdaptContext"] = None,
+    ) -> None:
+
+        if context:
+            types = context.adapters.types
+        else:
+            from .oids import postgres_types
+
+            types = postgres_types
+
+        types.add(self)
+
+        if self.array_oid:
+            from .types.array import register_adapters
+
+            register_adapters(self, context)
+
+    _info_query = """\
+select
+    typname as name, oid, typarray as array_oid,
+    oid::regtype as alt_name, typdelim as delimiter
+from pg_type t
+where t.oid = %(name)s::regtype
+order by t.oid
+"""
+
+
+class RangeInfo(TypeInfo):
+    """Manage information about a range type."""
+
+    def __init__(self, name: str, oid: int, array_oid: int, subtype_oid: int):
+        super().__init__(name, oid, array_oid)
+        self.subtype_oid = subtype_oid
+
+    def register(
+        self,
+        context: Optional[AdaptContext] = None,
+    ) -> None:
+        super().register(context)
+
+        from .types.range import register_adapters
+
+        register_adapters(self, context)
+
+    _info_query = """\
+select t.typname as name, t.oid as oid, t.typarray as array_oid,
+    r.rngsubtype as subtype_oid
+from pg_type t
+join pg_range r on t.oid = r.rngtypid
+where t.oid = %(name)s::regtype
+"""
+
+
+class CompositeInfo(TypeInfo):
+    """Manage information about a composite type."""
+
+    def __init__(
+        self,
+        name: str,
+        oid: int,
+        array_oid: int,
+        field_names: Sequence[str],
+        field_types: Sequence[int],
+    ):
+        super().__init__(name, oid, array_oid)
+        self.field_names = field_names
+        self.field_types = field_types
+
+    def register(
+        self,
+        context: Optional[AdaptContext] = None,
+        factory: Optional[Callable[..., Any]] = None,
+    ) -> None:
+        super().register(context)
+
+        from .types.composite import register_adapters
+
+        register_adapters(self, context, factory)
+
+    _info_query = """\
+select
+    t.typname as name, t.oid as oid, t.typarray as array_oid,
+    coalesce(a.fnames, '{}') as field_names,
+    coalesce(a.ftypes, '{}') as field_types
+from pg_type t
+left join (
+    select
+        attrelid,
+        array_agg(attname) as fnames,
+        array_agg(atttypid) as ftypes
+    from (
+        select a.attrelid, a.attname, a.atttypid
+        from pg_attribute a
+        join pg_type t on t.typrelid = a.attrelid
+        where t.oid = %(name)s::regtype
+        and a.attnum > 0
+        and not a.attisdropped
+        order by a.attnum
+    ) x
+    group by attrelid
+) a on a.attrelid = t.typrelid
+where t.oid = %(name)s::regtype
+"""
+
+
+class TypesRegistry:
+    """
+    Container for the information about types in a database.
+    """
+
+    def __init__(self, template: Optional["TypesRegistry"] = None):
+        self._by_oid: Dict[int, TypeInfo]
+        self._by_name: Dict[str, TypeInfo]
+        self._by_range_subtype: Dict[int, TypeInfo]
+
+        # Make a shallow copy: it will become a proper copy if the registry
+        # is edited (note the BUG: a child will get shallow-copied, but changing
+        # the parent will change children who weren't copied yet. It can be
+        # probably fixed by setting _own_state to False on the parent on copy,
+        # but needs testing and for the moment I'll leave it there TODO).
+        if template:
+            self._by_oid = template._by_oid
+            self._by_name = template._by_name
+            self._by_range_subtype = template._by_range_subtype
+            self._own_state = False
+        else:
+            self._by_oid = {}
+            self._by_name = {}
+            self._by_range_subtype = {}
+            self._own_state = True
+
+    def add(self, info: TypeInfo) -> None:
+        self._ensure_own_state()
+        self._by_oid[info.oid] = info
+        if info.array_oid:
+            self._by_oid[info.array_oid] = info
+        self._by_name[info.name] = info
+
+        if info.alt_name and info.alt_name not in self._by_name:
+            self._by_name[info.alt_name] = info
+
+        # Map ranges subtypes to info
+        if isinstance(info, RangeInfo):
+            self._by_range_subtype[info.subtype_oid] = info
+
+    def __iter__(self) -> Iterator[TypeInfo]:
+        seen = set()
+        for t in self._by_oid.values():
+            if t.oid not in seen:
+                seen.add(t.oid)
+                yield t
+
+    def __getitem__(self, key: Union[str, int]) -> TypeInfo:
+        """
+        Return info about a type, specified by name or oid
+
+        The type name or oid may refer to the array too.
+
+        Raise KeyError if not found.
+        """
+        if isinstance(key, str):
+            if key.endswith("[]"):
+                key = key[:-2]
+            return self._by_name[key]
+        elif isinstance(key, int):
+            return self._by_oid[key]
+        else:
+            raise TypeError(
+                f"the key must be an oid or a name, got {type(key)}"
+            )
+
+    def get(self, key: Union[str, int]) -> Optional[TypeInfo]:
+        """
+        Return info about a type, specified by name or oid
+
+        The type name or oid may refer to the array too.
+
+        Return None if not found.
+        """
+        try:
+            return self[key]
+        except KeyError:
+            return None
+
+    def get_oid(self, name: str) -> int:
+        """
+        Return the oid of a PostgreSQL type by name.
+
+        Return the array oid if the type ends with "[]"
+
+        Raise KeyError if the name is unknown.
+        """
+        t = self[name]
+        if name.endswith("[]"):
+            return t.array_oid
+        else:
+            return t.oid
+
+    def get_range(self, key: Union[str, int]) -> Optional[TypeInfo]:
+        """
+        Return info about a range by its element name or oid
+
+        Return None if the element or its range are not found.
+        """
+        try:
+            info = self[key]
+        except KeyError:
+            return None
+        return self._by_range_subtype.get(info.oid)
+
+    def _ensure_own_state(self) -> None:
+        # Time to write! so, copy.
+        if not self._own_state:
+            self._by_oid = self._by_oid.copy()
+            self._by_name = self._by_name.copy()
+            self._by_range_subtype = self._by_range_subtype.copy()
+            self._own_state = True
index dc70082b9b3dd89b14683b4c4236ade5b02e160b..2b7bda463947cd9f1b7c469ec1765860999c53ec 100644 (file)
@@ -8,17 +8,15 @@ from ..oids import INVALID_OID
 from ..proto import AdaptContext
 
 # Register default adapters
-from . import array, composite
-from . import range
+from . import array, composite, range
 
 # Wrapper objects
 from ..wrappers.numeric import Int2, Int4, Int8, IntNumeric, Oid
 from .json import Json, Jsonb
 from .range import Range
 
-# Supper objects
-from .range import RangeInfo
-from .composite import CompositeInfo
+# Database types descriptors
+from ..typeinfo import TypeInfo, RangeInfo, CompositeInfo
 
 # Adapter objects
 from .text import (
index b68bf77c35bec2305b985f6ea406ea5f181a0538..b282d8583ee06af5e065bb5dcf77e8399a54d76a 100644 (file)
@@ -14,6 +14,7 @@ from ..oids import postgres_types, TEXT_OID, TEXT_ARRAY_OID, INVALID_OID
 from ..adapt import Buffer, Dumper, Loader, Transformer
 from ..adapt import Format as Pg3Format
 from ..proto import AdaptContext
+from ..typeinfo import TypeInfo
 
 
 class BaseListDumper(Dumper):
@@ -315,20 +316,15 @@ class ArrayBinaryLoader(BaseArrayLoader):
         return agg(dims)
 
 
-def register(
-    array_oid: int,
-    base_oid: int,
-    context: Optional[AdaptContext] = None,
-    name: Optional[str] = None,
+def register_adapters(
+    info: TypeInfo, context: Optional["AdaptContext"]
 ) -> None:
-    if not name:
-        name = f"oid{base_oid}"
-
     for base in (ArrayLoader, ArrayBinaryLoader):
-        fmt = "Binary" if base.format == pq.Format.BINARY else ""
-        lname = f"{name.title()}Array{fmt}Loader"
-        loader: Type[Loader] = type(lname, (base,), {"base_oid": base_oid})
-        loader.register(array_oid, context=context)
+        lname = f"{info.name.title()}{base.__name__}"
+        loader: Type[BaseArrayLoader] = type(
+            lname, (base,), {"base_oid": info.oid}
+        )
+        loader.register(info.array_oid, context=context)
 
 
 def register_all_arrays(ctx: AdaptContext) -> None:
@@ -341,4 +337,4 @@ def register_all_arrays(ctx: AdaptContext) -> None:
     for t in ctx.adapters.types:
         # TODO: handle different delimiters (box)
         if t.array_oid and getattr(t, "delimiter", None) == ",":
-            register(t.array_oid, t.oid, name=t.name)
+            t.register(ctx)
index c43a68b1c7393b33edb6e46ae2e1258043c59032..f91f29b1b2b12e92e7d6cdbc8f74369009c13cd6 100644 (file)
@@ -7,141 +7,14 @@ Support for composite types adaptation.
 import re
 import struct
 from collections import namedtuple
-from typing import Any, Callable, Iterator, List, NamedTuple, Optional
-from typing import Sequence, Tuple, Type, Union, TYPE_CHECKING
+from typing import Any, Callable, Iterator, List, Optional
+from typing import Sequence, Tuple, Type
 
 from .. import pq
-from .. import sql
-from .. import errors as e
-from ..oids import TypeInfo, TEXT_OID
+from ..oids import TEXT_OID
 from ..adapt import Buffer, Format, Dumper, Loader, Transformer
 from ..proto import AdaptContext
-from . import array
-
-if TYPE_CHECKING:
-    from ..connection import Connection, AsyncConnection
-
-
-class CompositeInfo(TypeInfo):
-    """Manage information about a composite type.
-
-    The class allows to:
-
-    - read information about a composite type using `fetch()` and `fetch_async()`
-    - configure a composite type adaptation using `register()`
-    """
-
-    def __init__(
-        self,
-        name: str,
-        oid: int,
-        array_oid: int,
-        fields: Sequence["CompositeInfo.FieldInfo"],
-    ):
-        super().__init__(name, oid, array_oid)
-        self.fields = list(fields)
-
-    class FieldInfo(NamedTuple):
-        """Information about a single field in a composite type."""
-
-        name: str
-        type_oid: int
-
-    @classmethod
-    def fetch(
-        cls, conn: "Connection", name: Union[str, sql.Identifier]
-    ) -> Optional["CompositeInfo"]:
-        if isinstance(name, sql.Composable):
-            name = name.as_string(conn)
-        cur = conn.cursor()
-        cur.execute(cls._info_query, {"name": name})
-        recs = cur.fetchall()
-        return cls._from_records(recs)
-
-    @classmethod
-    async def fetch_async(
-        cls, conn: "AsyncConnection", name: Union[str, sql.Identifier]
-    ) -> Optional["CompositeInfo"]:
-        if isinstance(name, sql.Composable):
-            name = name.as_string(conn)
-        cur = await conn.cursor()
-        await cur.execute(cls._info_query, {"name": name})
-        recs = await cur.fetchall()
-        return cls._from_records(recs)
-
-    def register(
-        self,
-        context: Optional[AdaptContext] = None,
-        factory: Optional[Callable[..., Any]] = None,
-    ) -> None:
-        if not factory:
-            factory = namedtuple(  # type: ignore
-                self.name, [f.name for f in self.fields]
-            )
-
-        loader: Type[Loader]
-
-        # generate and register a customized text loader
-        loader = type(
-            f"{self.name.title()}Loader",
-            (CompositeLoader,),
-            {
-                "factory": factory,
-                "fields_types": [f.type_oid for f in self.fields],
-            },
-        )
-        loader.register(self.oid, context=context)
-
-        # generate and register a customized binary loader
-        loader = type(
-            f"{self.name.title()}BinaryLoader",
-            (CompositeBinaryLoader,),
-            {"factory": factory},
-        )
-        loader.register(self.oid, context=context)
-
-        if self.array_oid:
-            array.register(
-                self.array_oid, self.oid, context=context, name=self.name
-            )
-
-    @classmethod
-    def _from_records(cls, recs: Sequence[Any]) -> Optional["CompositeInfo"]:
-        if not recs:
-            return None
-        if len(recs) > 1:
-            raise e.ProgrammingError(
-                f"found {len(recs)} different types named {recs[0][0]}"
-            )
-
-        name, oid, array_oid, fnames, ftypes = recs[0]
-        fields = [cls.FieldInfo(*p) for p in zip(fnames, ftypes)]
-        return cls(name, oid, array_oid, fields)
-
-    _info_query = """\
-select
-    t.typname as name, t.oid as oid, t.typarray as array_oid,
-    coalesce(a.fnames, '{}') as fnames,
-    coalesce(a.ftypes, '{}') as ftypes
-from pg_type t
-left join (
-    select
-        attrelid,
-        array_agg(attname) as fnames,
-        array_agg(atttypid) as ftypes
-    from (
-        select a.attrelid, a.attname, a.atttypid
-        from pg_attribute a
-        join pg_type t on t.typrelid = a.attrelid
-        where t.oid = %(name)s::regtype
-        and a.attnum > 0
-        and not a.attisdropped
-        order by a.attnum
-    ) x
-    group by attrelid
-) a on a.attrelid = t.typrelid
-where t.oid = %(name)s::regtype
-"""
+from ..typeinfo import CompositeInfo
 
 
 class SequenceDumper(Dumper):
@@ -317,3 +190,31 @@ class CompositeBinaryLoader(RecordBinaryLoader):
     def load(self, data: Buffer) -> Any:
         r = super().load(data)
         return type(self).factory(*r)
+
+
+def register_adapters(
+    info: CompositeInfo,
+    context: Optional["AdaptContext"],
+    factory: Optional[Callable[..., Any]] = None,
+) -> None:
+    if not factory:
+        factory = namedtuple(info.name, info.field_names)  # type: ignore
+
+    # generate and register a customized text loader
+    loader: Type[BaseCompositeLoader] = type(
+        f"{info.name.title()}Loader",
+        (CompositeLoader,),
+        {
+            "factory": factory,
+            "fields_types": info.field_types,
+        },
+    )
+    loader.register(info.oid, context=context)
+
+    # generate and register a customized binary loader
+    loader = type(
+        f"{info.name.title()}BinaryLoader",
+        (CompositeBinaryLoader,),
+        {"factory": factory},
+    )
+    loader.register(info.oid, context=context)
index 69ac4ee441214b9ac7a40996cdc7890728925644..43b9ccbd664979116aacdb27c1d0c86be3819ec6 100644 (file)
@@ -5,24 +5,18 @@ Support for range types adaptation.
 # Copyright (C) 2020-2021 The Psycopg Team
 
 import re
-from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Type, Union
-from typing import cast, Tuple, TYPE_CHECKING
+from typing import cast, Any, Dict, Generic, Optional, Tuple, TypeVar, Type
 from decimal import Decimal
 from datetime import date, datetime
 
-from .. import sql
-from .. import errors as e
 from ..pq import Format
-from ..oids import postgres_types as builtins, TypeInfo, INVALID_OID
-from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
+from ..oids import postgres_types as builtins, INVALID_OID
+from ..adapt import Buffer, Dumper, Format as Pg3Format
 from ..proto import AdaptContext
+from ..typeinfo import RangeInfo
 
-from . import array
 from .composite import SequenceDumper, BaseCompositeLoader
 
-if TYPE_CHECKING:
-    from ..connection import Connection, AsyncConnection
-
 T = TypeVar("T")
 
 
@@ -312,6 +306,18 @@ class RangeLoader(BaseCompositeLoader, Generic[T]):
 _int2parens = {ord(c): c for c in "[]()"}
 
 
+def register_adapters(
+    info: RangeInfo, context: Optional["AdaptContext"]
+) -> None:
+    # generate and register a customized text loader
+    loader: Type[RangeLoader[Any]] = type(
+        f"{info.name.title()}Loader",
+        (RangeLoader,),
+        {"subtype_oid": info.subtype_oid},
+    )
+    loader.register(info.oid, context=context)
+
+
 # Loaders for builtin range types
 
 
@@ -337,76 +343,3 @@ class TimestampRangeLoader(RangeLoader[datetime]):
 
 class TimestampTZRangeLoader(RangeLoader[datetime]):
     subtype_oid = builtins["timestamptz"].oid
-
-
-class RangeInfo(TypeInfo):
-    """Manage information about a range type.
-
-    The class allows to:
-
-    - read information about a range type using `fetch()` and `fetch_async()`
-    - configure a composite type adaptation using `register()`
-    """
-
-    @classmethod
-    def fetch(
-        cls, conn: "Connection", name: Union[str, sql.Identifier]
-    ) -> Optional["RangeInfo"]:
-        if isinstance(name, sql.Composable):
-            name = name.as_string(conn)
-        cur = conn.cursor(binary=True)
-        cur.execute(cls._info_query, {"name": name})
-        recs = cur.fetchall()
-        return cls._from_records(recs)
-
-    @classmethod
-    async def fetch_async(
-        cls, conn: "AsyncConnection", name: Union[str, sql.Identifier]
-    ) -> Optional["RangeInfo"]:
-        if isinstance(name, sql.Composable):
-            name = name.as_string(conn)
-        cur = await conn.cursor(binary=True)
-        await cur.execute(cls._info_query, {"name": name})
-        recs = await cur.fetchall()
-        return cls._from_records(recs)
-
-    def register(
-        self,
-        context: Optional[AdaptContext] = None,
-    ) -> None:
-
-        if context:
-            context.adapters.types.add(self)
-
-        # generate and register a customized text loader
-        loader: Type[Loader] = type(
-            f"{self.name.title()}Loader",
-            (RangeLoader,),
-            {"subtype_oid": self.range_subtype},
-        )
-        loader.register(self.oid, context=context)
-
-        if self.array_oid:
-            array.register(
-                self.array_oid, self.oid, context=context, name=self.name
-            )
-
-    @classmethod
-    def _from_records(cls, recs: Sequence[Any]) -> Optional["RangeInfo"]:
-        if not recs:
-            return None
-        if len(recs) > 1:
-            raise e.ProgrammingError(
-                f"found {len(recs)} different ranges named {recs[0][0]}"
-            )
-
-        name, oid, array_oid, subtype = recs[0]
-        return cls(name, oid, array_oid, subtype)
-
-    _info_query = """\
-select t.typname as name, t.oid as oid, t.typarray as array_oid,
-    r.rngsubtype as range_subtype
-from pg_type t
-join pg_range r on t.oid = r.rngtypid
-where t.oid = %(name)s::regtype
-"""
index 7e82198e67cef0a69c7949b7f1e4cf3c6982400c..0249bb35bb1ba18f029f047a416609f8d7a4294e 100644 (file)
@@ -11,19 +11,9 @@ cdef enum:
 
     # autogenerated: start
 
-    # Generated from PostgreSQL 13.1
+    # Generated from PostgreSQL 13.0
 
     ACLITEM_OID = 1033
-    ANY_OID = 2276
-    ANYARRAY_OID = 2277
-    ANYCOMPATIBLE_OID = 5077
-    ANYCOMPATIBLEARRAY_OID = 5078
-    ANYCOMPATIBLENONARRAY_OID = 5079
-    ANYCOMPATIBLERANGE_OID = 5080
-    ANYELEMENT_OID = 2283
-    ANYENUM_OID = 3500
-    ANYNONARRAY_OID = 2776
-    ANYRANGE_OID = 3831
     BIT_OID = 1560
     BOOL_OID = 16
     BOX_OID = 603
@@ -33,10 +23,8 @@ cdef enum:
     CID_OID = 29
     CIDR_OID = 650
     CIRCLE_OID = 718
-    CSTRING_OID = 2275
     DATE_OID = 1082
     DATERANGE_OID = 3912
-    EVENT_TRIGGER_OID = 3838
     FLOAT4_OID = 700
     FLOAT8_OID = 701
     GTSVECTOR_OID = 3642
@@ -47,7 +35,6 @@ cdef enum:
     INT4RANGE_OID = 3904
     INT8_OID = 20
     INT8RANGE_OID = 3926
-    INTERNAL_OID = 2281
     INTERVAL_OID = 1186
     JSON_OID = 114
     JSONB_OID = 3802
@@ -84,17 +71,14 @@ cdef enum:
     TIMESTAMP_OID = 1114
     TIMESTAMPTZ_OID = 1184
     TIMETZ_OID = 1266
-    TRIGGER_OID = 2279
     TSQUERY_OID = 3615
     TSRANGE_OID = 3908
     TSTZRANGE_OID = 3910
     TSVECTOR_OID = 3614
     TXID_SNAPSHOT_OID = 2970
-    UNKNOWN_OID = 705
     UUID_OID = 2950
     VARBIT_OID = 1562
     VARCHAR_OID = 1043
-    VOID_OID = 2278
     XID_OID = 28
     XID8_OID = 5069
     XML_OID = 142
index 22087683143fc8acc10849567ef1b83f05e866be..ea1805df1151e50a5fb7bfde8135d8b771230446 100644 (file)
@@ -4,7 +4,7 @@ from psycopg3 import pq
 from psycopg3 import sql
 from psycopg3.oids import postgres_types as builtins
 from psycopg3.adapt import Format, Transformer
-from psycopg3.types import array
+from psycopg3.types import TypeInfo
 
 
 tests_str = [
@@ -118,9 +118,9 @@ def test_array_register(conn):
     assert res[0] == "(foo)"
     assert res[1] == "{(foo)}"
 
-    array.register(
-        cur.description[1].type_code, cur.description[0].type_code, context=cur
-    )
+    info = TypeInfo.fetch(conn, "mytype")
+    info.register(cur)
+
     cur.execute("""select '(foo)'::mytype, '{"(foo)"}'::mytype[] -- 2""")
     res = cur.fetchone()
     assert res[0] == "(foo)"
index 32d2ad75a5f8e3c975351949cceb3a4cbe70ffcc..078e4f3b45d0f0e9db2273afcd1e7d884038b848 100644 (file)
@@ -4,7 +4,7 @@ from psycopg3 import pq
 from psycopg3.sql import Identifier
 from psycopg3.oids import postgres_types as builtins
 from psycopg3.adapt import Format, global_adapters
-from psycopg3.types.composite import CompositeInfo
+from psycopg3.types import CompositeInfo
 
 
 tests_str = [
@@ -135,10 +135,11 @@ def test_fetch_info(conn, testcomp, name, fields):
     assert info.name == "testcomp"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
-    assert len(info.fields) == 3
+    assert len(info.field_names) == 3
+    assert len(info.field_types) == 3
     for i, (name, t) in enumerate(fields):
-        assert info.fields[i].name == name
-        assert info.fields[i].type_oid == builtins[t].oid
+        assert info.field_names[i] == name
+        assert info.field_types[i] == builtins[t].oid
 
 
 @pytest.mark.asyncio
@@ -148,10 +149,11 @@ async def test_fetch_info_async(aconn, testcomp, name, fields):
     assert info.name == "testcomp"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
-    assert len(info.fields) == 3
+    assert len(info.field_names) == 3
+    assert len(info.field_types) == 3
     for i, (name, t) in enumerate(fields):
-        assert info.fields[i].name == name
-        assert info.fields[i].type_oid == builtins[t].oid
+        assert info.field_names[i] == name
+        assert info.field_types[i] == builtins[t].oid
 
 
 @pytest.mark.parametrize("fmt_in", [Format.AUTO, Format.TEXT, Format.BINARY])
index 0ccbc9545008055c0c3e58bac478d06b5032e7f3..f170cbb9e9bea286f72c29f6fdfd49b6c2bb4e26 100644 (file)
@@ -5,8 +5,7 @@ from decimal import Decimal
 import pytest
 
 from psycopg3.sql import Identifier
-from psycopg3.types import range as mrange
-from psycopg3.types.range import Range
+from psycopg3.types import Range, RangeInfo
 
 
 type2sub = {
@@ -159,36 +158,36 @@ fetch_cases = [
 
 @pytest.mark.parametrize("name, subtype", fetch_cases)
 def test_fetch_info(conn, testrange, name, subtype):
-    info = mrange.RangeInfo.fetch(conn, name)
+    info = RangeInfo.fetch(conn, name)
     assert info.name == "testrange"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
-    assert info.range_subtype == conn.adapters.types[subtype].oid
+    assert info.subtype_oid == conn.adapters.types[subtype].oid
 
 
 def test_fetch_info_not_found(conn):
     with pytest.raises(conn.ProgrammingError):
-        mrange.RangeInfo.fetch(conn, "nosuchrange")
+        RangeInfo.fetch(conn, "nosuchrange")
 
 
 @pytest.mark.asyncio
 @pytest.mark.parametrize("name, subtype", fetch_cases)
 async def test_fetch_info_async(aconn, testrange, name, subtype):
-    info = await mrange.RangeInfo.fetch_async(aconn, name)
+    info = await RangeInfo.fetch_async(aconn, name)
     assert info.name == "testrange"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
-    assert info.range_subtype == aconn.adapters.types[subtype].oid
+    assert info.subtype_oid == aconn.adapters.types[subtype].oid
 
 
 @pytest.mark.asyncio
 async def test_fetch_info_not_found_async(aconn):
     with pytest.raises(aconn.ProgrammingError):
-        await mrange.RangeInfo.fetch_async(aconn, "nosuchrange")
+        await RangeInfo.fetch_async(aconn, "nosuchrange")
 
 
 def test_dump_custom_empty(conn, testrange):
-    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info = RangeInfo.fetch(conn, "testrange")
     info.register(conn)
 
     r = Range(empty=True)
@@ -197,7 +196,7 @@ def test_dump_custom_empty(conn, testrange):
 
 
 def test_dump_quoting(conn, testrange):
-    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info = RangeInfo.fetch(conn, "testrange")
     info.register(conn)
     cur = conn.cursor()
     for i in range(1, 254):
@@ -212,16 +211,16 @@ def test_dump_quoting(conn, testrange):
 
 
 def test_load_custom_empty(conn, testrange):
-    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info = RangeInfo.fetch(conn, "testrange")
     info.register(conn)
 
     (got,) = conn.execute("select 'empty'::testrange").fetchone()
-    assert isinstance(got, mrange.Range)
+    assert isinstance(got, Range)
     assert got.isempty
 
 
 def test_load_quoting(conn, testrange):
-    info = mrange.RangeInfo.fetch(conn, "testrange")
+    info = RangeInfo.fetch(conn, "testrange")
     info.register(conn)
     cur = conn.cursor()
     for i in range(1, 254):
@@ -230,7 +229,7 @@ def test_load_quoting(conn, testrange):
             {"low": i, "up": i + 1},
         )
         got = cur.fetchone()[0]
-        assert isinstance(got, mrange.Range)
+        assert isinstance(got, Range)
         assert ord(got.lower) == i
         assert ord(got.upper) == i + 1
 
@@ -483,7 +482,7 @@ class TestRangeObject:
         string conversion.
         """
         tz = dt.timezone(dt.timedelta(hours=-5))
-        r = mrange.Range(
+        r = Range(
             dt.datetime(2010, 1, 1, tzinfo=tz),
             dt.datetime(2011, 1, 1, tzinfo=tz),
         )
index 62c90d39513b43a8c8e49eb5b785c50f885228d1..85349db2281d876039f8ea3c8b57767b190f4c70 100755 (executable)
@@ -22,29 +22,57 @@ $$,
     where name = 'server_version_num'
 """
 
-py_oids_sql = """
-select format(
-        '(%L, %s, %s, %s, %L, %L),',
-        typname, oid, typarray, coalesce(rngsubtype, 0), oid::regtype, typdelim)
-    from pg_type t
-    left join pg_range r on t.oid = rngtypid
-    where oid < 10000
-    and typname !~ all('{^(_|pg_),_handler$}')
-    order by typname
+# Note: "record" is a pseudotype but still a useful one to have.
+py_types_sql = """
+select
+    'TypeInfo('
+    || array_to_string(array_remove(array[
+        format('%L', typname),
+        oid::text,
+        typarray::text,
+        case when oid::regtype::text != typname
+            then format('alt_name=%L', oid::regtype)
+        end,
+        case when typdelim != ','
+            then format('delimiter=%L', typdelim)
+        end
+    ], null), ',')
+    || '),'
+from pg_type t
+where
+    oid < 10000
+    and (typtype = 'b' or typname = 'record')
+    and typname !~ '^(_|pg_)'
+order by typname
 """
 
+py_ranges_sql = """
+select
+    format('RangeInfo(%L, %s, %s, subtype_oid=%s),',
+        typname, oid, typarray, rngsubtype)
+from
+    pg_type t
+    join pg_range r on t.oid = rngtypid
+where
+    oid < 10000
+    and typtype = 'r'
+    and typname !~ '^(_|pg_)'
+order by typname
+"""
 
 cython_oids_sql = """
 select format('%s_OID = %s', upper(typname), oid)
-    from pg_type
-    where oid < 10000
-    and typname !~ all('{^(_|pg_),_handler$}')
-    order by typname
+from pg_type
+where
+    oid < 10000
+    and (typtype = any('{b,r}') or typname = 'record')
+    and typname !~ '^(_|pg_)'
+order by typname
 """
 
 
 def update_python_oids() -> None:
-    queries = [version_sql, py_oids_sql]
+    queries = [version_sql, py_types_sql, py_ranges_sql]
     fn = ROOT / "psycopg3/psycopg3/oids.py"
     update_file(fn, queries)
     sp.check_call(["black", "-q", fn])