]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Have a single Range class, not one per subtype
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Thu, 21 Jan 2021 03:05:33 +0000 (04:05 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sat, 23 Jan 2021 01:55:41 +0000 (02:55 +0100)
There are two problems here, so this doesn't work. One, which will be
solved, is to have a registry of types attached to an adaptation
context: a TODO in already several places, so this will be solved.
Before that we don't really have a way to find back e.g. the oid of a
custom range on strings starting from Range('a', 'b').

The second problem, more serious, is that Postgres doesn't cast
int4range <-> int8range and tsrange <-> tstzrange. The latter pair can
be solved with a two steps dumper choosing between tz aware and not. The
first I don't have in mind how to solve it: Given Range(1,2) I wouldn't
know if int4 or int8 should be used, and Postgres doesn't seem very
forgiving. Probably we should go unknown.

psycopg3/psycopg3/oids.py
psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/range.py
tests/types/test_range.py
tools/update_oids.py

index a30b52f72eaf2b8f0f6b0493c14125e7685a20da..daedcf3121ab3f6e05d9a7a9ab622b5442fabd7d 100644 (file)
@@ -8,10 +8,13 @@ from typing import Dict, Iterator, Optional, Union
 
 
 class TypeInfo:
-    def __init__(self, name: str, oid: int, array_oid: int):
+    def __init__(
+        self, name: str, oid: int, array_oid: int, range_subtype: int = 0
+    ):
         self.name = name
         self.oid = oid
         self.array_oid = array_oid
+        self.range_subtype = range_subtype
 
     def __repr__(self) -> str:
         return (
@@ -26,10 +29,11 @@ class BuiltinTypeInfo(TypeInfo):
         name: str,
         oid: int,
         array_oid: int,
+        range_subtype: int,
         alt_name: str,
         delimiter: str,
     ):
-        super().__init__(name, oid, array_oid)
+        super().__init__(name, oid, array_oid, range_subtype)
         self.alt_name = alt_name
         self.delimiter = delimiter
 
@@ -42,12 +46,16 @@ class TypesRegistry:
     def __init__(self) -> None:
         self._by_oid: Dict[int, TypeInfo] = {}
         self._by_name: Dict[str, TypeInfo] = {}
+        self._by_range_subtype: Dict[int, TypeInfo] = {}
 
     def add(self, info: TypeInfo) -> None:
         self._by_oid[info.oid] = info
         if info.array_oid:
             self._by_oid[info.array_oid] = info
         self._by_name[info.name] = info
+        if info.range_subtype:
+            # Note: actually not unique (e.g. range of different collation?)
+            self._by_range_subtype[info.range_subtype] = info
 
         if isinstance(info, BuiltinTypeInfo):
             if info.alt_name not in self._by_name:
@@ -106,6 +114,18 @@ class TypesRegistry:
         else:
             return t.oid
 
+    def get_range(self, key: Union[str, int]) -> Optional[TypeInfo]:
+        """
+        Return info about a range by its element name or oid
+
+        Return None if the element or its range are not found.
+        """
+        try:
+            info = self[key]
+        except KeyError:
+            return None
+        return self._by_range_subtype.get(info.oid)
+
 
 builtins = TypesRegistry()
 
@@ -113,91 +133,91 @@ builtins = TypesRegistry()
 for r in [
     # autogenerated: start
     # Generated from PostgreSQL 13.1
-    ("aclitem", 1033, 1034, "aclitem", ","),
-    ("any", 2276, 0, '"any"', ","),
-    ("anyarray", 2277, 0, "anyarray", ","),
-    ("anycompatible", 5077, 0, "anycompatible", ","),
-    ("anycompatiblearray", 5078, 0, "anycompatiblearray", ","),
-    ("anycompatiblenonarray", 5079, 0, "anycompatiblenonarray", ","),
-    ("anycompatiblerange", 5080, 0, "anycompatiblerange", ","),
-    ("anyelement", 2283, 0, "anyelement", ","),
-    ("anyenum", 3500, 0, "anyenum", ","),
-    ("anynonarray", 2776, 0, "anynonarray", ","),
-    ("anyrange", 3831, 0, "anyrange", ","),
-    ("bit", 1560, 1561, "bit", ","),
-    ("bool", 16, 1000, "boolean", ","),
-    ("box", 603, 1020, "box", ";"),
-    ("bpchar", 1042, 1014, "character", ","),
-    ("bytea", 17, 1001, "bytea", ","),
-    ("char", 18, 1002, '"char"', ","),
-    ("cid", 29, 1012, "cid", ","),
-    ("cidr", 650, 651, "cidr", ","),
-    ("circle", 718, 719, "circle", ","),
-    ("cstring", 2275, 1263, "cstring", ","),
-    ("date", 1082, 1182, "date", ","),
-    ("daterange", 3912, 3913, "daterange", ","),
-    ("event_trigger", 3838, 0, "event_trigger", ","),
-    ("float4", 700, 1021, "real", ","),
-    ("float8", 701, 1022, "double precision", ","),
-    ("gtsvector", 3642, 3644, "gtsvector", ","),
-    ("inet", 869, 1041, "inet", ","),
-    ("int2", 21, 1005, "smallint", ","),
-    ("int2vector", 22, 1006, "int2vector", ","),
-    ("int4", 23, 1007, "integer", ","),
-    ("int4range", 3904, 3905, "int4range", ","),
-    ("int8", 20, 1016, "bigint", ","),
-    ("int8range", 3926, 3927, "int8range", ","),
-    ("internal", 2281, 0, "internal", ","),
-    ("interval", 1186, 1187, "interval", ","),
-    ("json", 114, 199, "json", ","),
-    ("jsonb", 3802, 3807, "jsonb", ","),
-    ("jsonpath", 4072, 4073, "jsonpath", ","),
-    ("line", 628, 629, "line", ","),
-    ("lseg", 601, 1018, "lseg", ","),
-    ("macaddr", 829, 1040, "macaddr", ","),
-    ("macaddr8", 774, 775, "macaddr8", ","),
-    ("money", 790, 791, "money", ","),
-    ("name", 19, 1003, "name", ","),
-    ("numeric", 1700, 1231, "numeric", ","),
-    ("numrange", 3906, 3907, "numrange", ","),
-    ("oid", 26, 1028, "oid", ","),
-    ("oidvector", 30, 1013, "oidvector", ","),
-    ("path", 602, 1019, "path", ","),
-    ("point", 600, 1017, "point", ","),
-    ("polygon", 604, 1027, "polygon", ","),
-    ("record", 2249, 2287, "record", ","),
-    ("refcursor", 1790, 2201, "refcursor", ","),
-    ("regclass", 2205, 2210, "regclass", ","),
-    ("regcollation", 4191, 4192, "regcollation", ","),
-    ("regconfig", 3734, 3735, "regconfig", ","),
-    ("regdictionary", 3769, 3770, "regdictionary", ","),
-    ("regnamespace", 4089, 4090, "regnamespace", ","),
-    ("regoper", 2203, 2208, "regoper", ","),
-    ("regoperator", 2204, 2209, "regoperator", ","),
-    ("regproc", 24, 1008, "regproc", ","),
-    ("regprocedure", 2202, 2207, "regprocedure", ","),
-    ("regrole", 4096, 4097, "regrole", ","),
-    ("regtype", 2206, 2211, "regtype", ","),
-    ("text", 25, 1009, "text", ","),
-    ("tid", 27, 1010, "tid", ","),
-    ("time", 1083, 1183, "time without time zone", ","),
-    ("timestamp", 1114, 1115, "timestamp without time zone", ","),
-    ("timestamptz", 1184, 1185, "timestamp with time zone", ","),
-    ("timetz", 1266, 1270, "time with time zone", ","),
-    ("trigger", 2279, 0, "trigger", ","),
-    ("tsquery", 3615, 3645, "tsquery", ","),
-    ("tsrange", 3908, 3909, "tsrange", ","),
-    ("tstzrange", 3910, 3911, "tstzrange", ","),
-    ("tsvector", 3614, 3643, "tsvector", ","),
-    ("txid_snapshot", 2970, 2949, "txid_snapshot", ","),
-    ("unknown", 705, 0, "unknown", ","),
-    ("uuid", 2950, 2951, "uuid", ","),
-    ("varbit", 1562, 1563, "bit varying", ","),
-    ("varchar", 1043, 1015, "character varying", ","),
-    ("void", 2278, 0, "void", ","),
-    ("xid", 28, 1011, "xid", ","),
-    ("xid8", 5069, 271, "xid8", ","),
-    ("xml", 142, 143, "xml", ","),
+    ("aclitem", 1033, 1034, 0, "aclitem", ","),
+    ("any", 2276, 0, 0, '"any"', ","),
+    ("anyarray", 2277, 0, 0, "anyarray", ","),
+    ("anycompatible", 5077, 0, 0, "anycompatible", ","),
+    ("anycompatiblearray", 5078, 0, 0, "anycompatiblearray", ","),
+    ("anycompatiblenonarray", 5079, 0, 0, "anycompatiblenonarray", ","),
+    ("anycompatiblerange", 5080, 0, 0, "anycompatiblerange", ","),
+    ("anyelement", 2283, 0, 0, "anyelement", ","),
+    ("anyenum", 3500, 0, 0, "anyenum", ","),
+    ("anynonarray", 2776, 0, 0, "anynonarray", ","),
+    ("anyrange", 3831, 0, 0, "anyrange", ","),
+    ("bit", 1560, 1561, 0, "bit", ","),
+    ("bool", 16, 1000, 0, "boolean", ","),
+    ("box", 603, 1020, 0, "box", ";"),
+    ("bpchar", 1042, 1014, 0, "character", ","),
+    ("bytea", 17, 1001, 0, "bytea", ","),
+    ("char", 18, 1002, 0, '"char"', ","),
+    ("cid", 29, 1012, 0, "cid", ","),
+    ("cidr", 650, 651, 0, "cidr", ","),
+    ("circle", 718, 719, 0, "circle", ","),
+    ("cstring", 2275, 1263, 0, "cstring", ","),
+    ("date", 1082, 1182, 0, "date", ","),
+    ("daterange", 3912, 3913, 1082, "daterange", ","),
+    ("event_trigger", 3838, 0, 0, "event_trigger", ","),
+    ("float4", 700, 1021, 0, "real", ","),
+    ("float8", 701, 1022, 0, "double precision", ","),
+    ("gtsvector", 3642, 3644, 0, "gtsvector", ","),
+    ("inet", 869, 1041, 0, "inet", ","),
+    ("int2", 21, 1005, 0, "smallint", ","),
+    ("int2vector", 22, 1006, 0, "int2vector", ","),
+    ("int4", 23, 1007, 0, "integer", ","),
+    ("int4range", 3904, 3905, 23, "int4range", ","),
+    ("int8", 20, 1016, 0, "bigint", ","),
+    ("int8range", 3926, 3927, 20, "int8range", ","),
+    ("internal", 2281, 0, 0, "internal", ","),
+    ("interval", 1186, 1187, 0, "interval", ","),
+    ("json", 114, 199, 0, "json", ","),
+    ("jsonb", 3802, 3807, 0, "jsonb", ","),
+    ("jsonpath", 4072, 4073, 0, "jsonpath", ","),
+    ("line", 628, 629, 0, "line", ","),
+    ("lseg", 601, 1018, 0, "lseg", ","),
+    ("macaddr", 829, 1040, 0, "macaddr", ","),
+    ("macaddr8", 774, 775, 0, "macaddr8", ","),
+    ("money", 790, 791, 0, "money", ","),
+    ("name", 19, 1003, 0, "name", ","),
+    ("numeric", 1700, 1231, 0, "numeric", ","),
+    ("numrange", 3906, 3907, 1700, "numrange", ","),
+    ("oid", 26, 1028, 0, "oid", ","),
+    ("oidvector", 30, 1013, 0, "oidvector", ","),
+    ("path", 602, 1019, 0, "path", ","),
+    ("point", 600, 1017, 0, "point", ","),
+    ("polygon", 604, 1027, 0, "polygon", ","),
+    ("record", 2249, 2287, 0, "record", ","),
+    ("refcursor", 1790, 2201, 0, "refcursor", ","),
+    ("regclass", 2205, 2210, 0, "regclass", ","),
+    ("regcollation", 4191, 4192, 0, "regcollation", ","),
+    ("regconfig", 3734, 3735, 0, "regconfig", ","),
+    ("regdictionary", 3769, 3770, 0, "regdictionary", ","),
+    ("regnamespace", 4089, 4090, 0, "regnamespace", ","),
+    ("regoper", 2203, 2208, 0, "regoper", ","),
+    ("regoperator", 2204, 2209, 0, "regoperator", ","),
+    ("regproc", 24, 1008, 0, "regproc", ","),
+    ("regprocedure", 2202, 2207, 0, "regprocedure", ","),
+    ("regrole", 4096, 4097, 0, "regrole", ","),
+    ("regtype", 2206, 2211, 0, "regtype", ","),
+    ("text", 25, 1009, 0, "text", ","),
+    ("tid", 27, 1010, 0, "tid", ","),
+    ("time", 1083, 1183, 0, "time without time zone", ","),
+    ("timestamp", 1114, 1115, 0, "timestamp without time zone", ","),
+    ("timestamptz", 1184, 1185, 0, "timestamp with time zone", ","),
+    ("timetz", 1266, 1270, 0, "time with time zone", ","),
+    ("trigger", 2279, 0, 0, "trigger", ","),
+    ("tsquery", 3615, 3645, 0, "tsquery", ","),
+    ("tsrange", 3908, 3909, 1114, "tsrange", ","),
+    ("tstzrange", 3910, 3911, 1184, "tstzrange", ","),
+    ("tsvector", 3614, 3643, 0, "tsvector", ","),
+    ("txid_snapshot", 2970, 2949, 0, "txid_snapshot", ","),
+    ("unknown", 705, 0, 0, "unknown", ","),
+    ("uuid", 2950, 2951, 0, "uuid", ","),
+    ("varbit", 1562, 1563, 0, "bit varying", ","),
+    ("varchar", 1043, 1015, 0, "character varying", ","),
+    ("void", 2278, 0, 0, "void", ","),
+    ("xid", 28, 1011, 0, "xid", ","),
+    ("xid8", 5069, 271, 0, "xid8", ","),
+    ("xml", 142, 143, 0, "xml", ","),
     # autogenerated: end
 ]:
     builtins.add(BuiltinTypeInfo(*r))
index 69b9ef51e6e909689dbfc1833f132d5da9500eca..e89a41b5e8c8f11a6f6aa6a6ddb28cbc2e34f012 100644 (file)
@@ -14,8 +14,7 @@ from . import range
 # Wrapper objects
 from ..wrappers.numeric import Int2, Int4, Int8, IntNumeric, Oid
 from .json import Json, Jsonb
-from .range import Range, Int4Range, Int8Range, DecimalRange
-from .range import DateRange, DateTimeRange, DateTimeTZRange
+from .range import Range
 
 # Supper objects
 from .range import RangeInfo
@@ -99,12 +98,6 @@ from .network import (
 )
 from .range import (
     RangeDumper,
-    Int4RangeDumper,
-    Int8RangeDumper,
-    NumRangeDumper,
-    DateRangeDumper,
-    TimestampRangeDumper,
-    TimestampTZRangeDumper,
     RangeLoader,
     Int4RangeLoader,
     Int8RangeLoader,
@@ -217,12 +210,7 @@ def register_default_globals(ctx: AdaptContext) -> None:
     InetLoader.register("inet", ctx)
     CidrLoader.register("cidr", ctx)
 
-    Int4RangeDumper.register(Int4Range, ctx)
-    Int8RangeDumper.register(Int8Range, ctx)
-    NumRangeDumper.register(DecimalRange, ctx)
-    DateRangeDumper.register(DateRange, ctx)
-    TimestampRangeDumper.register(DateTimeRange, ctx)
-    TimestampTZRangeDumper.register(DateTimeTZRange, ctx)
+    RangeDumper.register(Range, ctx)
     Int4RangeLoader.register("int4range", ctx)
     Int8RangeLoader.register("int8range", ctx)
     NumericRangeLoader.register("numrange", ctx)
index bdd7b4f2d09913ec60e8e6bbb1738ecdab15d2b2..1bc17bffeb2c808b644b8222353dc699bda89143 100644 (file)
@@ -6,7 +6,7 @@ Support for range types adaptation.
 
 import re
 from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Type, Union
-from typing import cast, TYPE_CHECKING
+from typing import cast, Tuple, TYPE_CHECKING
 from decimal import Decimal
 from datetime import date, datetime
 
@@ -14,7 +14,7 @@ from .. import sql
 from .. import errors as e
 from ..pq import Format
 from ..oids import builtins, TypeInfo
-from ..adapt import Buffer, Dumper, Loader
+from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format
 from ..proto import AdaptContext
 
 from . import array
@@ -217,11 +217,17 @@ class Range(Generic[T]):
 
 class RangeDumper(SequenceDumper):
     """
-    Generic dumper for a range.
+    Dumper for range types.
 
-    Subclasses shoud specify the type oid.
+    The dumper can upgrade to one specific for a different range type.
     """
 
+    format = Format.TEXT
+
+    def __init__(self, cls: type, context: Optional[AdaptContext] = None):
+        super().__init__(cls, context)
+        self.sub_dumper: Optional[Dumper] = None
+
     def dump(self, obj: Range[Any]) -> bytes:
         if not obj:
             return b"empty"
@@ -235,6 +241,66 @@ class RangeDumper(SequenceDumper):
 
     _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
 
+    def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]:
+        item = self._get_item(obj)
+        if item is not None:
+            # TODO: binary range support
+            sd = self._tx.get_dumper(item, Pg3Format.TEXT)
+            return (self.cls, sd.cls)
+        else:
+            return (self.cls,)
+
+    def upgrade(self, obj: Range[Any], format: Pg3Format) -> "RangeDumper":
+        item = self._get_item(obj)
+        if item is None:
+            return RangeDumper(self.cls)
+
+        # TODO: binary range support
+        sd = self._tx.get_dumper(item, Pg3Format.TEXT)
+        dumper = type(self)(self.cls, self._tx)
+        dumper.sub_dumper = sd
+        dumper.oid = self._get_range_oid(sd.oid)
+        return dumper
+
+    def _get_item(self, obj: Range[Any]) -> Any:
+        """
+        Return a member representative of the range
+
+        If the range is numeric return the bigger number in absolute value, to
+        decide the best type to use. Otherwise return any non-null in the
+        range. Return None for an empty range.
+        """
+        lo, up = obj._lower, obj._upper
+        if lo is None:
+            rv = up
+        elif up is None:
+            rv = lo
+        else:
+            if isinstance(lo, int):
+                rv = up if abs(up) > abs(lo) else lo
+            else:
+                rv = up
+
+        # Upgrade int2 -> int4 as there's no int2range
+        if isinstance(rv, int) and -(2 ** 15) <= rv < 2 ** 15:
+            rv = 2 ** 15
+        return rv
+
+    def _get_range_oid(self, sub_oid: int) -> int:
+        """
+        Return the oid of the range from the oid of its elements.
+
+        Raise InterfaceError if not found.
+
+        TODO: we shouldn't consider builtins only, but other adaptation
+        contexts too
+        """
+        info = builtins.get_range(sub_oid)
+        if info:
+            return info.oid
+        else:
+            raise e.InterfaceError(f"range for type {sub_oid} unknown")
+
 
 class RangeLoader(BaseCompositeLoader, Generic[T]):
     """Generic loader for a range.
@@ -243,11 +309,10 @@ class RangeLoader(BaseCompositeLoader, Generic[T]):
     """
 
     subtype_oid: int
-    cls: Type[Range[T]]
 
     def load(self, data: Buffer) -> Range[T]:
         if data == b"empty":
-            return self.cls(empty=True)
+            return Range(empty=True)
 
         cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load
         bounds = _int2parens[data[0]] + _int2parens[data[-1]]
@@ -255,97 +320,37 @@ class RangeLoader(BaseCompositeLoader, Generic[T]):
             cast(token) if token is not None else None
             for token in self._parse_record(data[1:-1])
         )
-        return self.cls(min, max, bounds)
+        return Range(min, max, bounds)
 
 
 _int2parens = {ord(c): c for c in "[]()"}
 
 
-# Python wrappers for builtin range types
-
-
-class Int4Range(Range[int]):
-    pass
-
-
-class Int8Range(Range[int]):
-    pass
-
-
-class DecimalRange(Range[Decimal]):
-    pass
-
-
-class DateRange(Range[date]):
-    pass
-
-
-class DateTimeRange(Range[datetime]):
-    pass
-
-
-class DateTimeTZRange(Range[datetime]):
-    pass
-
-
-# Dumpers for builtin range types
-
-
-class Int4RangeDumper(RangeDumper):
-    _oid = builtins["int4range"].oid
-
-
-class Int8RangeDumper(RangeDumper):
-    _oid = builtins["int8range"].oid
-
-
-class NumRangeDumper(RangeDumper):
-    _oid = builtins["numrange"].oid
-
-
-class DateRangeDumper(RangeDumper):
-    _oid = builtins["daterange"].oid
-
-
-class TimestampRangeDumper(RangeDumper):
-    _oid = builtins["tsrange"].oid
-
-
-class TimestampTZRangeDumper(RangeDumper):
-    _oid = builtins["tstzrange"].oid
-
-
 # Loaders for builtin range types
 
 
 class Int4RangeLoader(RangeLoader[int]):
     subtype_oid = builtins["int4"].oid
-    cls = Int4Range
 
 
 class Int8RangeLoader(RangeLoader[int]):
     subtype_oid = builtins["int8"].oid
-    cls = Int8Range
 
 
 class NumericRangeLoader(RangeLoader[Decimal]):
     subtype_oid = builtins["numeric"].oid
-    cls = DecimalRange
 
 
 class DateRangeLoader(RangeLoader[date]):
     subtype_oid = builtins["date"].oid
-    cls = DateRange
 
 
 class TimestampRangeLoader(RangeLoader[datetime]):
     subtype_oid = builtins["timestamp"].oid
-    cls = DateTimeRange
 
 
 class TimestampTZRangeLoader(RangeLoader[datetime]):
     subtype_oid = builtins["timestamptz"].oid
-    cls = DateTimeTZRange
 
 
 class RangeInfo(TypeInfo):
@@ -357,16 +362,6 @@ class RangeInfo(TypeInfo):
     - configure a composite type adaptation using `register()`
     """
 
-    def __init__(
-        self,
-        name: str,
-        oid: int,
-        array_oid: int,
-        subtype_oid: int,
-    ):
-        super().__init__(name, oid, array_oid)
-        self.subtype_oid = subtype_oid
-
     @classmethod
     def fetch(
         cls, conn: "Connection", name: Union[str, sql.Identifier]
@@ -392,22 +387,15 @@ class RangeInfo(TypeInfo):
     def register(
         self,
         context: Optional[AdaptContext] = None,
-        range_class: Optional[Type[Range[Any]]] = None,
     ) -> None:
-        if not range_class:
-            range_class = type(self.name.title(), (Range,), {})
-
-        # generate and register a customized text dumper
-        dumper: Type[Dumper] = type(
-            f"{self.name.title()}Dumper", (RangeDumper,), {"_oid": self.oid}
-        )
-        dumper.register(range_class, context=context)
+        # A new dumper is not required. However TODO we will need to register
+        # the dumper in the adapters type registry, when we have one.
 
         # generate and register a customized text loader
         loader: Type[Loader] = type(
             f"{self.name.title()}Loader",
             (RangeLoader,),
-            {"cls": range_class, "subtype_oid": self.subtype_oid},
+            {"subtype_oid": self.range_subtype},
         )
         loader.register(self.oid, context=context)
 
@@ -430,7 +418,7 @@ class RangeInfo(TypeInfo):
 
     _info_query = """\
 select t.typname as name, t.oid as oid, t.typarray as array_oid,
-    r.rngsubtype as subtype_oid
+    r.rngsubtype as range_subtype
 from pg_type t
 join pg_range r on t.oid = r.rngtypid
 where t.oid = %(name)s::regtype
index 003532797634bbf23835954f9a7778f7d8bc3ffb..ac601c230436cd894e3993e7ab8e82ddfcd6af63 100644 (file)
@@ -10,14 +10,6 @@ from psycopg3.types import range as mrange
 from psycopg3.types.range import Range
 
 
-type2cls = {
-    "int4range": mrange.Int4Range,
-    "int8range": mrange.Int8Range,
-    "numrange": mrange.DecimalRange,
-    "daterange": mrange.DateRange,
-    "tsrange": mrange.DateTimeRange,
-    "tstzrange": mrange.DateTimeTZRange,
-}
 type2sub = {
     "int4range": "int4",
     "int8range": "int8",
@@ -58,7 +50,7 @@ samples = [
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
 def test_dump_builtin_empty(conn, pgtype):
-    r = type2cls[pgtype](empty=True)
+    r = Range(empty=True)
     cur = conn.execute(f"select 'empty'::{pgtype} = %s", (r,))
     assert cur.fetchone()[0] is True
 
@@ -68,8 +60,8 @@ def test_dump_builtin_empty(conn, pgtype):
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
 def test_dump_builtin_array(conn, pgtype):
-    r1 = type2cls[pgtype](empty=True)
-    r2 = type2cls[pgtype](bounds="()")
+    r1 = Range(empty=True)
+    r2 = Range(bounds="()")
     cur = conn.execute(
         f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %s",
         ([r1, r2],),
@@ -79,10 +71,10 @@ def test_dump_builtin_array(conn, pgtype):
 
 @pytest.mark.parametrize("pgtype, min, max, bounds", samples)
 def test_dump_builtin_range(conn, pgtype, min, max, bounds):
-    r = type2cls[pgtype](min, max, bounds)
+    r = Range(min, max, bounds)
     sub = type2sub[pgtype]
     cur = conn.execute(
-        f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s",
+        f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s::{pgtype}",
         (min, max, bounds, r),
     )
     assert cur.fetchone()[0] is True
@@ -93,9 +85,9 @@ def test_dump_builtin_range(conn, pgtype, min, max, bounds):
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
 def test_load_builtin_empty(conn, pgtype):
-    r = type2cls[pgtype](empty=True)
+    r = Range(empty=True)
     (got,) = conn.execute(f"select 'empty'::{pgtype}").fetchone()
-    assert type(got) is type2cls[pgtype]
+    assert type(got) is Range
     assert got == r
     assert not got
     assert got.isempty
@@ -106,9 +98,9 @@ def test_load_builtin_empty(conn, pgtype):
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
 def test_load_builtin_inf(conn, pgtype):
-    r = type2cls[pgtype](bounds="()")
+    r = Range(bounds="()")
     (got,) = conn.execute(f"select '(,)'::{pgtype}").fetchone()
-    assert type(got) is type2cls[pgtype]
+    assert type(got) is Range
     assert got == r
     assert got
     assert not got.isempty
@@ -121,8 +113,8 @@ def test_load_builtin_inf(conn, pgtype):
     "int4range int8range numrange daterange tsrange tstzrange".split(),
 )
 def test_load_builtin_array(conn, pgtype):
-    r1 = type2cls[pgtype](empty=True)
-    r2 = type2cls[pgtype](bounds="()")
+    r1 = Range(empty=True)
+    r2 = Range(bounds="()")
     (got,) = conn.execute(
         f"select array['empty'::{pgtype}, '(,)'::{pgtype}]"
     ).fetchone()
@@ -131,7 +123,7 @@ def test_load_builtin_array(conn, pgtype):
 
 @pytest.mark.parametrize("pgtype, min, max, bounds", samples)
 def test_load_builtin_range(conn, pgtype, min, max, bounds):
-    r = type2cls[pgtype](min, max, bounds)
+    r = Range(min, max, bounds)
     sub = type2sub[pgtype]
     cur = conn.execute(
         f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds)
@@ -172,7 +164,7 @@ def test_fetch_info(conn, testrange, name, subtype):
     assert info.name == "testrange"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
-    assert info.subtype_oid == builtins[subtype].oid
+    assert info.range_subtype == builtins[subtype].oid
 
 
 def test_fetch_info_not_found(conn):
@@ -187,7 +179,7 @@ async def test_fetch_info_async(aconn, testrange, name, subtype):
     assert info.name == "testrange"
     assert info.oid > 0
     assert info.oid != info.array_oid > 0
-    assert info.subtype_oid == builtins[subtype].oid
+    assert info.range_subtype == builtins[subtype].oid
 
 
 @pytest.mark.asyncio
@@ -197,28 +189,22 @@ async def test_fetch_info_not_found_async(aconn):
 
 
 def test_dump_custom_empty(conn, testrange):
-    class StrRange(mrange.Range):
-        pass
-
     info = mrange.RangeInfo.fetch(conn, "testrange")
-    info.register(conn, range_class=StrRange)
+    info.register(conn)
 
-    r = StrRange(empty=True)
+    r = Range(empty=True)
     cur = conn.execute("select 'empty'::testrange = %s", (r,))
     assert cur.fetchone()[0] is True
 
 
 def test_dump_quoting(conn, testrange):
-    class StrRange(mrange.Range):
-        pass
-
     info = mrange.RangeInfo.fetch(conn, "testrange")
-    info.register(conn, range_class=StrRange)
+    info.register(conn)
     cur = conn.cursor()
     for i in range(1, 254):
         cur.execute(
             "select ascii(lower(%(r)s)) = %(low)s and ascii(upper(%(r)s)) = %(up)s",
-            {"r": StrRange(chr(i), chr(i + 1)), "low": i, "up": i + 1},
+            {"r": Range(chr(i), chr(i + 1)), "low": i, "up": i + 1},
         )
         assert cur.fetchone()[0] is True
 
@@ -400,16 +386,6 @@ class TestRangeObject:
     def test_eq_wrong_type(self):
         assert Range(10, 20) != ()
 
-    def test_eq_subclass(self):
-        class IntRange(mrange.DecimalRange):
-            pass
-
-        class PositiveIntRange(IntRange):
-            pass
-
-        assert Range(10, 20) == IntRange(10, 20)
-        assert PositiveIntRange(10, 20) == IntRange(10, 20)
-
     # as the postgres docs describe for the server-side stuff,
     # ordering is rather arbitrary, but will remain stable
     # and consistent.
@@ -481,13 +457,7 @@ class TestRangeObject:
     def test_str(self):
         """
         Range types should have a short and readable ``str`` implementation.
-
-        Using ``repr`` for all string conversions can be very unreadable for
-        longer types like ``DateTimeTZRange``.
         """
-
-        # Using the "u" prefix to make sure we have the proper return types in
-        # Python2
         expected = [
             "(0, 4)",
             "[0, 4]",
@@ -511,7 +481,7 @@ class TestRangeObject:
         string conversion.
         """
         tz = dt.timezone(dt.timedelta(hours=-5))
-        r = mrange.DateTimeTZRange(
+        r = mrange.Range(
             dt.datetime(2010, 1, 1, tzinfo=tz),
             dt.datetime(2011, 1, 1, tzinfo=tz),
         )
index 56171d00619cba0b3c1a88b7dd191ce5a8fa65b3..62c90d39513b43a8c8e49eb5b785c50f885228d1 100755 (executable)
@@ -24,9 +24,10 @@ $$,
 
 py_oids_sql = """
 select format(
-        '(%L, %s, %s, %L, %L),',
-        typname, oid, typarray, oid::regtype, typdelim)
-    from pg_type
+        '(%L, %s, %s, %s, %L, %L),',
+        typname, oid, typarray, coalesce(rngsubtype, 0), oid::regtype, typdelim)
+    from pg_type t
+    left join pg_range r on t.oid = rngtypid
     where oid < 10000
     and typname !~ all('{^(_|pg_),_handler$}')
     order by typname