]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add intersection method to Range class
authorYurii Karabas <1998uriyyo@gmail.com>
Sun, 19 Mar 2023 09:56:56 +0000 (11:56 +0200)
committerYurii Karabas <1998uriyyo@gmail.com>
Sun, 19 Mar 2023 10:00:59 +0000 (12:00 +0200)
doc/build/changelog/unreleased_20/9509.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/9509.rst b/doc/build/changelog/unreleased_20/9509.rst
new file mode 100644 (file)
index 0000000..7312da4
--- /dev/null
@@ -0,0 +1,6 @@
+.. change::
+      :tags: usecase, postgresql
+      :tickets: 9509
+
+      Add missed :meth:`_postgresql.Range.intersection` method.
+      Pull request courtesy Yurii Karabas.
\ No newline at end of file
index 3cf2ceb445c6732a51203173958663599b91b8bc..a1dfb1409c426771a1a5117263770c3c13429dea 100644 (file)
@@ -641,6 +641,38 @@ class Range(Generic[_T]):
     def __sub__(self, other: Range[_T]) -> Range[_T]:
         return self.difference(other)
 
+    def intersection(self, other: Range[_T]) -> Range[_T]:
+        """Compute the intersection of this range with the `other`.
+        """
+        if self.empty or other.empty or not self.overlaps(other):
+            return Range(empty=True)
+
+        slower = self.lower
+        supper = self.upper
+        slower_b, supper_b = self.bounds
+        olower = other.lower
+        oupper = other.upper
+        olower_b, oupper_b = other.bounds
+
+        if self._compare_edges(slower, slower_b, olower, olower_b) < 0:
+            rlower = olower
+            rlower_b = olower_b
+        else:
+            rlower = slower
+            rlower_b = slower_b
+
+        if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0:
+            rupper = oupper
+            rupper_b = oupper_b
+        else:
+            rupper = supper
+            rupper_b = supper_b
+
+        return Range(rlower, rupper, bounds=cast(_BoundsType, rlower_b + rupper_b))
+
+    def __mul__(self, other: Range[_T]) -> Range[_T]:
+        return self.intersection(other)
+
     def __str__(self) -> str:
         return self._stringify()
 
@@ -809,6 +841,15 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
 
         __sub__ = difference
 
+        def intersection(self, other: Any) -> ColumnElement[Range[_T]]:
+            """Range expression. Returns the intersection of the two ranges.
+            Will raise an exception if the resulting range is not
+            contiguous.
+            """
+            return self.expr.op("*")(other)
+
+        __mul__ = intersection
+
 
 class AbstractRangeImpl(AbstractRange[Range[_T]]):
     """Marker for AbstractRange that will apply a subclass-specific
index 1ff9d785fd1b6d0705d96a7dbe9cd34bfc288d97..6824ac7ea3e126e18fcfb8b683b7da54059fda50 100644 (file)
@@ -4393,6 +4393,50 @@ class _RangeComparisonFixtures(_RangeTests):
                 f"{r1}.difference({r2}): got {py_res}, expected {pg_res}",
             )
 
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"),
+        lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"),
+        argnames="r1t",
+    )
+    @testing.combinations(
+        *_common_ranges_to_test,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper - e, bounds="(]"),
+        lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.lower, bounds="(]"),
+        lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"),
+        lambda r, e: Range(r.lower, r.upper, bounds="[]"),
+        lambda r, e: Range(r.lower, r.upper, bounds="()"),
+        argnames="r2t",
+    )
+    def test_intersection(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            cast(r1, RANGE).intersection(r2),
+        )
+        validate_q = select(
+            literal_column(f"'{r1}'::{range_typ}*'{r2}'::{range_typ}", RANGE),
+        )
+
+        pg_res = connection.execute(q).scalar()
+
+        validate_difference = connection.execute(validate_q).scalar()
+        eq_(pg_res, validate_difference)
+        py_res = r1.difference(r2)
+        eq_(
+            py_res,
+            pg_res,
+            f"{r1}.intersection({r2}): got {py_res}, expected {pg_res}",
+        )
+
     @testing.combinations(
         *_common_ranges_to_test,
         lambda r, e: Range(r.lower, r.lower, bounds="[]"),