From: Daniele Varrazzo Date: Thu, 21 Jan 2021 03:05:33 +0000 (+0100) Subject: Have a single Range class, not one per subtype X-Git-Tag: 3.0.dev0~133 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=113ad4503e2cb760f6acc5d444a4f235e25bfeca;p=thirdparty%2Fpsycopg.git Have a single Range class, not one per subtype There are two problems here, so this doesn't work. One, which will be solved, is to have a registry of types attached to an adaptation context: a TODO in already several places, so this will be solved. Before that we don't really have a way to find back e.g. the oid of a custom range on strings starting from Range('a', 'b'). The second problem, more serious, is that Postgres doesn't cast int4range <-> int8range and tsrange <-> tstzrange. The latter pair can be solved with a two steps dumper choosing between tz aware and not. The first I don't have in mind how to solve it: Given Range(1,2) I wouldn't know if int4 or int8 should be used, and Postgres doesn't seem very forgiving. Probably we should go unknown. --- diff --git a/psycopg3/psycopg3/oids.py b/psycopg3/psycopg3/oids.py index a30b52f72..daedcf312 100644 --- a/psycopg3/psycopg3/oids.py +++ b/psycopg3/psycopg3/oids.py @@ -8,10 +8,13 @@ from typing import Dict, Iterator, Optional, Union class TypeInfo: - def __init__(self, name: str, oid: int, array_oid: int): + def __init__( + self, name: str, oid: int, array_oid: int, range_subtype: int = 0 + ): self.name = name self.oid = oid self.array_oid = array_oid + self.range_subtype = range_subtype def __repr__(self) -> str: return ( @@ -26,10 +29,11 @@ class BuiltinTypeInfo(TypeInfo): name: str, oid: int, array_oid: int, + range_subtype: int, alt_name: str, delimiter: str, ): - super().__init__(name, oid, array_oid) + super().__init__(name, oid, array_oid, range_subtype) self.alt_name = alt_name self.delimiter = delimiter @@ -42,12 +46,16 @@ class TypesRegistry: def __init__(self) -> None: self._by_oid: Dict[int, TypeInfo] = {} self._by_name: Dict[str, TypeInfo] = {} + self._by_range_subtype: Dict[int, TypeInfo] = {} def add(self, info: TypeInfo) -> None: self._by_oid[info.oid] = info if info.array_oid: self._by_oid[info.array_oid] = info self._by_name[info.name] = info + if info.range_subtype: + # Note: actually not unique (e.g. range of different collation?) + self._by_range_subtype[info.range_subtype] = info if isinstance(info, BuiltinTypeInfo): if info.alt_name not in self._by_name: @@ -106,6 +114,18 @@ class TypesRegistry: else: return t.oid + def get_range(self, key: Union[str, int]) -> Optional[TypeInfo]: + """ + Return info about a range by its element name or oid + + Return None if the element or its range are not found. + """ + try: + info = self[key] + except KeyError: + return None + return self._by_range_subtype.get(info.oid) + builtins = TypesRegistry() @@ -113,91 +133,91 @@ builtins = TypesRegistry() for r in [ # autogenerated: start # Generated from PostgreSQL 13.1 - ("aclitem", 1033, 1034, "aclitem", ","), - ("any", 2276, 0, '"any"', ","), - ("anyarray", 2277, 0, "anyarray", ","), - ("anycompatible", 5077, 0, "anycompatible", ","), - ("anycompatiblearray", 5078, 0, "anycompatiblearray", ","), - ("anycompatiblenonarray", 5079, 0, "anycompatiblenonarray", ","), - ("anycompatiblerange", 5080, 0, "anycompatiblerange", ","), - ("anyelement", 2283, 0, "anyelement", ","), - ("anyenum", 3500, 0, "anyenum", ","), - ("anynonarray", 2776, 0, "anynonarray", ","), - ("anyrange", 3831, 0, "anyrange", ","), - ("bit", 1560, 1561, "bit", ","), - ("bool", 16, 1000, "boolean", ","), - ("box", 603, 1020, "box", ";"), - ("bpchar", 1042, 1014, "character", ","), - ("bytea", 17, 1001, "bytea", ","), - ("char", 18, 1002, '"char"', ","), - ("cid", 29, 1012, "cid", ","), - ("cidr", 650, 651, "cidr", ","), - ("circle", 718, 719, "circle", ","), - ("cstring", 2275, 1263, "cstring", ","), - ("date", 1082, 1182, "date", ","), - ("daterange", 3912, 3913, "daterange", ","), - ("event_trigger", 3838, 0, "event_trigger", ","), - ("float4", 700, 1021, "real", ","), - ("float8", 701, 1022, "double precision", ","), - ("gtsvector", 3642, 3644, "gtsvector", ","), - ("inet", 869, 1041, "inet", ","), - ("int2", 21, 1005, "smallint", ","), - ("int2vector", 22, 1006, "int2vector", ","), - ("int4", 23, 1007, "integer", ","), - ("int4range", 3904, 3905, "int4range", ","), - ("int8", 20, 1016, "bigint", ","), - ("int8range", 3926, 3927, "int8range", ","), - ("internal", 2281, 0, "internal", ","), - ("interval", 1186, 1187, "interval", ","), - ("json", 114, 199, "json", ","), - ("jsonb", 3802, 3807, "jsonb", ","), - ("jsonpath", 4072, 4073, "jsonpath", ","), - ("line", 628, 629, "line", ","), - ("lseg", 601, 1018, "lseg", ","), - ("macaddr", 829, 1040, "macaddr", ","), - ("macaddr8", 774, 775, "macaddr8", ","), - ("money", 790, 791, "money", ","), - ("name", 19, 1003, "name", ","), - ("numeric", 1700, 1231, "numeric", ","), - ("numrange", 3906, 3907, "numrange", ","), - ("oid", 26, 1028, "oid", ","), - ("oidvector", 30, 1013, "oidvector", ","), - ("path", 602, 1019, "path", ","), - ("point", 600, 1017, "point", ","), - ("polygon", 604, 1027, "polygon", ","), - ("record", 2249, 2287, "record", ","), - ("refcursor", 1790, 2201, "refcursor", ","), - ("regclass", 2205, 2210, "regclass", ","), - ("regcollation", 4191, 4192, "regcollation", ","), - ("regconfig", 3734, 3735, "regconfig", ","), - ("regdictionary", 3769, 3770, "regdictionary", ","), - ("regnamespace", 4089, 4090, "regnamespace", ","), - ("regoper", 2203, 2208, "regoper", ","), - ("regoperator", 2204, 2209, "regoperator", ","), - ("regproc", 24, 1008, "regproc", ","), - ("regprocedure", 2202, 2207, "regprocedure", ","), - ("regrole", 4096, 4097, "regrole", ","), - ("regtype", 2206, 2211, "regtype", ","), - ("text", 25, 1009, "text", ","), - ("tid", 27, 1010, "tid", ","), - ("time", 1083, 1183, "time without time zone", ","), - ("timestamp", 1114, 1115, "timestamp without time zone", ","), - ("timestamptz", 1184, 1185, "timestamp with time zone", ","), - ("timetz", 1266, 1270, "time with time zone", ","), - ("trigger", 2279, 0, "trigger", ","), - ("tsquery", 3615, 3645, "tsquery", ","), - ("tsrange", 3908, 3909, "tsrange", ","), - ("tstzrange", 3910, 3911, "tstzrange", ","), - ("tsvector", 3614, 3643, "tsvector", ","), - ("txid_snapshot", 2970, 2949, "txid_snapshot", ","), - ("unknown", 705, 0, "unknown", ","), - ("uuid", 2950, 2951, "uuid", ","), - ("varbit", 1562, 1563, "bit varying", ","), - ("varchar", 1043, 1015, "character varying", ","), - ("void", 2278, 0, "void", ","), - ("xid", 28, 1011, "xid", ","), - ("xid8", 5069, 271, "xid8", ","), - ("xml", 142, 143, "xml", ","), + ("aclitem", 1033, 1034, 0, "aclitem", ","), + ("any", 2276, 0, 0, '"any"', ","), + ("anyarray", 2277, 0, 0, "anyarray", ","), + ("anycompatible", 5077, 0, 0, "anycompatible", ","), + ("anycompatiblearray", 5078, 0, 0, "anycompatiblearray", ","), + ("anycompatiblenonarray", 5079, 0, 0, "anycompatiblenonarray", ","), + ("anycompatiblerange", 5080, 0, 0, "anycompatiblerange", ","), + ("anyelement", 2283, 0, 0, "anyelement", ","), + ("anyenum", 3500, 0, 0, "anyenum", ","), + ("anynonarray", 2776, 0, 0, "anynonarray", ","), + ("anyrange", 3831, 0, 0, "anyrange", ","), + ("bit", 1560, 1561, 0, "bit", ","), + ("bool", 16, 1000, 0, "boolean", ","), + ("box", 603, 1020, 0, "box", ";"), + ("bpchar", 1042, 1014, 0, "character", ","), + ("bytea", 17, 1001, 0, "bytea", ","), + ("char", 18, 1002, 0, '"char"', ","), + ("cid", 29, 1012, 0, "cid", ","), + ("cidr", 650, 651, 0, "cidr", ","), + ("circle", 718, 719, 0, "circle", ","), + ("cstring", 2275, 1263, 0, "cstring", ","), + ("date", 1082, 1182, 0, "date", ","), + ("daterange", 3912, 3913, 1082, "daterange", ","), + ("event_trigger", 3838, 0, 0, "event_trigger", ","), + ("float4", 700, 1021, 0, "real", ","), + ("float8", 701, 1022, 0, "double precision", ","), + ("gtsvector", 3642, 3644, 0, "gtsvector", ","), + ("inet", 869, 1041, 0, "inet", ","), + ("int2", 21, 1005, 0, "smallint", ","), + ("int2vector", 22, 1006, 0, "int2vector", ","), + ("int4", 23, 1007, 0, "integer", ","), + ("int4range", 3904, 3905, 23, "int4range", ","), + ("int8", 20, 1016, 0, "bigint", ","), + ("int8range", 3926, 3927, 20, "int8range", ","), + ("internal", 2281, 0, 0, "internal", ","), + ("interval", 1186, 1187, 0, "interval", ","), + ("json", 114, 199, 0, "json", ","), + ("jsonb", 3802, 3807, 0, "jsonb", ","), + ("jsonpath", 4072, 4073, 0, "jsonpath", ","), + ("line", 628, 629, 0, "line", ","), + ("lseg", 601, 1018, 0, "lseg", ","), + ("macaddr", 829, 1040, 0, "macaddr", ","), + ("macaddr8", 774, 775, 0, "macaddr8", ","), + ("money", 790, 791, 0, "money", ","), + ("name", 19, 1003, 0, "name", ","), + ("numeric", 1700, 1231, 0, "numeric", ","), + ("numrange", 3906, 3907, 1700, "numrange", ","), + ("oid", 26, 1028, 0, "oid", ","), + ("oidvector", 30, 1013, 0, "oidvector", ","), + ("path", 602, 1019, 0, "path", ","), + ("point", 600, 1017, 0, "point", ","), + ("polygon", 604, 1027, 0, "polygon", ","), + ("record", 2249, 2287, 0, "record", ","), + ("refcursor", 1790, 2201, 0, "refcursor", ","), + ("regclass", 2205, 2210, 0, "regclass", ","), + ("regcollation", 4191, 4192, 0, "regcollation", ","), + ("regconfig", 3734, 3735, 0, "regconfig", ","), + ("regdictionary", 3769, 3770, 0, "regdictionary", ","), + ("regnamespace", 4089, 4090, 0, "regnamespace", ","), + ("regoper", 2203, 2208, 0, "regoper", ","), + ("regoperator", 2204, 2209, 0, "regoperator", ","), + ("regproc", 24, 1008, 0, "regproc", ","), + ("regprocedure", 2202, 2207, 0, "regprocedure", ","), + ("regrole", 4096, 4097, 0, "regrole", ","), + ("regtype", 2206, 2211, 0, "regtype", ","), + ("text", 25, 1009, 0, "text", ","), + ("tid", 27, 1010, 0, "tid", ","), + ("time", 1083, 1183, 0, "time without time zone", ","), + ("timestamp", 1114, 1115, 0, "timestamp without time zone", ","), + ("timestamptz", 1184, 1185, 0, "timestamp with time zone", ","), + ("timetz", 1266, 1270, 0, "time with time zone", ","), + ("trigger", 2279, 0, 0, "trigger", ","), + ("tsquery", 3615, 3645, 0, "tsquery", ","), + ("tsrange", 3908, 3909, 1114, "tsrange", ","), + ("tstzrange", 3910, 3911, 1184, "tstzrange", ","), + ("tsvector", 3614, 3643, 0, "tsvector", ","), + ("txid_snapshot", 2970, 2949, 0, "txid_snapshot", ","), + ("unknown", 705, 0, 0, "unknown", ","), + ("uuid", 2950, 2951, 0, "uuid", ","), + ("varbit", 1562, 1563, 0, "bit varying", ","), + ("varchar", 1043, 1015, 0, "character varying", ","), + ("void", 2278, 0, 0, "void", ","), + ("xid", 28, 1011, 0, "xid", ","), + ("xid8", 5069, 271, 0, "xid8", ","), + ("xml", 142, 143, 0, "xml", ","), # autogenerated: end ]: builtins.add(BuiltinTypeInfo(*r)) diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index 69b9ef51e..e89a41b5e 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -14,8 +14,7 @@ from . import range # Wrapper objects from ..wrappers.numeric import Int2, Int4, Int8, IntNumeric, Oid from .json import Json, Jsonb -from .range import Range, Int4Range, Int8Range, DecimalRange -from .range import DateRange, DateTimeRange, DateTimeTZRange +from .range import Range # Supper objects from .range import RangeInfo @@ -99,12 +98,6 @@ from .network import ( ) from .range import ( RangeDumper, - Int4RangeDumper, - Int8RangeDumper, - NumRangeDumper, - DateRangeDumper, - TimestampRangeDumper, - TimestampTZRangeDumper, RangeLoader, Int4RangeLoader, Int8RangeLoader, @@ -217,12 +210,7 @@ def register_default_globals(ctx: AdaptContext) -> None: InetLoader.register("inet", ctx) CidrLoader.register("cidr", ctx) - Int4RangeDumper.register(Int4Range, ctx) - Int8RangeDumper.register(Int8Range, ctx) - NumRangeDumper.register(DecimalRange, ctx) - DateRangeDumper.register(DateRange, ctx) - TimestampRangeDumper.register(DateTimeRange, ctx) - TimestampTZRangeDumper.register(DateTimeTZRange, ctx) + RangeDumper.register(Range, ctx) Int4RangeLoader.register("int4range", ctx) Int8RangeLoader.register("int8range", ctx) NumericRangeLoader.register("numrange", ctx) diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py index bdd7b4f2d..1bc17bffe 100644 --- a/psycopg3/psycopg3/types/range.py +++ b/psycopg3/psycopg3/types/range.py @@ -6,7 +6,7 @@ Support for range types adaptation. import re from typing import Any, Dict, Generic, Optional, Sequence, TypeVar, Type, Union -from typing import cast, TYPE_CHECKING +from typing import cast, Tuple, TYPE_CHECKING from decimal import Decimal from datetime import date, datetime @@ -14,7 +14,7 @@ from .. import sql from .. import errors as e from ..pq import Format from ..oids import builtins, TypeInfo -from ..adapt import Buffer, Dumper, Loader +from ..adapt import Buffer, Dumper, Loader, Format as Pg3Format from ..proto import AdaptContext from . import array @@ -217,11 +217,17 @@ class Range(Generic[T]): class RangeDumper(SequenceDumper): """ - Generic dumper for a range. + Dumper for range types. - Subclasses shoud specify the type oid. + The dumper can upgrade to one specific for a different range type. """ + format = Format.TEXT + + def __init__(self, cls: type, context: Optional[AdaptContext] = None): + super().__init__(cls, context) + self.sub_dumper: Optional[Dumper] = None + def dump(self, obj: Range[Any]) -> bytes: if not obj: return b"empty" @@ -235,6 +241,66 @@ class RangeDumper(SequenceDumper): _re_needs_quotes = re.compile(br'[",\\\s()\[\]]') + def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]: + item = self._get_item(obj) + if item is not None: + # TODO: binary range support + sd = self._tx.get_dumper(item, Pg3Format.TEXT) + return (self.cls, sd.cls) + else: + return (self.cls,) + + def upgrade(self, obj: Range[Any], format: Pg3Format) -> "RangeDumper": + item = self._get_item(obj) + if item is None: + return RangeDumper(self.cls) + + # TODO: binary range support + sd = self._tx.get_dumper(item, Pg3Format.TEXT) + dumper = type(self)(self.cls, self._tx) + dumper.sub_dumper = sd + dumper.oid = self._get_range_oid(sd.oid) + return dumper + + def _get_item(self, obj: Range[Any]) -> Any: + """ + Return a member representative of the range + + If the range is numeric return the bigger number in absolute value, to + decide the best type to use. Otherwise return any non-null in the + range. Return None for an empty range. + """ + lo, up = obj._lower, obj._upper + if lo is None: + rv = up + elif up is None: + rv = lo + else: + if isinstance(lo, int): + rv = up if abs(up) > abs(lo) else lo + else: + rv = up + + # Upgrade int2 -> int4 as there's no int2range + if isinstance(rv, int) and -(2 ** 15) <= rv < 2 ** 15: + rv = 2 ** 15 + return rv + + def _get_range_oid(self, sub_oid: int) -> int: + """ + Return the oid of the range from the oid of its elements. + + Raise InterfaceError if not found. + + TODO: we shouldn't consider builtins only, but other adaptation + contexts too + """ + info = builtins.get_range(sub_oid) + if info: + return info.oid + else: + raise e.InterfaceError(f"range for type {sub_oid} unknown") + class RangeLoader(BaseCompositeLoader, Generic[T]): """Generic loader for a range. @@ -243,11 +309,10 @@ class RangeLoader(BaseCompositeLoader, Generic[T]): """ subtype_oid: int - cls: Type[Range[T]] def load(self, data: Buffer) -> Range[T]: if data == b"empty": - return self.cls(empty=True) + return Range(empty=True) cast = self._tx.get_loader(self.subtype_oid, format=Format.TEXT).load bounds = _int2parens[data[0]] + _int2parens[data[-1]] @@ -255,97 +320,37 @@ class RangeLoader(BaseCompositeLoader, Generic[T]): cast(token) if token is not None else None for token in self._parse_record(data[1:-1]) ) - return self.cls(min, max, bounds) + return Range(min, max, bounds) _int2parens = {ord(c): c for c in "[]()"} -# Python wrappers for builtin range types - - -class Int4Range(Range[int]): - pass - - -class Int8Range(Range[int]): - pass - - -class DecimalRange(Range[Decimal]): - pass - - -class DateRange(Range[date]): - pass - - -class DateTimeRange(Range[datetime]): - pass - - -class DateTimeTZRange(Range[datetime]): - pass - - -# Dumpers for builtin range types - - -class Int4RangeDumper(RangeDumper): - _oid = builtins["int4range"].oid - - -class Int8RangeDumper(RangeDumper): - _oid = builtins["int8range"].oid - - -class NumRangeDumper(RangeDumper): - _oid = builtins["numrange"].oid - - -class DateRangeDumper(RangeDumper): - _oid = builtins["daterange"].oid - - -class TimestampRangeDumper(RangeDumper): - _oid = builtins["tsrange"].oid - - -class TimestampTZRangeDumper(RangeDumper): - _oid = builtins["tstzrange"].oid - - # Loaders for builtin range types class Int4RangeLoader(RangeLoader[int]): subtype_oid = builtins["int4"].oid - cls = Int4Range class Int8RangeLoader(RangeLoader[int]): subtype_oid = builtins["int8"].oid - cls = Int8Range class NumericRangeLoader(RangeLoader[Decimal]): subtype_oid = builtins["numeric"].oid - cls = DecimalRange class DateRangeLoader(RangeLoader[date]): subtype_oid = builtins["date"].oid - cls = DateRange class TimestampRangeLoader(RangeLoader[datetime]): subtype_oid = builtins["timestamp"].oid - cls = DateTimeRange class TimestampTZRangeLoader(RangeLoader[datetime]): subtype_oid = builtins["timestamptz"].oid - cls = DateTimeTZRange class RangeInfo(TypeInfo): @@ -357,16 +362,6 @@ class RangeInfo(TypeInfo): - configure a composite type adaptation using `register()` """ - def __init__( - self, - name: str, - oid: int, - array_oid: int, - subtype_oid: int, - ): - super().__init__(name, oid, array_oid) - self.subtype_oid = subtype_oid - @classmethod def fetch( cls, conn: "Connection", name: Union[str, sql.Identifier] @@ -392,22 +387,15 @@ class RangeInfo(TypeInfo): def register( self, context: Optional[AdaptContext] = None, - range_class: Optional[Type[Range[Any]]] = None, ) -> None: - if not range_class: - range_class = type(self.name.title(), (Range,), {}) - - # generate and register a customized text dumper - dumper: Type[Dumper] = type( - f"{self.name.title()}Dumper", (RangeDumper,), {"_oid": self.oid} - ) - dumper.register(range_class, context=context) + # A new dumper is not required. However TODO we will need to register + # the dumper in the adapters type registry, when we have one. # generate and register a customized text loader loader: Type[Loader] = type( f"{self.name.title()}Loader", (RangeLoader,), - {"cls": range_class, "subtype_oid": self.subtype_oid}, + {"subtype_oid": self.range_subtype}, ) loader.register(self.oid, context=context) @@ -430,7 +418,7 @@ class RangeInfo(TypeInfo): _info_query = """\ select t.typname as name, t.oid as oid, t.typarray as array_oid, - r.rngsubtype as subtype_oid + r.rngsubtype as range_subtype from pg_type t join pg_range r on t.oid = r.rngtypid where t.oid = %(name)s::regtype diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 003532797..ac601c230 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -10,14 +10,6 @@ from psycopg3.types import range as mrange from psycopg3.types.range import Range -type2cls = { - "int4range": mrange.Int4Range, - "int8range": mrange.Int8Range, - "numrange": mrange.DecimalRange, - "daterange": mrange.DateRange, - "tsrange": mrange.DateTimeRange, - "tstzrange": mrange.DateTimeTZRange, -} type2sub = { "int4range": "int4", "int8range": "int8", @@ -58,7 +50,7 @@ samples = [ "int4range int8range numrange daterange tsrange tstzrange".split(), ) def test_dump_builtin_empty(conn, pgtype): - r = type2cls[pgtype](empty=True) + r = Range(empty=True) cur = conn.execute(f"select 'empty'::{pgtype} = %s", (r,)) assert cur.fetchone()[0] is True @@ -68,8 +60,8 @@ def test_dump_builtin_empty(conn, pgtype): "int4range int8range numrange daterange tsrange tstzrange".split(), ) def test_dump_builtin_array(conn, pgtype): - r1 = type2cls[pgtype](empty=True) - r2 = type2cls[pgtype](bounds="()") + r1 = Range(empty=True) + r2 = Range(bounds="()") cur = conn.execute( f"select array['empty'::{pgtype}, '(,)'::{pgtype}] = %s", ([r1, r2],), @@ -79,10 +71,10 @@ def test_dump_builtin_array(conn, pgtype): @pytest.mark.parametrize("pgtype, min, max, bounds", samples) def test_dump_builtin_range(conn, pgtype, min, max, bounds): - r = type2cls[pgtype](min, max, bounds) + r = Range(min, max, bounds) sub = type2sub[pgtype] cur = conn.execute( - f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s", + f"select {pgtype}(%s::{sub}, %s::{sub}, %s) = %s::{pgtype}", (min, max, bounds, r), ) assert cur.fetchone()[0] is True @@ -93,9 +85,9 @@ def test_dump_builtin_range(conn, pgtype, min, max, bounds): "int4range int8range numrange daterange tsrange tstzrange".split(), ) def test_load_builtin_empty(conn, pgtype): - r = type2cls[pgtype](empty=True) + r = Range(empty=True) (got,) = conn.execute(f"select 'empty'::{pgtype}").fetchone() - assert type(got) is type2cls[pgtype] + assert type(got) is Range assert got == r assert not got assert got.isempty @@ -106,9 +98,9 @@ def test_load_builtin_empty(conn, pgtype): "int4range int8range numrange daterange tsrange tstzrange".split(), ) def test_load_builtin_inf(conn, pgtype): - r = type2cls[pgtype](bounds="()") + r = Range(bounds="()") (got,) = conn.execute(f"select '(,)'::{pgtype}").fetchone() - assert type(got) is type2cls[pgtype] + assert type(got) is Range assert got == r assert got assert not got.isempty @@ -121,8 +113,8 @@ def test_load_builtin_inf(conn, pgtype): "int4range int8range numrange daterange tsrange tstzrange".split(), ) def test_load_builtin_array(conn, pgtype): - r1 = type2cls[pgtype](empty=True) - r2 = type2cls[pgtype](bounds="()") + r1 = Range(empty=True) + r2 = Range(bounds="()") (got,) = conn.execute( f"select array['empty'::{pgtype}, '(,)'::{pgtype}]" ).fetchone() @@ -131,7 +123,7 @@ def test_load_builtin_array(conn, pgtype): @pytest.mark.parametrize("pgtype, min, max, bounds", samples) def test_load_builtin_range(conn, pgtype, min, max, bounds): - r = type2cls[pgtype](min, max, bounds) + r = Range(min, max, bounds) sub = type2sub[pgtype] cur = conn.execute( f"select {pgtype}(%s::{sub}, %s::{sub}, %s)", (min, max, bounds) @@ -172,7 +164,7 @@ def test_fetch_info(conn, testrange, name, subtype): assert info.name == "testrange" assert info.oid > 0 assert info.oid != info.array_oid > 0 - assert info.subtype_oid == builtins[subtype].oid + assert info.range_subtype == builtins[subtype].oid def test_fetch_info_not_found(conn): @@ -187,7 +179,7 @@ async def test_fetch_info_async(aconn, testrange, name, subtype): assert info.name == "testrange" assert info.oid > 0 assert info.oid != info.array_oid > 0 - assert info.subtype_oid == builtins[subtype].oid + assert info.range_subtype == builtins[subtype].oid @pytest.mark.asyncio @@ -197,28 +189,22 @@ async def test_fetch_info_not_found_async(aconn): def test_dump_custom_empty(conn, testrange): - class StrRange(mrange.Range): - pass - info = mrange.RangeInfo.fetch(conn, "testrange") - info.register(conn, range_class=StrRange) + info.register(conn) - r = StrRange(empty=True) + r = Range(empty=True) cur = conn.execute("select 'empty'::testrange = %s", (r,)) assert cur.fetchone()[0] is True def test_dump_quoting(conn, testrange): - class StrRange(mrange.Range): - pass - info = mrange.RangeInfo.fetch(conn, "testrange") - info.register(conn, range_class=StrRange) + info.register(conn) cur = conn.cursor() for i in range(1, 254): cur.execute( "select ascii(lower(%(r)s)) = %(low)s and ascii(upper(%(r)s)) = %(up)s", - {"r": StrRange(chr(i), chr(i + 1)), "low": i, "up": i + 1}, + {"r": Range(chr(i), chr(i + 1)), "low": i, "up": i + 1}, ) assert cur.fetchone()[0] is True @@ -400,16 +386,6 @@ class TestRangeObject: def test_eq_wrong_type(self): assert Range(10, 20) != () - def test_eq_subclass(self): - class IntRange(mrange.DecimalRange): - pass - - class PositiveIntRange(IntRange): - pass - - assert Range(10, 20) == IntRange(10, 20) - assert PositiveIntRange(10, 20) == IntRange(10, 20) - # as the postgres docs describe for the server-side stuff, # ordering is rather arbitrary, but will remain stable # and consistent. @@ -481,13 +457,7 @@ class TestRangeObject: def test_str(self): """ Range types should have a short and readable ``str`` implementation. - - Using ``repr`` for all string conversions can be very unreadable for - longer types like ``DateTimeTZRange``. """ - - # Using the "u" prefix to make sure we have the proper return types in - # Python2 expected = [ "(0, 4)", "[0, 4]", @@ -511,7 +481,7 @@ class TestRangeObject: string conversion. """ tz = dt.timezone(dt.timedelta(hours=-5)) - r = mrange.DateTimeTZRange( + r = mrange.Range( dt.datetime(2010, 1, 1, tzinfo=tz), dt.datetime(2011, 1, 1, tzinfo=tz), ) diff --git a/tools/update_oids.py b/tools/update_oids.py index 56171d006..62c90d395 100755 --- a/tools/update_oids.py +++ b/tools/update_oids.py @@ -24,9 +24,10 @@ $$, py_oids_sql = """ select format( - '(%L, %s, %s, %L, %L),', - typname, oid, typarray, oid::regtype, typdelim) - from pg_type + '(%L, %s, %s, %s, %L, %L),', + typname, oid, typarray, coalesce(rngsubtype, 0), oid::regtype, typdelim) + from pg_type t + left join pg_range r on t.oid = rngtypid where oid < 10000 and typname !~ all('{^(_|pg_),_handler$}') order by typname