]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
First cut at difference() implementation and tests
authorLele Gaifax <lele@metapensiero.it>
Mon, 7 Nov 2022 21:29:19 +0000 (22:29 +0100)
committerLele Gaifax <lele@metapensiero.it>
Mon, 7 Nov 2022 21:29:19 +0000 (22:29 +0100)
This is not yet ready, in particular the tests will need more work: the
set used for Int4Range works, but the one generated for NumRange does
not.

lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

index 0822e1b653c697fc020dd7bac6495d120d121c82..a8863ba05ae961e0e2500414b1a1cd90b4485db0 100644 (file)
@@ -452,6 +452,67 @@ class Range(Generic[_T]):
 
     __add__ = union
 
+    def difference(self, other: Range) -> Range:
+        """Compute the difference between this range and the `other`.
+
+        This raises a ``ValueError`` exception if the two ranges are
+        "disjunct", that is neither adjacent nor overlapping.
+        """
+
+        # Subtracting an empty range is a no-op
+        if self.empty or other.empty:
+            return self
+
+        slower = self.lower
+        slower_b = self.bounds[0]
+        supper = self.upper
+        supper_b = self.bounds[1]
+        olower = other.lower
+        olower_b = other.bounds[0]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        sl_vs_ol = self._compare_edges(slower, slower_b, olower, olower_b)
+        su_vs_ou = self._compare_edges(supper, supper_b, oupper, oupper_b)
+        if sl_vs_ol < 0 and su_vs_ou > 0:
+            raise ValueError(
+                "Subtracting a strictly inner range is not implemented"
+            )
+
+        sl_vs_ou = self._compare_edges(slower, slower_b, oupper, oupper_b)
+        su_vs_ol = self._compare_edges(supper, supper_b, olower, olower_b)
+
+        # If the ranges do not overlap, result is simply the first
+        if sl_vs_ou > 0 or su_vs_ol < 0:
+            return self
+
+        # If this range is completely contained by the other, result is empty
+        if sl_vs_ol >= 0 and su_vs_ou <= 0:
+            return Range(None, None, empty=True)
+
+        # If this range extends to the left of the other and ends in its
+        # middle
+        if sl_vs_ol <= 0 and su_vs_ol >= 0 and su_vs_ou <= 0:
+            rupper_b = ")" if olower_b == "[" else "]"
+            if self._compare_edges(slower, slower_b, olower, rupper_b) == 0:
+                return Range(None, None, empty=True)
+            else:
+                return Range(slower, olower, bounds=slower_b + rupper_b)
+
+        # If this range starts in the middle of the other and extends to its
+        # right
+        if sl_vs_ol >= 0 and su_vs_ou >= 0 and sl_vs_ou <= 0:
+            rlower_b = "(" if oupper_b == "]" else "("
+            if self._compare_edges(oupper, rlower_b, supper, supper_b) == 0:
+                return Range(None, None, empty=True)
+            else:
+                return Range(oupper, supper, bounds=rlower_b + supper_b)
+
+        # TODO: figure out if I handled all the cases above
+        assert False
+
+    __sub__ = difference
+
     def __str__(self):
         return self._stringify()
 
@@ -584,6 +645,15 @@ class AbstractRange(sqltypes.TypeEngine):
 
         __add__ = union
 
+        def difference(self, other):
+            """Range expression. Returns the union of the two ranges.
+            Will raise an exception if the resulting range is not
+            contiguous.
+            """
+            return self.expr.op("-")(other)
+
+        __sub__ = difference
+
 
 class AbstractRangeImpl(AbstractRange):
     """marker for AbstractRange that will apply a subclass-specific
index b847ef02d9a81b1e86cfc4794568273a53f7383c..c10e06d45e6f53ffc279d25ff698761a9a4dff38 100644 (file)
@@ -3849,13 +3849,19 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             self.col.type,
         )
 
-    def test_different(self):
+    def test_difference(self):
         self._test_clause(
             self.col - self.col,
             "data_table.range - data_table.range",
             self.col.type,
         )
 
+        self._test_clause(
+            self.col.difference(self._data_str()),
+            "data_table.range - %(range_1)s",
+            self.col.type,
+        )
+
 
 class _RangeComparisonFixtures:
     def _data_str(self):
@@ -4260,6 +4266,51 @@ class _RangeComparisonFixtures:
                 f"{r1}.union({r2}) != {union}",
             )
 
+    @testing.combinations(
+        ("empty", "empty", False),
+        ("empty", "[{le},{re_}]", False),
+        ("[{le},{re_}]", "empty", False),
+        ("[{ll},{ih}]", "({le},{ih}]", False),
+        ("[{ll},{rh})", "[{le},{re_}]", False),
+        ("[{le},{re_}]", "({le},{re_})", True),
+        ("[{ll},{rh}]", "[{le},{re_}]", True),
+        argnames="r1repr,r2repr,err",
+    )
+    def test_difference(self, connection, r1repr, r2repr, err):
+        data = self._value_values()
+
+        if r1repr != "empty":
+            r1repr = r1repr.format(**data)
+        if r2repr != "empty":
+            r2repr = r2repr.format(**data)
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+        )
+
+        r1, r2 = connection.execute(q).first()
+
+        resq = select(
+            literal_column(f"'{r1}'::{range_typ}-'{r2}'::{range_typ}", RANGE),
+        )
+
+        if err:
+            with expect_raises(DataError):
+                connection.execute(resq).scalar()
+            with expect_raises(ValueError):
+                r1.difference(r2)
+        else:
+            difference = connection.execute(resq).scalar()
+            eq_(
+                r1.difference(r2),
+                difference,
+                f"{r1}.difference({r2}) != {difference}",
+            )
+
 
 class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
     __requires__ = ("range_types",)
@@ -4759,7 +4810,7 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
             self.col.type,
         )
 
-    def test_different(self):
+    def test_difference(self):
         self._test_clause(
             self.col - self.col,
             "data_table.multirange - data_table.multirange",