From: Daniele Varrazzo Date: Wed, 23 Jun 2021 16:04:09 +0000 (+0100) Subject: Use type-specific range subclasses to avoid the dump upgrade mechanism X-Git-Tag: 3.0.dev0~18^2 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=f45d4e049ab95bb56b6a967e43ecce05ac7de42e;p=thirdparty%2Fpsycopg.git Use type-specific range subclasses to avoid the dump upgrade mechanism This allows binary dump of empty ranges, for which the type can be inferred. In normal dumping these are dumped in text format with unknown oid but that would break binary copy. --- diff --git a/psycopg3/psycopg3/types/__init__.py b/psycopg3/psycopg3/types/__init__.py index 25d66d7b5..b455ea12c 100644 --- a/psycopg3/psycopg3/types/__init__.py +++ b/psycopg3/psycopg3/types/__init__.py @@ -4,9 +4,10 @@ psycopg3 types package # Copyright (C) 2020-2021 The Psycopg Team -from . import array from ..oids import INVALID_OID from ..proto import AdaptContext +from .array import register_all_arrays +from . import range as _range # Wrapper objects from ..wrappers.numeric import ( @@ -135,7 +136,19 @@ from .network import ( ) from .range import ( RangeDumper as RangeDumper, - BinaryRangeDumper as BinaryRangeDumper, + Int4RangeDumper as Int4RangeDumper, + Int8RangeDumper as Int8RangeDumper, + NumericRangeDumper as NumericRangeDumper, + DateRangeDumper as DateRangeDumper, + TimestampRangeDumper as TimestampRangeDumper, + TimestamptzRangeDumper as TimestamptzRangeDumper, + RangeBinaryDumper as RangeBinaryDumper, + Int4RangeBinaryDumper as Int4RangeBinaryDumper, + Int8RangeBinaryDumper as Int8RangeBinaryDumper, + NumericRangeBinaryDumper as NumericRangeBinaryDumper, + DateRangeBinaryDumper as DateRangeBinaryDumper, + TimestampRangeBinaryDumper as TimestampRangeBinaryDumper, + TimestamptzRangeBinaryDumper as TimestamptzRangeBinaryDumper, RangeLoader as RangeLoader, Int4RangeLoader as Int4RangeLoader, Int8RangeLoader as Int8RangeLoader, @@ -283,8 +296,20 @@ def register_default_globals(ctx: AdaptContext) -> None: CidrLoader.register("cidr", ctx) CidrBinaryLoader.register("cidr", ctx) + RangeBinaryDumper.register(Range, ctx) RangeDumper.register(Range, ctx) - BinaryRangeDumper.register(Range, ctx) + Int4RangeDumper.register(_range.Int4Range, ctx) + Int8RangeDumper.register(_range.Int8Range, ctx) + NumericRangeDumper.register(_range.NumericRange, ctx) + DateRangeDumper.register(_range.DateRange, ctx) + TimestampRangeDumper.register(_range.TimestampRange, ctx) + TimestamptzRangeDumper.register(_range.TimestamptzRange, ctx) + Int4RangeBinaryDumper.register(_range.Int4Range, ctx) + Int8RangeBinaryDumper.register(_range.Int8Range, ctx) + NumericRangeBinaryDumper.register(_range.NumericRange, ctx) + DateRangeBinaryDumper.register(_range.DateRange, ctx) + TimestampRangeBinaryDumper.register(_range.TimestampRange, ctx) + TimestamptzRangeBinaryDumper.register(_range.TimestamptzRange, ctx) Int4RangeLoader.register("int4range", ctx) Int8RangeLoader.register("int8range", ctx) NumericRangeLoader.register("numrange", ctx) @@ -305,4 +330,4 @@ def register_default_globals(ctx: AdaptContext) -> None: RecordLoader.register("record", ctx) RecordBinaryLoader.register("record", ctx) - array.register_all_arrays(ctx) + register_all_arrays(ctx) diff --git a/psycopg3/psycopg3/types/range.py b/psycopg3/psycopg3/types/range.py index f2819afa7..3a618be3c 100644 --- a/psycopg3/psycopg3/types/range.py +++ b/psycopg3/psycopg3/types/range.py @@ -225,6 +225,34 @@ class Range(Generic[T]): setattr(self, slot, value) +# Subclasses to specify a specific subtype. Usually not needed: only needed +# in binary copy, where switching to text is not an option. + + +class Int4Range(Range[int]): + pass + + +class Int8Range(Range[int]): + pass + + +class NumericRange(Range[Decimal]): + pass + + +class DateRange(Range[date]): + pass + + +class TimestampRange(Range[datetime]): + pass + + +class TimestamptzRange(Range[datetime]): + pass + + class BaseRangeDumper(RecursiveDumper): def __init__(self, cls: type, context: Optional[AdaptContext] = None): super().__init__(cls, context) @@ -232,7 +260,13 @@ class BaseRangeDumper(RecursiveDumper): self._types = context.adapters.types if context else builtins self._adapt_format = Pg3Format.from_pq(self.format) - def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]: + def get_key( + self, obj: Range[Any], format: Pg3Format + ) -> Union[type, Tuple[type, ...]]: + # If we are a subclass whose oid is specified we don't need upgrade + if self.oid != INVALID_OID: + return self.cls + item = self._get_item(obj) if item is not None: sd = self._tx.get_dumper(item, self._adapt_format) @@ -241,6 +275,10 @@ class BaseRangeDumper(RecursiveDumper): return (self.cls,) def upgrade(self, obj: Range[Any], format: Pg3Format) -> "BaseRangeDumper": + # If we are a subclass whose oid is specified we don't need upgrade + if self.oid != INVALID_OID: + return self + item = self._get_item(obj) if item is None: return RangeDumper(self.cls) @@ -306,7 +344,7 @@ class RangeDumper(BaseRangeDumper, SequenceDumper): _re_needs_quotes = re.compile(br'[",\\\s()\[\]]') -class BinaryRangeDumper(BaseRangeDumper): +class RangeBinaryDumper(BaseRangeDumper): format = Format.BINARY @@ -421,6 +459,64 @@ def register_adapters( bloader.register(info.oid, context=context) +# Text dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeDumper(RangeDumper): + _oid = builtins["int4range"].oid + + +class Int8RangeDumper(RangeDumper): + _oid = builtins["int8range"].oid + + +class NumericRangeDumper(RangeDumper): + _oid = builtins["numrange"].oid + + +class DateRangeDumper(RangeDumper): + _oid = builtins["daterange"].oid + + +class TimestampRangeDumper(RangeDumper): + _oid = builtins["tsrange"].oid + + +class TimestamptzRangeDumper(RangeDumper): + _oid = builtins["tstzrange"].oid + + +# Binary dumpers for builtin range types wrappers +# These are registered on specific subtypes so that the upgrade mechanism +# doesn't kick in. + + +class Int4RangeBinaryDumper(RangeBinaryDumper): + _oid = builtins["int4range"].oid + + +class Int8RangeBinaryDumper(RangeBinaryDumper): + _oid = builtins["int8range"].oid + + +class NumericRangeBinaryDumper(RangeBinaryDumper): + _oid = builtins["numrange"].oid + + +class DateRangeBinaryDumper(RangeBinaryDumper): + _oid = builtins["daterange"].oid + + +class TimestampRangeBinaryDumper(RangeBinaryDumper): + _oid = builtins["tsrange"].oid + + +class TimestamptzRangeBinaryDumper(RangeBinaryDumper): + _oid = builtins["tstzrange"].oid + + # Text loaders for builtin range types diff --git a/tests/fix_faker.py b/tests/fix_faker.py index e10b02730..7b69d4779 100644 --- a/tests/fix_faker.py +++ b/tests/fix_faker.py @@ -436,8 +436,8 @@ class Faker: def make_Range(self, spec): # TODO: drop format check after fixing binary dumping of empty ranges - if random() < 0.02 and self.format == Format.TEXT: - return Range(empty=True) + if random() < 0.02 and spec[0] is Range and self.format == Format.TEXT: + return spec[0](empty=True) while True: bounds = [] @@ -459,23 +459,60 @@ class Faker: # avoid generating ranges with no type info if dumping in binary # TODO: lift this limitation after test_copy_in_empty xfail is fixed - if self.format == Format.BINARY: + if spec[0] is Range and self.format == Format.BINARY: if bounds[0] is bounds[1] is None: continue break - r = Range(bounds[0], bounds[1], choice("[(") + choice("])")) + r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])")) return r + def make_Int4Range(self, spec): + return self.make_Range((spec, Int4)) + + def make_Int8Range(self, spec): + return self.make_Range((spec, Int8)) + + def make_NumericRange(self, spec): + return self.make_Range((spec, Decimal)) + + def make_DateRange(self, spec): + return self.make_Range((spec, dt.date)) + + def make_TimestampRange(self, spec): + return self.make_Range((spec, (dt.datetime, False))) + + def make_TimestamptzRange(self, spec): + return self.make_Range((spec, (dt.datetime, True))) + def match_Range(self, spec, got, want): # normalise the bounds of unbounded ranges if want.lower is None and want.lower_inc: want = type(want)(want.lower, want.upper, "(" + want.bounds[1]) if want.upper is None and want.upper_inc: want = type(want)(want.lower, want.upper, want.bounds[0] + ")") + return got == want + def match_Int4Range(self, spec, got, want): + return self.match_Range((spec, Int4), got, want) + + def match_Int8Range(self, spec, got, want): + return self.match_Range((spec, Int8), got, want) + + def match_NumericRange(self, spec, got, want): + return self.match_Range((spec, Decimal), got, want) + + def match_DateRange(self, spec, got, want): + return self.match_Range((spec, dt.date), got, want) + + def match_TimestampRange(self, spec, got, want): + return self.match_Range((spec, (dt.datetime, False)), got, want) + + def match_TimestamptzRange(self, spec, got, want): + return self.match_Range((spec, (dt.datetime, True)), got, want) + def make_str(self, spec, length=0): if not length: length = randrange(self.str_max_length) diff --git a/tests/types/test_range.py b/tests/types/test_range.py index 884867a1e..fe6c940ca 100644 --- a/tests/types/test_range.py +++ b/tests/types/test_range.py @@ -188,6 +188,31 @@ def test_copy_in_empty(conn, min, max, bounds, format): assert rec[0] == r +@pytest.mark.parametrize("bounds", "() empty".split()) +@pytest.mark.parametrize( + "wrapper", + """Int4Range Int8Range NumericRange + DateRange TimestampRange TimestamptzRange""".split(), +) +@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY]) +def test_copy_in_empty_wrappers(conn, bounds, wrapper, format): + from psycopg3.types import range as range_module + + cur = conn.cursor() + cur.execute("create table copyrange (id serial primary key, r daterange)") + + cls = getattr(range_module, wrapper) + r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds) + + with cur.copy( + f"copy copyrange (r) from stdin (format {format.name})" + ) as copy: + copy.write_row([r]) + + rec = cur.execute("select r from copyrange order by id").fetchone() + assert rec[0] == r + + @pytest.fixture(scope="session") def testrange(svcconn): svcconn.execute(