From bde39e92290c256ec1baeef9d5975be4e3c7132b Mon Sep 17 00:00:00 2001 From: Lele Gaifax Date: Mon, 7 Nov 2022 22:29:19 +0100 Subject: [PATCH] First cut at difference() implementation and tests This is not yet ready, in particular the tests will need more work: the set used for Int4Range works, but the one generated for NumRange does not. --- lib/sqlalchemy/dialects/postgresql/ranges.py | 70 ++++++++++++++++++++ test/dialect/postgresql/test_types.py | 55 ++++++++++++++- 2 files changed, 123 insertions(+), 2 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 0822e1b653..a8863ba05a 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -452,6 +452,67 @@ class Range(Generic[_T]): __add__ = union + def difference(self, other: Range) -> Range: + """Compute the difference between this range and the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. + """ + + # Subtracting an empty range is a no-op + if self.empty or other.empty: + return self + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + sl_vs_ol = self._compare_edges(slower, slower_b, olower, olower_b) + su_vs_ou = self._compare_edges(supper, supper_b, oupper, oupper_b) + if sl_vs_ol < 0 and su_vs_ou > 0: + raise ValueError( + "Subtracting a strictly inner range is not implemented" + ) + + sl_vs_ou = self._compare_edges(slower, slower_b, oupper, oupper_b) + su_vs_ol = self._compare_edges(supper, supper_b, olower, olower_b) + + # If the ranges do not overlap, result is simply the first + if sl_vs_ou > 0 or su_vs_ol < 0: + return self + + # If this range is completely contained by the other, result is empty + if sl_vs_ol >= 0 and su_vs_ou <= 0: + return Range(None, None, empty=True) + + # If this range extends to the left of the other and ends in its + # middle + if sl_vs_ol <= 0 and su_vs_ol >= 0 and su_vs_ou <= 0: + rupper_b = ")" if olower_b == "[" else "]" + if self._compare_edges(slower, slower_b, olower, rupper_b) == 0: + return Range(None, None, empty=True) + else: + return Range(slower, olower, bounds=slower_b + rupper_b) + + # If this range starts in the middle of the other and extends to its + # right + if sl_vs_ol >= 0 and su_vs_ou >= 0 and sl_vs_ou <= 0: + rlower_b = "(" if oupper_b == "]" else "(" + if self._compare_edges(oupper, rlower_b, supper, supper_b) == 0: + return Range(None, None, empty=True) + else: + return Range(oupper, supper, bounds=rlower_b + supper_b) + + # TODO: figure out if I handled all the cases above + assert False + + __sub__ = difference + def __str__(self): return self._stringify() @@ -584,6 +645,15 @@ class AbstractRange(sqltypes.TypeEngine): __add__ = union + def difference(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("-")(other) + + __sub__ = difference + class AbstractRangeImpl(AbstractRange): """marker for AbstractRange that will apply a subclass-specific diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index b847ef02d9..c10e06d45e 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -3849,13 +3849,19 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): self.col.type, ) - def test_different(self): + def test_difference(self): self._test_clause( self.col - self.col, "data_table.range - data_table.range", self.col.type, ) + self._test_clause( + self.col.difference(self._data_str()), + "data_table.range - %(range_1)s", + self.col.type, + ) + class _RangeComparisonFixtures: def _data_str(self): @@ -4260,6 +4266,51 @@ class _RangeComparisonFixtures: f"{r1}.union({r2}) != {union}", ) + @testing.combinations( + ("empty", "empty", False), + ("empty", "[{le},{re_}]", False), + ("[{le},{re_}]", "empty", False), + ("[{ll},{ih}]", "({le},{ih}]", False), + ("[{ll},{rh})", "[{le},{re_}]", False), + ("[{le},{re_}]", "({le},{re_})", True), + ("[{ll},{rh}]", "[{le},{re_}]", True), + argnames="r1repr,r2repr,err", + ) + def test_difference(self, connection, r1repr, r2repr, err): + data = self._value_values() + + if r1repr != "empty": + r1repr = r1repr.format(**data) + if r2repr != "empty": + r2repr = r2repr.format(**data) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + ) + + r1, r2 = connection.execute(q).first() + + resq = select( + literal_column(f"'{r1}'::{range_typ}-'{r2}'::{range_typ}", RANGE), + ) + + if err: + with expect_raises(DataError): + connection.execute(resq).scalar() + with expect_raises(ValueError): + r1.difference(r2) + else: + difference = connection.execute(resq).scalar() + eq_( + r1.difference(r2), + difference, + f"{r1}.difference({r2}) != {difference}", + ) + class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): __requires__ = ("range_types",) @@ -4759,7 +4810,7 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): self.col.type, ) - def test_different(self): + def test_difference(self): self._test_clause( self.col - self.col, "data_table.multirange - data_table.multirange", -- 2.47.3