From: Lele Gaifax Date: Tue, 15 Nov 2022 20:27:34 +0000 (-0500) Subject: Issue #8765: implement missing methods on PG Range X-Git-Tag: rel_2_0_0b4~49^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bc696220bb0e183e26e52b3bd771459e387964a1;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Issue #8765: implement missing methods on PG Range ### Description This PR implements missing methods on the PG `Range` class, as described by issue #8765. ### Checklist This pull request is: - [ ] A documentation / typographical error fix - [ ] A short code fix - [x] A new feature implementation Closes: #8766 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8766 Pull-request-sha: 21c0df86cc0d1502855527e29425fbffc3f45d64 Change-Id: I86fabd966ad1f14a3a86132be741df46965b9aa9 --- diff --git a/doc/build/changelog/unreleased_20/8765.rst b/doc/build/changelog/unreleased_20/8765.rst new file mode 100644 index 0000000000..a210fb3486 --- /dev/null +++ b/doc/build/changelog/unreleased_20/8765.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 8765 + + Complementing :ticket:`8690`, new comparison methods such as + ``adjacent_to()``, ``difference()``, ``union()``, etc., were added to the + PG-specific range objects, bringing them in par with the standard + operators implemented by the underlying + :attr:`_postgresql.AbstractRange.comparator_factory`. Pull request + courtesy Lele Gaifax. diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 6729f3785f..a4c39d0639 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -134,6 +134,119 @@ class Range(Generic[_T]): else: return None + def _compare_edges( + self, + value1: Optional[_T], + bound1: str, + value2: Optional[_T], + bound2: str, + only_values: bool = False, + ) -> int: + """Compare two range bounds. + + Return -1, 0 or 1 respectively when `value1` is less than, + equal to or greater than `value2`. + + When `only_value` is ``True``, do not consider the *inclusivity* + of the edges, just their values. + """ + + value1_is_lower_bound = bound1 in {"[", "("} + value2_is_lower_bound = bound2 in {"[", "("} + + # Infinite edges are equal when they are on the same side, + # otherwise a lower edge is considered less than the upper end + if value1 is value2 is None: + if value1_is_lower_bound == value2_is_lower_bound: + return 0 + else: + return -1 if value1_is_lower_bound else 1 + elif value1 is None: + return -1 if value1_is_lower_bound else 1 + elif value2 is None: + return 1 if value2_is_lower_bound else -1 + + # Short path for trivial case + if bound1 == bound2 and value1 == value2: + return 0 + + value1_inc = bound1 in {"[", "]"} + value2_inc = bound2 in {"[", "]"} + step = self._get_discrete_step() + + if step is not None: + # "Normalize" the two edges as '[)', to simplify successive + # logic when the range is discrete: otherwise we would need + # to handle the comparison between ``(0`` and ``[1`` that + # are equal when dealing with integers while for floats the + # former is lesser than the latter + + if value1_is_lower_bound: + if not value1_inc: + value1 += step + value1_inc = True + else: + if value1_inc: + value1 += step + value1_inc = False + if value2_is_lower_bound: + if not value2_inc: + value2 += step + value2_inc = True + else: + if value2_inc: + value2 += step + value2_inc = False + + if value1 < value2: + return -1 + elif value1 > value2: + return 1 + elif only_values: + return 0 + else: + # Neither one is infinite but are equal, so we + # need to consider the respective inclusive/exclusive + # flag + + if value1_inc and value2_inc: + return 0 + elif not value1_inc and not value2_inc: + if value1_is_lower_bound == value2_is_lower_bound: + return 0 + else: + return 1 if value1_is_lower_bound else -1 + elif not value1_inc: + return 1 if value1_is_lower_bound else -1 + elif not value2_inc: + return -1 if value2_is_lower_bound else 1 + else: + return 0 + + def __eq__(self, other: Range) -> bool: + """Compare this range to the `other` taking into account + bounds inclusivity, returning ``True`` if they are equal. + """ + + if self.empty and other.empty: + return True + elif self.empty != other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + return ( + self._compare_edges(slower, slower_b, olower, olower_b) == 0 + and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0 + ) + def contained_by(self, other: Range) -> bool: "Determine whether this range is a contained by `other`." @@ -145,72 +258,23 @@ class Range(Generic[_T]): if other.empty: return False + slower = self.lower + slower_b = self.bounds[0] olower = other.lower - oupper = other.upper + olower_b = other.bounds[0] - # A bilateral unbound range contains any other range - if olower is oupper is None: - return True + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + return False - slower = self.lower supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] - # A lower-bound range cannot contain a lower-unbound range - if slower is None and olower is not None: - return False - - # Likewise on the right side - if supper is None and oupper is not None: + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: return False - slower_inc = self.bounds[0] == "[" - supper_inc = self.bounds[1] == "]" - olower_inc = other.bounds[0] == "[" - oupper_inc = other.bounds[1] == "]" - - # Check the lower end - step = -1 - if slower is not None and olower is not None: - lside = olower < slower - if not lside: - if not slower_inc or olower_inc: - lside = olower == slower - if not lside: - # Cover (1,x] vs [2,x) and (0,x] vs [1,x) - if not slower_inc and olower_inc and slower < olower: - step = self._get_discrete_step() - if step is not None: - lside = olower == (slower + step) - elif slower_inc and not olower_inc and slower > olower: - step = self._get_discrete_step() - if step is not None: - lside = (olower + step) == slower - if not lside: - return False - - # Lower end already considered, an upper-unbound range surely contains - # this - if oupper is None: - return True - - # Check the upper end - uside = oupper > supper - if not uside: - if not supper_inc or oupper_inc: - uside = oupper == supper - if not uside: - # Cover (x,2] vs [x,3) and (x,1] vs [x,2) - if supper_inc and not oupper_inc and supper < oupper: - if step == -1: - step = self._get_discrete_step() - if step is not None: - uside = oupper == (supper + step) - elif not supper_inc and oupper_inc and supper > oupper: - if step == -1: - step = self._get_discrete_step() - if step is not None: - uside = (oupper + step) == supper - return uside + return True def contains(self, value: Union[_T, Range]) -> bool: "Determine whether this range contains `value`." @@ -220,52 +284,286 @@ class Range(Generic[_T]): else: return self._contains_value(value) - def overlaps(self, other): - """Boolean expression. Returns true if the column overlaps - (has points in common with) the right hand operand. - """ - raise NotImplementedError("not yet implemented") + def overlaps(self, other: Range) -> bool: + "Determine whether this range overlaps with `other`." - def strictly_left_of(self, other): - """Boolean expression. Returns true if the column is strictly - left of the right hand operand. - """ - raise NotImplementedError("not yet implemented") + # Empty ranges never overlap with any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower bound is contained in the other range + if ( + self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + and self._compare_edges(slower, slower_b, oupper, oupper_b) <= 0 + ): + return True + + # Check whether other lower bound is contained in this range + if ( + self._compare_edges(olower, olower_b, slower, slower_b) >= 0 + and self._compare_edges(olower, olower_b, supper, supper_b) <= 0 + ): + return True + + return False + + def strictly_left_of(self, other: Range) -> bool: + "Determine whether this range is completely to the left of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this upper edge is less than other's lower end + return self._compare_edges(supper, supper_b, olower, olower_b) < 0 __lshift__ = strictly_left_of - def strictly_right_of(self, other): - """Boolean expression. Returns true if the column is strictly - right of the right hand operand. - """ - raise NotImplementedError("not yet implemented") + def strictly_right_of(self, other: Range) -> bool: + "Determine whether this range is completely to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower edge is greater than other's upper end + return self._compare_edges(slower, slower_b, oupper, oupper_b) > 0 __rshift__ = strictly_right_of - def not_extend_right_of(self, other): - """Boolean expression. Returns true if the range in the column - does not extend right of the range in the operand. - """ - raise NotImplementedError("not yet implemented") + def not_extend_left_of(self, other: Range) -> bool: + "Determine whether this does not extend to the left of `other`." - def not_extend_left_of(self, other): - """Boolean expression. Returns true if the range in the column - does not extend left of the range in the operand. - """ - raise NotImplementedError("not yet implemented") + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this lower edge is not less than other's lower end + return self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + + def not_extend_right_of(self, other: Range) -> bool: + "Determine whether this does not extend to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False - def adjacent_to(self, other): - """Boolean expression. Returns true if the range in the column - is adjacent to the range in the operand. + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this upper edge is not greater than other's upper end + return self._compare_edges(supper, supper_b, oupper, oupper_b) <= 0 + + def _upper_edge_adjacent_to_lower( + self, + value1: Optional[_T], + bound1: str, + value2: Optional[_T], + bound2: str, + ) -> bool: + """Determine whether an upper bound is immediately successive to a + lower bound.""" + + # Since we need a peculiar way to handle the bounds inclusivity, + # just do a comparison by value here + res = self._compare_edges(value1, bound1, value2, bound2, True) + if res == -1: + step = self._get_discrete_step() + if step is None: + return False + if bound1 == "]": + if bound2 == "[": + return value1 == value2 - step + else: + return value1 == value2 + else: + if bound2 == "[": + return value1 == value2 + else: + return value1 == value2 - step + elif res == 0: + # Cover cases like [0,0] -|- [1,] and [0,2) -|- (1,3] + if ( + bound1 == "]" + and bound2 == "[" + or bound1 == ")" + and bound2 == "(" + ): + step = self._get_discrete_step() + if step is not None: + return True + return ( + bound1 == ")" + and bound2 == "[" + or bound1 == "]" + and bound2 == "(" + ) + else: + return False + + def adjacent_to(self, other: Range) -> bool: + "Determine whether this range is adjacent to the `other`." + + # Empty ranges are not adjacent to any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + return self._upper_edge_adjacent_to_lower( + supper, supper_b, olower, olower_b + ) or self._upper_edge_adjacent_to_lower( + oupper, oupper_b, slower, slower_b + ) + + def union(self, other: Range) -> Range: + """Compute the union of this range with the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. """ - raise NotImplementedError("not yet implemented") - def __add__(self, other): - """Range expression. Returns the union of the two ranges. - Will raise an exception if the resulting range is not - contiguous. + # Empty ranges are "additive identities" + if self.empty: + return other + if other.empty: + return self + + if not self.overlaps(other) and not self.adjacent_to(other): + raise ValueError( + "Adding non-overlapping and non-adjacent" + " ranges is not implemented" + ) + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = slower + rlower_b = slower_b + else: + rlower = olower + rlower_b = olower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = supper + rupper_b = supper_b + else: + rupper = oupper + rupper_b = oupper_b + + return Range(rlower, rupper, bounds=rlower_b + rupper_b) + + __add__ = union + + def difference(self, other: Range) -> Range: + """Compute the difference between this range and the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. """ - raise NotImplementedError("not yet implemented") + + # Subtracting an empty range is a no-op + if self.empty or other.empty: + return self + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + sl_vs_ol = self._compare_edges(slower, slower_b, olower, olower_b) + su_vs_ou = self._compare_edges(supper, supper_b, oupper, oupper_b) + if sl_vs_ol < 0 and su_vs_ou > 0: + raise ValueError( + "Subtracting a strictly inner range is not implemented" + ) + + sl_vs_ou = self._compare_edges(slower, slower_b, oupper, oupper_b) + su_vs_ol = self._compare_edges(supper, supper_b, olower, olower_b) + + # If the ranges do not overlap, result is simply the first + if sl_vs_ou > 0 or su_vs_ol < 0: + return self + + # If this range is completely contained by the other, result is empty + if sl_vs_ol >= 0 and su_vs_ou <= 0: + return Range(None, None, empty=True) + + # If this range extends to the left of the other and ends in its + # middle + if sl_vs_ol <= 0 and su_vs_ol >= 0 and su_vs_ou <= 0: + rupper_b = ")" if olower_b == "[" else "]" + if ( + slower_b != "[" + and rupper_b != "]" + and self._compare_edges(slower, slower_b, olower, rupper_b) + == 0 + ): + return Range(None, None, empty=True) + else: + return Range(slower, olower, bounds=slower_b + rupper_b) + + # If this range starts in the middle of the other and extends to its + # right + if sl_vs_ol >= 0 and su_vs_ou >= 0 and sl_vs_ou <= 0: + rlower_b = "(" if oupper_b == "]" else "[" + if ( + rlower_b != "[" + and supper_b != "]" + and self._compare_edges(oupper, rlower_b, supper, supper_b) + == 0 + ): + return Range(None, None, empty=True) + else: + return Range(oupper, supper, bounds=rlower_b + supper_b) + + assert False, f"Unhandled case computing {self} - {other}" + + __sub__ = difference def __str__(self): return self._stringify() @@ -390,13 +688,24 @@ class AbstractRange(sqltypes.TypeEngine): """ return self.expr.op("-|-", is_comparison=True)(other) - def __add__(self, other): + def union(self, other): """Range expression. Returns the union of the two ranges. Will raise an exception if the resulting range is not contiguous. """ return self.expr.op("+")(other) + __add__ = union + + def difference(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + return self.expr.op("-")(other) + + __sub__ = difference + class AbstractRangeImpl(AbstractRange): """marker for AbstractRange that will apply a subclass-specific diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 39e7d73172..1e0e3df658 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -60,11 +60,13 @@ from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE from sqlalchemy.exc import CompileError +from sqlalchemy.exc import DBAPIError from sqlalchemy.orm import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import bindparam from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_false @@ -3669,7 +3671,28 @@ class HStoreRoundTripTest(fixtures.TablesTest): eq_(s.query(Data.data, Data).all(), [(d.data, d)]) -class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): +class _RangeTests: + _col_type = None + "The concrete range class these tests are for." + + _col_str = None + "The corresponding PG type name." + + _epsilon = None + """A small value used to generate range variants""" + + def _data_str(self): + """return string form of a sample range""" + raise NotImplementedError() + + def _data_obj(self): + """return Range form of the same range""" + raise NotImplementedError() + + +class _RangeTypeCompilation( + AssertsCompiledSQL, _RangeTests, fixtures.TestBase +): __dialect__ = "postgresql" # operator tests @@ -3835,6 +3858,12 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): self.col.type, ) + self._test_clause( + self.col.union(self._data_str()), + "data_table.range + %(range_1)s", + self.col.type, + ) + def test_intersection(self): self._test_clause( self.col * self.col, @@ -3842,23 +3871,21 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): self.col.type, ) - def test_different(self): + def test_difference(self): self._test_clause( self.col - self.col, "data_table.range - data_table.range", self.col.type, ) + self._test_clause( + self.col.difference(self._data_str()), + "data_table.range - %(range_1)s", + self.col.type, + ) -class _RangeComparisonFixtures: - def _data_str(self): - """return string form of a sample range""" - raise NotImplementedError() - - def _data_obj(self): - """return Range form of the same range""" - raise NotImplementedError() +class _RangeComparisonFixtures(_RangeTests): def _step_value_up(self, value): """given a value, return a step up @@ -3995,46 +4022,394 @@ class _RangeComparisonFixtures: r, expected = connection.execute(q).first() eq_(r.contains(v), expected) - def test_contains_range( - self, - connection, - bounds_obj_combinations, - contains_range_obj_combinations, - ): - r1repr = contains_range_obj_combinations._stringify() - r2repr = bounds_obj_combinations._stringify() + _common_ranges_to_test = ( + lambda r, e: Range(empty=True), + lambda r, e: Range(None, None, bounds="()"), + lambda r, e: Range(r.lower, None, bounds="[)"), + lambda r, e: Range(None, r.upper, bounds="(]"), + lambda r, e: r, + lambda r, e: Range(r.lower, r.upper, bounds="[]"), + lambda r, e: Range(r.lower, r.upper, bounds="(]"), + lambda r, e: Range(r.lower, r.upper, bounds="()"), + ) + + @testing.combinations( + *_common_ranges_to_test, + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower + e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower + e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower + e, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower + e, r.upper, bounds="(]"), + lambda r, e: Range(r.lower + e, r.upper, bounds="[]"), + lambda r, e: Range(r.lower + e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower - 2 * e, r.lower - e, bounds="(]"), + lambda r, e: Range(r.lower - 4 * e, r.lower, bounds="[)"), + lambda r, e: Range(r.upper + 4 * e, r.upper + 6 * e, bounds="()"), + argnames="r2t", + ) + def test_contains_range(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) RANGE = self._col_type range_typ = self._col_str q = select( - cast(contains_range_obj_combinations, RANGE).label("r1"), - cast(bounds_obj_combinations, RANGE).label("r2"), - cast(contains_range_obj_combinations, RANGE).contains( - bounds_obj_combinations - ), - cast(contains_range_obj_combinations, RANGE).contained_by( - bounds_obj_combinations - ), + cast(r1, RANGE).contains(r2), + cast(r1, RANGE).contained_by(r2), ) + validate_q = select( - literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), - literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), - literal_column( - f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}" - ), - literal_column( - f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}" - ), + literal_column(f"'{r1}'::{range_typ} @> '{r2}'::{range_typ}"), + literal_column(f"'{r1}'::{range_typ} <@ '{r2}'::{range_typ}"), ) - orig_row = connection.execute(q).first() + + row = connection.execute(q).first() validate_row = connection.execute(validate_q).first() - eq_(orig_row, validate_row) + eq_(row, validate_row) - r1, r2, contains, contained = orig_row - eq_(r1.contains(r2), contains) - eq_(r1.contained_by(r2), contained) - eq_(r2.contains(r1), contained) + pg_contains, pg_contained = row + py_contains = r1.contains(r2) + eq_( + py_contains, + pg_contains, + f"{r1}.contains({r2}): got {py_contains}," + f" expected {pg_contains}", + ) + py_contained = r1.contained_by(r2) + eq_( + py_contained, + pg_contained, + f"{r1}.contained_by({r2}): got {py_contained}," + f" expected {pg_contained}", + ) + eq_( + r2.contains(r1), + pg_contained, + f"{r2}.contains({r1}: got {r2.contains(r1)}," + f" expected {pg_contained})", + ) + + @testing.combinations( + *_common_ranges_to_test, + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower - 2 * e, r.lower - e, bounds="(]"), + lambda r, e: Range(r.upper + e, r.upper + 2 * e, bounds="[)"), + argnames="r2t", + ) + def test_overlaps(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).overlaps(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ} && '{r2}'::{range_typ}"), + ) + row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(row, validate_row) + + pg_res = row[0] + py_res = r1.overlaps(r2) + eq_( + py_res, + pg_res, + f"{r1}.overlaps({r2}): got {py_res}, expected {pg_res}", + ) + + @testing.combinations( + *_common_ranges_to_test, + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="[]"), + lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="(]"), + lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[]"), + lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[)"), + argnames="r2t", + ) + def test_strictly_left_or_right_of(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).strictly_left_of(r2), + cast(r1, RANGE).strictly_right_of(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ} << '{r2}'::{range_typ}"), + literal_column(f"'{r1}'::{range_typ} >> '{r2}'::{range_typ}"), + ) + + row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(row, validate_row) + + pg_left, pg_right = row + py_left = r1.strictly_left_of(r2) + eq_( + py_left, + pg_left, + f"{r1}.strictly_left_of({r2}): got {py_left}, expected {pg_left}", + ) + py_left = r1 << r2 + eq_( + py_left, + pg_left, + f"{r1} << {r2}: got {py_left}, expected {pg_left}", + ) + py_right = r1.strictly_right_of(r2) + eq_( + py_right, + pg_right, + f"{r1}.strictly_right_of({r2}): got {py_left}," + f" expected {pg_right}", + ) + py_right = r1 >> r2 + eq_( + py_right, + pg_right, + f"{r1} >> {r2}: got {py_left}, expected {pg_right}", + ) + + @testing.combinations( + *_common_ranges_to_test, + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="[]"), + lambda r, e: Range(r.upper, r.upper + 2 * e, bounds="(]"), + lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[]"), + lambda r, e: Range(r.lower - 2 * e, r.lower, bounds="[)"), + argnames="r2t", + ) + def test_not_extend_left_or_right_of(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).not_extend_left_of(r2), + cast(r1, RANGE).not_extend_right_of(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ} &> '{r2}'::{range_typ}"), + literal_column(f"'{r1}'::{range_typ} &< '{r2}'::{range_typ}"), + ) + row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(row, validate_row) + + pg_left, pg_right = row + py_left = r1.not_extend_left_of(r2) + eq_( + py_left, + pg_left, + f"{r1}.not_extend_left_of({r2}): got {py_left}," + f" expected {pg_left}", + ) + py_right = r1.not_extend_right_of(r2) + eq_( + py_right, + pg_right, + f"{r1}.not_extend_right_of({r2}): got {py_right}," + f" expected {pg_right}", + ) + + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower - e, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower - e, bounds="[]"), + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower - e, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower - e, bounds="[]"), + lambda r, e: Range(r.lower + e, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower + e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower + e, r.upper, bounds="(]"), + lambda r, e: Range(r.lower + e, r.upper, bounds="[]"), + lambda r, e: Range(r.lower + e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower - e, bounds="[]"), + lambda r, e: Range(r.lower - 2 * e, r.lower - e, bounds="(]"), + lambda r, e: Range(r.lower - 4 * e, r.lower, bounds="[)"), + lambda r, e: Range(r.upper + 4 * e, r.upper + 6 * e, bounds="()"), + argnames="r2t", + ) + def test_adjacent(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).adjacent_to(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ} -|- '{r2}'::{range_typ}"), + ) + + row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(row, validate_row) + + pg_res = row[0] + py_res = r1.adjacent_to(r2) + eq_( + py_res, + pg_res, + f"{r1}.adjacent_to({r2}): got {py_res}, expected {pg_res}", + ) + + @testing.combinations( + *_common_ranges_to_test, + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower + e, bounds="[]"), + lambda r, e: Range(r.upper + 4 * e, r.upper + 6 * e, bounds="()"), + argnames="r2t", + ) + def test_union(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).union(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ}+'{r2}'::{range_typ}", RANGE), + ) + + try: + pg_res = connection.execute(q).scalar() + except DBAPIError: + connection.rollback() + with expect_raises(DBAPIError): + connection.execute(validate_q).scalar() + with expect_raises(ValueError): + r1.union(r2) + else: + validate_union = connection.execute(validate_q).scalar() + eq_(pg_res, validate_union) + py_res = r1.union(r2) + eq_( + py_res, + pg_res, + f"{r1}.union({r2}): got {py_res}, expected {pg_res}", + ) + + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper - e, bounds="[]"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.upper + e, bounds="[]"), + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.upper - e, bounds="(]"), + lambda r, e: Range(r.lower, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower, bounds="(]"), + lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"), + lambda r, e: Range(r.lower, r.upper, bounds="[]"), + lambda r, e: Range(r.lower, r.upper, bounds="()"), + argnames="r2t", + ) + def test_difference(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(r1, RANGE).difference(r2), + ) + validate_q = select( + literal_column(f"'{r1}'::{range_typ}-'{r2}'::{range_typ}", RANGE), + ) + + try: + pg_res = connection.execute(q).scalar() + except DBAPIError: + connection.rollback() + with expect_raises(DBAPIError): + connection.execute(validate_q).scalar() + with expect_raises(ValueError): + r1.difference(r2) + else: + validate_difference = connection.execute(validate_q).scalar() + eq_(pg_res, validate_difference) + py_res = r1.difference(r2) + eq_( + py_res, + pg_res, + f"{r1}.difference({r2}): got {py_res}, expected {pg_res}", + ) + + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + argnames="r1t", + ) + @testing.combinations( + *_common_ranges_to_test, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower, bounds="(]"), + lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"), + argnames="r2t", + ) + def test_equality(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + range_typ = self._col_str + + q = select( + literal_column(f"'{r1}'::{range_typ} = '{r2}'::{range_typ}") + ) + equal = connection.execute(q).scalar() + eq_(r1 == r2, equal, f"{r1} == {r2}: got {r1 == r2}, expected {equal}") + + q = select( + literal_column(f"'{r1}'::{range_typ} <> '{r2}'::{range_typ}") + ) + different = connection.execute(q).scalar() + eq_( + r1 != r2, + different, + f"{r1} != {r2}: got {r1 != r2}, expected {different}", + ) class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): @@ -4194,6 +4569,8 @@ class _Int4RangeTests: def _data_obj(self): return Range(1, 4) + _epsilon = 1 + def _step_value_up(self, value): return value + 1 @@ -4212,6 +4589,8 @@ class _Int8RangeTests: def _data_obj(self): return Range(9223372036854775306, 9223372036854775800) + _epsilon = 1 + def _step_value_up(self, value): return value + 5 @@ -4230,6 +4609,8 @@ class _NumRangeTests: def _data_obj(self): return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0")) + _epsilon = decimal.Decimal(1) + def _step_value_up(self, value): return value + decimal.Decimal("1.8") @@ -4248,6 +4629,8 @@ class _DateRangeTests: def _data_obj(self): return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30)) + _epsilon = datetime.timedelta(days=1) + def _step_value_up(self, value): return value + datetime.timedelta(days=1) @@ -4269,6 +4652,8 @@ class _DateTimeRangeTests: datetime.datetime(2013, 3, 30, 23, 30), ) + _epsilon = datetime.timedelta(days=1) + def _step_value_up(self, value): return value + datetime.timedelta(days=1) @@ -4296,6 +4681,8 @@ class _DateTimeTZRangeTests: def _data_obj(self): return Range(*self.tstzs()) + _epsilon = datetime.timedelta(days=1) + def _step_value_up(self, value): return value + datetime.timedelta(days=1) @@ -4535,7 +4922,7 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): self.col.type, ) - def test_different(self): + def test_difference(self): self._test_clause( self.col - self.col, "data_table.multirange - data_table.multirange",