def __sub__(self, other: Range[_T]) -> Range[_T]:
return self.difference(other)
+ def intersection(self, other: Range[_T]) -> Range[_T]:
+ """Compute the intersection of this range with the `other`."""
+ if self.empty or other.empty or not self.overlaps(other):
+ return Range(None, None, empty=True)
+
+ slower = self.lower
+ slower_b = self.bounds[0]
+ supper = self.upper
+ supper_b = self.bounds[1]
+ olower = other.lower
+ olower_b = other.bounds[0]
+ oupper = other.upper
+ oupper_b = other.bounds[1]
+
+ if self._compare_edges(slower, slower_b, olower, olower_b) < 0:
+ rlower = olower
+ rlower_b = olower_b
+ else:
+ rlower = slower
+ rlower_b = slower_b
+
+ if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0:
+ rupper = oupper
+ rupper_b = oupper_b
+ else:
+ rupper = supper
+ rupper_b = supper_b
+
+ return Range(
+ rlower,
+ rupper,
+ bounds=cast(_BoundsType, rlower_b + rupper_b),
+ )
+
+ def __mul__(self, other: Range[_T]) -> Range[_T]:
+ return self.intersection(other)
+
def __str__(self) -> str:
return self._stringify()
__sub__ = difference
+ def intersection(self, other: Any) -> ColumnElement[Range[_T]]:
+ """Range expression. Returns the intersection of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ return self.expr.op("*")(other) # type: ignore
+
+ __mul__ = intersection
+
class AbstractRangeImpl(AbstractRange[Range[_T]]):
"""Marker for AbstractRange that will apply a subclass-specific
f"{r1}.difference({r2}): got {py_res}, expected {pg_res}",
)
+ @testing.combinations(
+ *_common_ranges_to_test,
+ lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+ lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"),
+ lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"),
+ lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"),
+ argnames="r1t",
+ )
+ @testing.combinations(
+ *_common_ranges_to_test,
+ lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+ lambda r, e: Range(r.lower, r.upper - e, bounds="(]"),
+ lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
+ lambda r, e: Range(r.lower - e, r.lower, bounds="(]"),
+ lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"),
+ lambda r, e: Range(r.lower, r.upper, bounds="[]"),
+ lambda r, e: Range(r.lower, r.upper, bounds="()"),
+ argnames="r2t",
+ )
+ def test_intersection(self, connection, r1t, r2t):
+ r1 = r1t(self._data_obj(), self._epsilon)
+ r2 = r2t(self._data_obj(), self._epsilon)
+
+ RANGE = self._col_type
+ range_typ = self._col_str
+
+ q = select(
+ cast(r1, RANGE).intersection(r2),
+ )
+ validate_q = select(
+ literal_column(f"'{r1}'::{range_typ}*'{r2}'::{range_typ}", RANGE),
+ )
+
+ pg_res = connection.execute(q).scalar()
+
+ validate_intersection = connection.execute(validate_q).scalar()
+ eq_(pg_res, validate_intersection)
+ py_res = r1.intersection(r2)
+ eq_(
+ py_res,
+ pg_res,
+ f"{r1}.intersection({r2}): got {py_res}, expected {pg_res}",
+ )
+
@testing.combinations(
*_common_ranges_to_test,
lambda r, e: Range(r.lower, r.lower, bounds="[]"),