]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add intersection method to Range class
authorYurii Karabas <1998uriyyo@gmail.com>
Fri, 14 Apr 2023 17:37:40 +0000 (13:37 -0400)
committersqla-tester <sqla-tester@sqlalchemy.org>
Fri, 14 Apr 2023 17:37:40 +0000 (13:37 -0400)
<!-- Provide a general summary of your proposed changes in the Title field above -->

### Description
Fixes: #9509
<!-- Describe your changes in detail -->

### Checklist
<!-- go over following points. check them with an `x` if they do apply, (they turn into clickable checkboxes once the PR is submitted, so no need to do everything at once)

-->

This pull request is:

- [ ] A documentation / typographical error fix
- Good to go, no issue or tests are needed
- [ ] A short code fix
- please include the issue number, and create an issue if none exists, which
  must include a complete example of the issue.  one line code fixes without an
  issue and demonstration will not be accepted.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.   one line code fixes without tests will not be accepted.
- [x] A new feature implementation
- please include the issue number, and create an issue if none exists, which must
  include a complete example of how the feature would look.
- Please include: `Fixes: #<issue number>` in the commit message
- please include tests.

**Have a nice day!**

Closes: #9510
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9510
Pull-request-sha: 596648e7989327eef1807057519b2295b48f1adf

Change-Id: I7b527edda09eb78dee6948edd4d49b00ea437011

doc/build/changelog/unreleased_20/9509.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/9509.rst b/doc/build/changelog/unreleased_20/9509.rst
new file mode 100644 (file)
index 0000000..b50a4a0
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+      :tags: usecase, postgresql
+      :tickets: 9509
+
+      Add missing :meth:`_postgresql.Range.intersection` method.
+      Pull request courtesy Yurii Karabas.
index 3cf2ceb445c6732a51203173958663599b91b8bc..cefd280ea4df9c36d717f491c3fc3833843d4d2c 100644 (file)
@@ -641,6 +641,43 @@ class Range(Generic[_T]):
     def __sub__(self, other: Range[_T]) -> Range[_T]:
         return self.difference(other)
 
+    def intersection(self, other: Range[_T]) -> Range[_T]:
+        """Compute the intersection of this range with the `other`."""
+        if self.empty or other.empty or not self.overlaps(other):
+            return Range(None, None, empty=True)
+
+        slower = self.lower
+        slower_b = self.bounds[0]
+        supper = self.upper
+        supper_b = self.bounds[1]
+        olower = other.lower
+        olower_b = other.bounds[0]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        if self._compare_edges(slower, slower_b, olower, olower_b) < 0:
+            rlower = olower
+            rlower_b = olower_b
+        else:
+            rlower = slower
+            rlower_b = slower_b
+
+        if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0:
+            rupper = oupper
+            rupper_b = oupper_b
+        else:
+            rupper = supper
+            rupper_b = supper_b
+
+        return Range(
+            rlower,
+            rupper,
+            bounds=cast(_BoundsType, rlower_b + rupper_b),
+        )
+
+    def __mul__(self, other: Range[_T]) -> Range[_T]:
+        return self.intersection(other)
+
     def __str__(self) -> str:
         return self._stringify()
 
@@ -809,6 +846,15 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
 
         __sub__ = difference
 
+        def intersection(self, other: Any) -> ColumnElement[Range[_T]]:
+            """Range expression. Returns the intersection of the two ranges.
+            Will raise an exception if the resulting range is not
+            contiguous.
+            """
+            return self.expr.op("*")(other)  # type: ignore
+
+        __mul__ = intersection
+
 
 class AbstractRangeImpl(AbstractRange[Range[_T]]):
     """Marker for AbstractRange that will apply a subclass-specific
index 5f5be3c571ec824ebb8b2798b84798b9cd37f473..f322bf35481b9a1f1f219700661f597b08cab9c2 100644 (file)
@@ -4496,6 +4496,50 @@ class _RangeComparisonFixtures(_RangeTests):
                 f"{r1}.difference({r2}): got {py_res}, expected {pg_res}",
             )
 
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"),
+        lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"),
+        argnames="r1t",
+    )
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper - e, bounds="(]"),
+        lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.lower, bounds="(]"),
+        lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"),
+        lambda r, e: Range(r.lower, r.upper, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper, bounds="()"),
+        argnames="r2t",
+    )
+    def test_intersection(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            cast(r1, RANGE).intersection(r2),
+        )
+        validate_q = select(
+            literal_column(f"'{r1}'::{range_typ}*'{r2}'::{range_typ}", RANGE),
+        )
+
+        pg_res = connection.execute(q).scalar()
+
+        validate_intersection = connection.execute(validate_q).scalar()
+        eq_(pg_res, validate_intersection)
+        py_res = r1.intersection(r2)
+        eq_(
+            py_res,
+            pg_res,
+            f"{r1}.intersection({r2}): got {py_res}, expected {pg_res}",
+        )
+
     @testing.combinations(
         *_common_ranges_to_test,
         lambda r, e: Range(r.lower, r.lower, bounds="[]"),