]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Use type-specific range subclasses to avoid the dump upgrade mechanism
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Wed, 23 Jun 2021 16:04:09 +0000 (17:04 +0100)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Fri, 25 Jun 2021 15:16:26 +0000 (16:16 +0100)
This allows binary dump of empty ranges, for which the type can be
inferred. In normal dumping these are dumped in text format with unknown
oid but that would break binary copy.

psycopg3/psycopg3/types/__init__.py
psycopg3/psycopg3/types/range.py
tests/fix_faker.py
tests/types/test_range.py

index 25d66d7b50032612e032d4e70e693d025c9322cb..b455ea12c917cbea767141dfff30320e25b293ff 100644 (file)
@@ -4,9 +4,10 @@ psycopg3 types package
 
 # Copyright (C) 2020-2021 The Psycopg Team
 
-from . import array
 from ..oids import INVALID_OID
 from ..proto import AdaptContext
+from .array import register_all_arrays
+from . import range as _range
 
 # Wrapper objects
 from ..wrappers.numeric import (
@@ -135,7 +136,19 @@ from .network import (
 )
 from .range import (
     RangeDumper as RangeDumper,
-    BinaryRangeDumper as BinaryRangeDumper,
+    Int4RangeDumper as Int4RangeDumper,
+    Int8RangeDumper as Int8RangeDumper,
+    NumericRangeDumper as NumericRangeDumper,
+    DateRangeDumper as DateRangeDumper,
+    TimestampRangeDumper as TimestampRangeDumper,
+    TimestamptzRangeDumper as TimestamptzRangeDumper,
+    RangeBinaryDumper as RangeBinaryDumper,
+    Int4RangeBinaryDumper as Int4RangeBinaryDumper,
+    Int8RangeBinaryDumper as Int8RangeBinaryDumper,
+    NumericRangeBinaryDumper as NumericRangeBinaryDumper,
+    DateRangeBinaryDumper as DateRangeBinaryDumper,
+    TimestampRangeBinaryDumper as TimestampRangeBinaryDumper,
+    TimestamptzRangeBinaryDumper as TimestamptzRangeBinaryDumper,
     RangeLoader as RangeLoader,
     Int4RangeLoader as Int4RangeLoader,
     Int8RangeLoader as Int8RangeLoader,
@@ -283,8 +296,20 @@ def register_default_globals(ctx: AdaptContext) -> None:
     CidrLoader.register("cidr", ctx)
     CidrBinaryLoader.register("cidr", ctx)
 
+    RangeBinaryDumper.register(Range, ctx)
     RangeDumper.register(Range, ctx)
-    BinaryRangeDumper.register(Range, ctx)
+    Int4RangeDumper.register(_range.Int4Range, ctx)
+    Int8RangeDumper.register(_range.Int8Range, ctx)
+    NumericRangeDumper.register(_range.NumericRange, ctx)
+    DateRangeDumper.register(_range.DateRange, ctx)
+    TimestampRangeDumper.register(_range.TimestampRange, ctx)
+    TimestamptzRangeDumper.register(_range.TimestamptzRange, ctx)
+    Int4RangeBinaryDumper.register(_range.Int4Range, ctx)
+    Int8RangeBinaryDumper.register(_range.Int8Range, ctx)
+    NumericRangeBinaryDumper.register(_range.NumericRange, ctx)
+    DateRangeBinaryDumper.register(_range.DateRange, ctx)
+    TimestampRangeBinaryDumper.register(_range.TimestampRange, ctx)
+    TimestamptzRangeBinaryDumper.register(_range.TimestamptzRange, ctx)
     Int4RangeLoader.register("int4range", ctx)
     Int8RangeLoader.register("int8range", ctx)
     NumericRangeLoader.register("numrange", ctx)
@@ -305,4 +330,4 @@ def register_default_globals(ctx: AdaptContext) -> None:
     RecordLoader.register("record", ctx)
     RecordBinaryLoader.register("record", ctx)
 
-    array.register_all_arrays(ctx)
+    register_all_arrays(ctx)
index f2819afa7170e319290163e3094ca39a3e6fde9c..3a618be3c6feaba77853dccca1ecc395868a7bf4 100644 (file)
@@ -225,6 +225,34 @@ class Range(Generic[T]):
             setattr(self, slot, value)
 
 
+# Subclasses to specify a specific subtype. Usually not needed: only needed
+# in binary copy, where switching to text is not an option.
+
+
+class Int4Range(Range[int]):
+    pass
+
+
+class Int8Range(Range[int]):
+    pass
+
+
+class NumericRange(Range[Decimal]):
+    pass
+
+
+class DateRange(Range[date]):
+    pass
+
+
+class TimestampRange(Range[datetime]):
+    pass
+
+
+class TimestamptzRange(Range[datetime]):
+    pass
+
+
 class BaseRangeDumper(RecursiveDumper):
     def __init__(self, cls: type, context: Optional[AdaptContext] = None):
         super().__init__(cls, context)
@@ -232,7 +260,13 @@ class BaseRangeDumper(RecursiveDumper):
         self._types = context.adapters.types if context else builtins
         self._adapt_format = Pg3Format.from_pq(self.format)
 
-    def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]:
+    def get_key(
+        self, obj: Range[Any], format: Pg3Format
+    ) -> Union[type, Tuple[type, ...]]:
+        # If we are a subclass whose oid is specified we don't need upgrade
+        if self.oid != INVALID_OID:
+            return self.cls
+
         item = self._get_item(obj)
         if item is not None:
             sd = self._tx.get_dumper(item, self._adapt_format)
@@ -241,6 +275,10 @@ class BaseRangeDumper(RecursiveDumper):
             return (self.cls,)
 
     def upgrade(self, obj: Range[Any], format: Pg3Format) -> "BaseRangeDumper":
+        # If we are a subclass whose oid is specified we don't need upgrade
+        if self.oid != INVALID_OID:
+            return self
+
         item = self._get_item(obj)
         if item is None:
             return RangeDumper(self.cls)
@@ -306,7 +344,7 @@ class RangeDumper(BaseRangeDumper, SequenceDumper):
     _re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
 
 
-class BinaryRangeDumper(BaseRangeDumper):
+class RangeBinaryDumper(BaseRangeDumper):
 
     format = Format.BINARY
 
@@ -421,6 +459,64 @@ def register_adapters(
     bloader.register(info.oid, context=context)
 
 
+# Text dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeDumper(RangeDumper):
+    _oid = builtins["int4range"].oid
+
+
+class Int8RangeDumper(RangeDumper):
+    _oid = builtins["int8range"].oid
+
+
+class NumericRangeDumper(RangeDumper):
+    _oid = builtins["numrange"].oid
+
+
+class DateRangeDumper(RangeDumper):
+    _oid = builtins["daterange"].oid
+
+
+class TimestampRangeDumper(RangeDumper):
+    _oid = builtins["tsrange"].oid
+
+
+class TimestamptzRangeDumper(RangeDumper):
+    _oid = builtins["tstzrange"].oid
+
+
+# Binary dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeBinaryDumper(RangeBinaryDumper):
+    _oid = builtins["int4range"].oid
+
+
+class Int8RangeBinaryDumper(RangeBinaryDumper):
+    _oid = builtins["int8range"].oid
+
+
+class NumericRangeBinaryDumper(RangeBinaryDumper):
+    _oid = builtins["numrange"].oid
+
+
+class DateRangeBinaryDumper(RangeBinaryDumper):
+    _oid = builtins["daterange"].oid
+
+
+class TimestampRangeBinaryDumper(RangeBinaryDumper):
+    _oid = builtins["tsrange"].oid
+
+
+class TimestamptzRangeBinaryDumper(RangeBinaryDumper):
+    _oid = builtins["tstzrange"].oid
+
+
 # Text loaders for builtin range types
 
 
index e10b02730994725123f620df4671270b769f46ec..7b69d477950c9f63984740c81f7de4b30882ca8c 100644 (file)
@@ -436,8 +436,8 @@ class Faker:
 
     def make_Range(self, spec):
         # TODO: drop format check after fixing binary dumping of empty ranges
-        if random() < 0.02 and self.format == Format.TEXT:
-            return Range(empty=True)
+        if random() < 0.02 and spec[0] is Range and self.format == Format.TEXT:
+            return spec[0](empty=True)
 
         while True:
             bounds = []
@@ -459,23 +459,60 @@ class Faker:
 
             # avoid generating ranges with no type info if dumping in binary
             # TODO: lift this limitation after test_copy_in_empty xfail is fixed
-            if self.format == Format.BINARY:
+            if spec[0] is Range and self.format == Format.BINARY:
                 if bounds[0] is bounds[1] is None:
                     continue
 
             break
 
-        r = Range(bounds[0], bounds[1], choice("[(") + choice("])"))
+        r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])"))
         return r
 
+    def make_Int4Range(self, spec):
+        return self.make_Range((spec, Int4))
+
+    def make_Int8Range(self, spec):
+        return self.make_Range((spec, Int8))
+
+    def make_NumericRange(self, spec):
+        return self.make_Range((spec, Decimal))
+
+    def make_DateRange(self, spec):
+        return self.make_Range((spec, dt.date))
+
+    def make_TimestampRange(self, spec):
+        return self.make_Range((spec, (dt.datetime, False)))
+
+    def make_TimestamptzRange(self, spec):
+        return self.make_Range((spec, (dt.datetime, True)))
+
     def match_Range(self, spec, got, want):
         # normalise the bounds of unbounded ranges
         if want.lower is None and want.lower_inc:
             want = type(want)(want.lower, want.upper, "(" + want.bounds[1])
         if want.upper is None and want.upper_inc:
             want = type(want)(want.lower, want.upper, want.bounds[0] + ")")
+
         return got == want
 
+    def match_Int4Range(self, spec, got, want):
+        return self.match_Range((spec, Int4), got, want)
+
+    def match_Int8Range(self, spec, got, want):
+        return self.match_Range((spec, Int8), got, want)
+
+    def match_NumericRange(self, spec, got, want):
+        return self.match_Range((spec, Decimal), got, want)
+
+    def match_DateRange(self, spec, got, want):
+        return self.match_Range((spec, dt.date), got, want)
+
+    def match_TimestampRange(self, spec, got, want):
+        return self.match_Range((spec, (dt.datetime, False)), got, want)
+
+    def match_TimestamptzRange(self, spec, got, want):
+        return self.match_Range((spec, (dt.datetime, True)), got, want)
+
     def make_str(self, spec, length=0):
         if not length:
             length = randrange(self.str_max_length)
index 884867a1e1e5b8cbd4c068c71a9d6c56063e3e57..fe6c940ca4fa330763b0bd28302cadee25a40e05 100644 (file)
@@ -188,6 +188,31 @@ def test_copy_in_empty(conn, min, max, bounds, format):
     assert rec[0] == r
 
 
+@pytest.mark.parametrize("bounds", "() empty".split())
+@pytest.mark.parametrize(
+    "wrapper",
+    """Int4Range Int8Range NumericRange
+       DateRange TimestampRange TimestamptzRange""".split(),
+)
+@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
+def test_copy_in_empty_wrappers(conn, bounds, wrapper, format):
+    from psycopg3.types import range as range_module
+
+    cur = conn.cursor()
+    cur.execute("create table copyrange (id serial primary key, r daterange)")
+
+    cls = getattr(range_module, wrapper)
+    r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds)
+
+    with cur.copy(
+        f"copy copyrange (r) from stdin (format {format.name})"
+    ) as copy:
+        copy.write_row([r])
+
+    rec = cur.execute("select r from copyrange order by id").fetchone()
+    assert rec[0] == r
+
+
 @pytest.fixture(scope="session")
 def testrange(svcconn):
     svcconn.execute(