From fe8f6a19ab2686f2baa231130eb8a762042face0 Mon Sep 17 00:00:00 2001 From: Yurii Karabas <1998uriyyo@gmail.com> Date: Sun, 19 Mar 2023 11:56:56 +0200 Subject: [PATCH] Add intersection method to Range class --- doc/build/changelog/unreleased_20/9509.rst | 6 +++ lib/sqlalchemy/dialects/postgresql/ranges.py | 41 ++++++++++++++++++ test/dialect/postgresql/test_types.py | 44 ++++++++++++++++++++ 3 files changed, 91 insertions(+) create mode 100644 doc/build/changelog/unreleased_20/9509.rst diff --git a/doc/build/changelog/unreleased_20/9509.rst b/doc/build/changelog/unreleased_20/9509.rst new file mode 100644 index 0000000000..7312da421e --- /dev/null +++ b/doc/build/changelog/unreleased_20/9509.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 9509 + + Add missed :meth:`_postgresql.Range.intersection` method. + Pull request courtesy Yurii Karabas. \ No newline at end of file diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 3cf2ceb445..a1dfb1409c 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -641,6 +641,38 @@ class Range(Generic[_T]): def __sub__(self, other: Range[_T]) -> Range[_T]: return self.difference(other) + def intersection(self, other: Range[_T]) -> Range[_T]: + """Compute the intersection of this range with the `other`. + """ + if self.empty or other.empty or not self.overlaps(other): + return Range(empty=True) + + slower = self.lower + supper = self.upper + slower_b, supper_b = self.bounds + olower = other.lower + oupper = other.upper + olower_b, oupper_b = other.bounds + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = olower + rlower_b = olower_b + else: + rlower = slower + rlower_b = slower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = oupper + rupper_b = oupper_b + else: + rupper = supper + rupper_b = supper_b + + return Range(rlower, rupper, bounds=cast(_BoundsType, rlower_b + rupper_b)) + + def __mul__(self, other: Range[_T]) -> Range[_T]: + return self.intersection(other) + def __str__(self) -> str: return self._stringify() @@ -809,6 +841,15 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): __sub__ = difference + def intersection(self, other: Any) -> ColumnElement[Range[_T]]: + """Range expression. Returns the intersection of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("*")(other) + + __mul__ = intersection + class AbstractRangeImpl(AbstractRange[Range[_T]]): """Marker for AbstractRange that will apply a subclass-specific diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 1ff9d785fd..6824ac7ea3 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4393,6 +4393,50 @@ class _RangeComparisonFixtures(_RangeTests): f"{r1}.difference({r2}): got {py_res}, expected {pg_res}", ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"), + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower, bounds="(]"), + lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"), + lambda r, e: Range(r.lower, r.upper, bounds="[]"), + lambda r, e: Range(r.lower, r.upper, bounds="()"), + argnames="r2t", + ) + def test_intersection(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).intersection(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ}*'{r2}'::{range_typ}", RANGE), + ) + + pg_res = connection.execute(q).scalar() + + validate_difference = connection.execute(validate_q).scalar() + eq_(pg_res, validate_difference) + py_res = r1.difference(r2) + eq_( + py_res, + pg_res, + f"{r1}.intersection({r2}): got {py_res}, expected {pg_res}", + ) + @testing.combinations( *_common_ranges_to_test, lambda r, e: Range(r.lower, r.lower, bounds="[]"), -- 2.47.3