From: Lele Gaifax Date: Mon, 7 Nov 2022 08:16:10 +0000 (+0100) Subject: Implement and test comparison methods X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=c833479c84b8e2ce65720b4ff9d6701dbe3aa802;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement and test comparison methods See issue #8765: some further study and understanding is required to implement the last missing method, __add__(). --- diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 44472ab32d..c1552874a3 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -251,45 +251,160 @@ class Range(Generic[_T]): else: return self._contains_value(value) - def overlaps(self, other): - """Boolean expression. Returns true if the column overlaps - (has points in common with) the right hand operand. - """ - raise NotImplementedError("not yet implemented") + def overlaps(self, other: Range) -> bool: + "Determine whether this range overlaps with `other`." - def strictly_left_of(self, other): - """Boolean expression. Returns true if the column is strictly - left of the right hand operand. - """ - raise NotImplementedError("not yet implemented") + # Empty ranges never overlap with any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower bound is contained in the other range + if ( + self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + and self._compare_edges(slower, slower_b, oupper, oupper_b) <= 0 + ): + return True + + # Check whether other lower bound is contained in this range + if ( + self._compare_edges(olower, olower_b, slower, slower_b) >= 0 + and self._compare_edges(olower, olower_b, supper, supper_b) <= 0 + ): + return True + + return False + + def strictly_left_of(self, other: Range) -> bool: + "Determine whether this range is completely to the left of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this upper edge is less than other's lower end + return self._compare_edges(supper, supper_b, olower, olower_b) < 0 __lshift__ = strictly_left_of - def strictly_right_of(self, other): - """Boolean expression. Returns true if the column is strictly - right of the right hand operand. - """ - raise NotImplementedError("not yet implemented") + def strictly_right_of(self, other: Range) -> bool: + "Determine whether this range is completely to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this lower edge is greater than other's upper end + return self._compare_edges(slower, slower_b, oupper, oupper_b) > 0 __rshift__ = strictly_right_of - def not_extend_right_of(self, other): - """Boolean expression. Returns true if the range in the column - does not extend right of the range in the operand. - """ - raise NotImplementedError("not yet implemented") + def not_extend_left_of(self, other: Range) -> bool: + "Determine whether this does not extend to the left of `other`." - def not_extend_left_of(self, other): - """Boolean expression. Returns true if the range in the column - does not extend left of the range in the operand. - """ - raise NotImplementedError("not yet implemented") + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False - def adjacent_to(self, other): - """Boolean expression. Returns true if the range in the column - is adjacent to the range in the operand. - """ - raise NotImplementedError("not yet implemented") + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + + # Check whether this lower edge is not less than other's lower end + return self._compare_edges(slower, slower_b, olower, olower_b) >= 0 + + def not_extend_right_of(self, other: Range) -> bool: + "Determine whether this does not extend to the right of `other`." + + # Empty ranges are neither to left nor to the right of any other range + if self.empty or other.empty: + return False + + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + # Check whether this upper edge is not greater than other's upper end + return self._compare_edges(supper, supper_b, oupper, oupper_b) <= 0 + + def _upper_edge_adjacent_to_lower( + self, + value1: Optional[_T], + bound1: str, + value2: Optional[_T], + bound2: str, + ) -> bool: + """Determine whether an upper bound is immediately successive to a + lower bound.""" + + # Since we need a peculiar way to handle the bounds inclusivity, + # just do a comparison by value here + res = self._compare_edges(value1, bound1, value2, bound2, True) + if res == -1: + step = self._get_discrete_step() + if step is None: + return False + if bound1 == "]": + if bound2 == "[": + return value1 == value2 - step + else: + return value1 == value2 + else: + if bound2 == "[": + return value1 == value2 + else: + return value1 == value2 - step + elif res == 0: + return ( + bound1 == ")" + and bound2 == "[" + or bound1 == "]" + and bound2 == "(" + ) + else: + return False + + def adjacent_to(self, other: Range) -> bool: + "Determine whether this range is adjacent to the `other`." + + # Empty ranges are not adjacent to any other range + if self.empty or other.empty: + return False + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + return self._upper_edge_adjacent_to_lower( + supper, supper_b, olower, olower_b + ) or self._upper_edge_adjacent_to_lower( + oupper, oupper_b, slower, slower_b + ) def __add__(self, other): """Range expression. Returns the union of the two ranges. diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 491ca2ff77..ef2972a0b8 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4039,6 +4039,174 @@ class _RangeComparisonFixtures: ) eq_(r2.contains(r1), contained, f"{r2}.contains({r1} != {contained})") + def test_overlaps( + self, + connection, + bounds_obj_combinations, + contains_range_obj_combinations, + ): + r1repr = contains_range_obj_combinations._stringify() + r2repr = bounds_obj_combinations._stringify() + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(contains_range_obj_combinations, RANGE).label("r1"), + cast(bounds_obj_combinations, RANGE).label("r2"), + cast(contains_range_obj_combinations, RANGE).overlaps( + bounds_obj_combinations + ), + ) + validate_q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + literal_column( + f"'{r1repr}'::{range_typ} && '{r2repr}'::{range_typ}" + ), + ) + orig_row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(orig_row, validate_row) + + r1, r2, overlaps = orig_row + eq_(r1.overlaps(r2), overlaps, f"{r1}.overlaps({r2}) != {overlaps}") + + def test_strictly_left_or_right_of( + self, + connection, + bounds_obj_combinations, + contains_range_obj_combinations, + ): + r1repr = contains_range_obj_combinations._stringify() + r2repr = bounds_obj_combinations._stringify() + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(contains_range_obj_combinations, RANGE).label("r1"), + cast(bounds_obj_combinations, RANGE).label("r2"), + cast(contains_range_obj_combinations, RANGE).strictly_left_of( + bounds_obj_combinations + ), + cast(contains_range_obj_combinations, RANGE).strictly_right_of( + bounds_obj_combinations + ), + ) + validate_q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + literal_column( + f"'{r1repr}'::{range_typ} << '{r2repr}'::{range_typ}" + ), + literal_column( + f"'{r1repr}'::{range_typ} >> '{r2repr}'::{range_typ}" + ), + ) + orig_row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(orig_row, validate_row) + + r1, r2, leftof, rightof = orig_row + eq_( + r1.strictly_left_of(r2), + leftof, + f"{r1}.strictly_left_of({r2}) != {leftof}", + ) + eq_(r1 << r2, leftof, f"{r1} << {r2} != {leftof}") + eq_( + r1.strictly_right_of(r2), + rightof, + f"{r1}.strictly_right_of({r2}) != {rightof}", + ) + eq_(r1 >> r2, rightof, f"{r1} >> {r2} != {rightof}") + + def test_not_extend_left_or_right_of( + self, + connection, + bounds_obj_combinations, + contains_range_obj_combinations, + ): + r1repr = contains_range_obj_combinations._stringify() + r2repr = bounds_obj_combinations._stringify() + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(contains_range_obj_combinations, RANGE).label("r1"), + cast(bounds_obj_combinations, RANGE).label("r2"), + cast(contains_range_obj_combinations, RANGE).not_extend_left_of( + bounds_obj_combinations + ), + cast(contains_range_obj_combinations, RANGE).not_extend_right_of( + bounds_obj_combinations + ), + ) + validate_q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + literal_column( + f"'{r1repr}'::{range_typ} &> '{r2repr}'::{range_typ}" + ), + literal_column( + f"'{r1repr}'::{range_typ} &< '{r2repr}'::{range_typ}" + ), + ) + orig_row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(orig_row, validate_row) + + r1, r2, leftof, rightof = orig_row + eq_( + r1.not_extend_left_of(r2), + leftof, + f"{r1}.not_extend_left_of({r2}) != {leftof}", + ) + eq_( + r1.not_extend_right_of(r2), + rightof, + f"{r1}.not_extend_right_of({r2}) != {rightof}", + ) + + def test_adjacent( + self, + connection, + bounds_obj_combinations, + contains_range_obj_combinations, + ): + r1repr = contains_range_obj_combinations._stringify() + r2repr = bounds_obj_combinations._stringify() + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(contains_range_obj_combinations, RANGE).label("r1"), + cast(bounds_obj_combinations, RANGE).label("r2"), + cast(contains_range_obj_combinations, RANGE).adjacent_to( + bounds_obj_combinations + ), + ) + validate_q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + literal_column( + f"'{r1repr}'::{range_typ} -|- '{r2repr}'::{range_typ}" + ), + ) + orig_row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(orig_row, validate_row) + + r1, r2, adjacent = orig_row + eq_( + r1.adjacent_to(r2), + adjacent, + f"{r1}.adjacent_to({r2}) != {adjacent}", + ) + class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): __requires__ = ("range_types",)