# Copyright (C) 2020-2021 The Psycopg Team
-from . import array
from ..oids import INVALID_OID
from ..proto import AdaptContext
+from .array import register_all_arrays
+from . import range as _range
# Wrapper objects
from ..wrappers.numeric import (
)
from .range import (
RangeDumper as RangeDumper,
- BinaryRangeDumper as BinaryRangeDumper,
+ Int4RangeDumper as Int4RangeDumper,
+ Int8RangeDumper as Int8RangeDumper,
+ NumericRangeDumper as NumericRangeDumper,
+ DateRangeDumper as DateRangeDumper,
+ TimestampRangeDumper as TimestampRangeDumper,
+ TimestamptzRangeDumper as TimestamptzRangeDumper,
+ RangeBinaryDumper as RangeBinaryDumper,
+ Int4RangeBinaryDumper as Int4RangeBinaryDumper,
+ Int8RangeBinaryDumper as Int8RangeBinaryDumper,
+ NumericRangeBinaryDumper as NumericRangeBinaryDumper,
+ DateRangeBinaryDumper as DateRangeBinaryDumper,
+ TimestampRangeBinaryDumper as TimestampRangeBinaryDumper,
+ TimestamptzRangeBinaryDumper as TimestamptzRangeBinaryDumper,
RangeLoader as RangeLoader,
Int4RangeLoader as Int4RangeLoader,
Int8RangeLoader as Int8RangeLoader,
CidrLoader.register("cidr", ctx)
CidrBinaryLoader.register("cidr", ctx)
+ RangeBinaryDumper.register(Range, ctx)
RangeDumper.register(Range, ctx)
- BinaryRangeDumper.register(Range, ctx)
+ Int4RangeDumper.register(_range.Int4Range, ctx)
+ Int8RangeDumper.register(_range.Int8Range, ctx)
+ NumericRangeDumper.register(_range.NumericRange, ctx)
+ DateRangeDumper.register(_range.DateRange, ctx)
+ TimestampRangeDumper.register(_range.TimestampRange, ctx)
+ TimestamptzRangeDumper.register(_range.TimestamptzRange, ctx)
+ Int4RangeBinaryDumper.register(_range.Int4Range, ctx)
+ Int8RangeBinaryDumper.register(_range.Int8Range, ctx)
+ NumericRangeBinaryDumper.register(_range.NumericRange, ctx)
+ DateRangeBinaryDumper.register(_range.DateRange, ctx)
+ TimestampRangeBinaryDumper.register(_range.TimestampRange, ctx)
+ TimestamptzRangeBinaryDumper.register(_range.TimestamptzRange, ctx)
Int4RangeLoader.register("int4range", ctx)
Int8RangeLoader.register("int8range", ctx)
NumericRangeLoader.register("numrange", ctx)
RecordLoader.register("record", ctx)
RecordBinaryLoader.register("record", ctx)
- array.register_all_arrays(ctx)
+ register_all_arrays(ctx)
setattr(self, slot, value)
+# Subclasses to specify a specific subtype. Usually not needed: only needed
+# in binary copy, where switching to text is not an option.
+
+
+class Int4Range(Range[int]):
+ pass
+
+
+class Int8Range(Range[int]):
+ pass
+
+
+class NumericRange(Range[Decimal]):
+ pass
+
+
+class DateRange(Range[date]):
+ pass
+
+
+class TimestampRange(Range[datetime]):
+ pass
+
+
+class TimestamptzRange(Range[datetime]):
+ pass
+
+
class BaseRangeDumper(RecursiveDumper):
def __init__(self, cls: type, context: Optional[AdaptContext] = None):
super().__init__(cls, context)
self._types = context.adapters.types if context else builtins
self._adapt_format = Pg3Format.from_pq(self.format)
- def get_key(self, obj: Range[Any], format: Pg3Format) -> Tuple[type, ...]:
+ def get_key(
+ self, obj: Range[Any], format: Pg3Format
+ ) -> Union[type, Tuple[type, ...]]:
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.oid != INVALID_OID:
+ return self.cls
+
item = self._get_item(obj)
if item is not None:
sd = self._tx.get_dumper(item, self._adapt_format)
return (self.cls,)
def upgrade(self, obj: Range[Any], format: Pg3Format) -> "BaseRangeDumper":
+ # If we are a subclass whose oid is specified we don't need upgrade
+ if self.oid != INVALID_OID:
+ return self
+
item = self._get_item(obj)
if item is None:
return RangeDumper(self.cls)
_re_needs_quotes = re.compile(br'[",\\\s()\[\]]')
-class BinaryRangeDumper(BaseRangeDumper):
+class RangeBinaryDumper(BaseRangeDumper):
format = Format.BINARY
bloader.register(info.oid, context=context)
+# Text dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeDumper(RangeDumper):
+ _oid = builtins["int4range"].oid
+
+
+class Int8RangeDumper(RangeDumper):
+ _oid = builtins["int8range"].oid
+
+
+class NumericRangeDumper(RangeDumper):
+ _oid = builtins["numrange"].oid
+
+
+class DateRangeDumper(RangeDumper):
+ _oid = builtins["daterange"].oid
+
+
+class TimestampRangeDumper(RangeDumper):
+ _oid = builtins["tsrange"].oid
+
+
+class TimestamptzRangeDumper(RangeDumper):
+ _oid = builtins["tstzrange"].oid
+
+
+# Binary dumpers for builtin range types wrappers
+# These are registered on specific subtypes so that the upgrade mechanism
+# doesn't kick in.
+
+
+class Int4RangeBinaryDumper(RangeBinaryDumper):
+ _oid = builtins["int4range"].oid
+
+
+class Int8RangeBinaryDumper(RangeBinaryDumper):
+ _oid = builtins["int8range"].oid
+
+
+class NumericRangeBinaryDumper(RangeBinaryDumper):
+ _oid = builtins["numrange"].oid
+
+
+class DateRangeBinaryDumper(RangeBinaryDumper):
+ _oid = builtins["daterange"].oid
+
+
+class TimestampRangeBinaryDumper(RangeBinaryDumper):
+ _oid = builtins["tsrange"].oid
+
+
+class TimestamptzRangeBinaryDumper(RangeBinaryDumper):
+ _oid = builtins["tstzrange"].oid
+
+
# Text loaders for builtin range types
def make_Range(self, spec):
# TODO: drop format check after fixing binary dumping of empty ranges
- if random() < 0.02 and self.format == Format.TEXT:
- return Range(empty=True)
+ if random() < 0.02 and spec[0] is Range and self.format == Format.TEXT:
+ return spec[0](empty=True)
while True:
bounds = []
# avoid generating ranges with no type info if dumping in binary
# TODO: lift this limitation after test_copy_in_empty xfail is fixed
- if self.format == Format.BINARY:
+ if spec[0] is Range and self.format == Format.BINARY:
if bounds[0] is bounds[1] is None:
continue
break
- r = Range(bounds[0], bounds[1], choice("[(") + choice("])"))
+ r = spec[0](bounds[0], bounds[1], choice("[(") + choice("])"))
return r
+ def make_Int4Range(self, spec):
+ return self.make_Range((spec, Int4))
+
+ def make_Int8Range(self, spec):
+ return self.make_Range((spec, Int8))
+
+ def make_NumericRange(self, spec):
+ return self.make_Range((spec, Decimal))
+
+ def make_DateRange(self, spec):
+ return self.make_Range((spec, dt.date))
+
+ def make_TimestampRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, False)))
+
+ def make_TimestamptzRange(self, spec):
+ return self.make_Range((spec, (dt.datetime, True)))
+
def match_Range(self, spec, got, want):
# normalise the bounds of unbounded ranges
if want.lower is None and want.lower_inc:
want = type(want)(want.lower, want.upper, "(" + want.bounds[1])
if want.upper is None and want.upper_inc:
want = type(want)(want.lower, want.upper, want.bounds[0] + ")")
+
return got == want
+ def match_Int4Range(self, spec, got, want):
+ return self.match_Range((spec, Int4), got, want)
+
+ def match_Int8Range(self, spec, got, want):
+ return self.match_Range((spec, Int8), got, want)
+
+ def match_NumericRange(self, spec, got, want):
+ return self.match_Range((spec, Decimal), got, want)
+
+ def match_DateRange(self, spec, got, want):
+ return self.match_Range((spec, dt.date), got, want)
+
+ def match_TimestampRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, False)), got, want)
+
+ def match_TimestamptzRange(self, spec, got, want):
+ return self.match_Range((spec, (dt.datetime, True)), got, want)
+
def make_str(self, spec, length=0):
if not length:
length = randrange(self.str_max_length)
assert rec[0] == r
+@pytest.mark.parametrize("bounds", "() empty".split())
+@pytest.mark.parametrize(
+ "wrapper",
+ """Int4Range Int8Range NumericRange
+ DateRange TimestampRange TimestamptzRange""".split(),
+)
+@pytest.mark.parametrize("format", [pq.Format.TEXT, pq.Format.BINARY])
+def test_copy_in_empty_wrappers(conn, bounds, wrapper, format):
+ from psycopg3.types import range as range_module
+
+ cur = conn.cursor()
+ cur.execute("create table copyrange (id serial primary key, r daterange)")
+
+ cls = getattr(range_module, wrapper)
+ r = cls(empty=True) if bounds == "empty" else cls(None, None, bounds)
+
+ with cur.copy(
+ f"copy copyrange (r) from stdin (format {format.name})"
+ ) as copy:
+ copy.write_row([r])
+
+ rec = cur.execute("select r from copyrange order by id").fetchone()
+ assert rec[0] == r
+
+
@pytest.fixture(scope="session")
def testrange(svcconn):
svcconn.execute(