From: Yurii Karabas <1998uriyyo@gmail.com> Date: Fri, 14 Apr 2023 17:37:40 +0000 (-0400) Subject: Add intersection method to Range class X-Git-Tag: rel_2_0_10~6^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=609f432563954167b8f0148e43c70c08380e8ba4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add intersection method to Range class ### Description Fixes: #9509 ### Checklist This pull request is: - [ ] A documentation / typographical error fix - Good to go, no issue or tests are needed - [ ] A short code fix - please include the issue number, and create an issue if none exists, which must include a complete example of the issue. one line code fixes without an issue and demonstration will not be accepted. - Please include: `Fixes: #` in the commit message - please include tests. one line code fixes without tests will not be accepted. - [x] A new feature implementation - please include the issue number, and create an issue if none exists, which must include a complete example of how the feature would look. - Please include: `Fixes: #` in the commit message - please include tests. **Have a nice day!** Closes: #9510 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9510 Pull-request-sha: 596648e7989327eef1807057519b2295b48f1adf Change-Id: I7b527edda09eb78dee6948edd4d49b00ea437011 --- diff --git a/doc/build/changelog/unreleased_20/9509.rst b/doc/build/changelog/unreleased_20/9509.rst new file mode 100644 index 0000000000..b50a4a0286 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9509.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 9509 + + Add missing :meth:`_postgresql.Range.intersection` method. + Pull request courtesy Yurii Karabas. diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 3cf2ceb445..cefd280ea4 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -641,6 +641,43 @@ class Range(Generic[_T]): def __sub__(self, other: Range[_T]) -> Range[_T]: return self.difference(other) + def intersection(self, other: Range[_T]) -> Range[_T]: + """Compute the intersection of this range with the `other`.""" + if self.empty or other.empty or not self.overlaps(other): + return Range(None, None, empty=True) + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = olower + rlower_b = olower_b + else: + rlower = slower + rlower_b = slower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = oupper + rupper_b = oupper_b + else: + rupper = supper + rupper_b = supper_b + + return Range( + rlower, + rupper, + bounds=cast(_BoundsType, rlower_b + rupper_b), + ) + + def __mul__(self, other: Range[_T]) -> Range[_T]: + return self.intersection(other) + def __str__(self) -> str: return self._stringify() @@ -809,6 +846,15 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): __sub__ = difference + def intersection(self, other: Any) -> ColumnElement[Range[_T]]: + """Range expression. Returns the intersection of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("*")(other) # type: ignore + + __mul__ = intersection + class AbstractRangeImpl(AbstractRange[Range[_T]]): """Marker for AbstractRange that will apply a subclass-specific diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 5f5be3c571..f322bf3548 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4496,6 +4496,50 @@ class _RangeComparisonFixtures(_RangeTests): f"{r1}.difference({r2}): got {py_res}, expected {pg_res}", ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"), + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower, bounds="(]"), + lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"), + lambda r, e: Range(r.lower, r.upper, bounds="[]"), + lambda r, e: Range(r.lower, r.upper, bounds="()"), + argnames="r2t", + ) + def test_intersection(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).intersection(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ}*'{r2}'::{range_typ}", RANGE), + ) + + pg_res = connection.execute(q).scalar() + + validate_intersection = connection.execute(validate_q).scalar() + eq_(pg_res, validate_intersection) + py_res = r1.intersection(r2) + eq_( + py_res, + pg_res, + f"{r1}.intersection({r2}): got {py_res}, expected {pg_res}", + ) + @testing.combinations( *_common_ranges_to_test, lambda r, e: Range(r.lower, r.lower, bounds="[]"),