]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement and test comparison methods
authorLele Gaifax <lele@metapensiero.it>
Mon, 7 Nov 2022 08:16:10 +0000 (09:16 +0100)
committerLele Gaifax <lele@metapensiero.it>
Mon, 7 Nov 2022 08:23:13 +0000 (09:23 +0100)
See issue #8765: some further study and understanding is required to
implement the last missing method, __add__().

lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

index 44472ab32d33477b4d802c555c72191f1c510988..c1552874a31111ecf11097ab2798e048f4e71e97 100644 (file)
@@ -251,45 +251,160 @@ class Range(Generic[_T]):
         else:
             return self._contains_value(value)
 
-    def overlaps(self, other):
-        """Boolean expression. Returns true if the column overlaps
-        (has points in common with) the right hand operand.
-        """
-        raise NotImplementedError("not yet implemented")
+    def overlaps(self, other: Range) -> bool:
+        "Determine whether this range overlaps with `other`."
 
-    def strictly_left_of(self, other):
-        """Boolean expression. Returns true if the column is strictly
-        left of the right hand operand.
-        """
-        raise NotImplementedError("not yet implemented")
+        # Empty ranges never overlap with any other range
+        if self.empty or other.empty:
+            return False
+
+        slower = self.lower
+        slower_b = self.bounds[0]
+        supper = self.upper
+        supper_b = self.bounds[1]
+        olower = other.lower
+        olower_b = other.bounds[0]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        # Check whether this lower bound is contained in the other range
+        if (
+            self._compare_edges(slower, slower_b, olower, olower_b) >= 0
+            and self._compare_edges(slower, slower_b, oupper, oupper_b) <= 0
+        ):
+            return True
+
+        # Check whether other lower bound is contained in this range
+        if (
+            self._compare_edges(olower, olower_b, slower, slower_b) >= 0
+            and self._compare_edges(olower, olower_b, supper, supper_b) <= 0
+        ):
+            return True
+
+        return False
+
+    def strictly_left_of(self, other: Range) -> bool:
+        "Determine whether this range is completely to the left of `other`."
+
+        # Empty ranges are neither to left nor to the right of any other range
+        if self.empty or other.empty:
+            return False
+
+        supper = self.upper
+        supper_b = self.bounds[1]
+        olower = other.lower
+        olower_b = other.bounds[0]
+
+        # Check whether this upper edge is less than other's lower end
+        return self._compare_edges(supper, supper_b, olower, olower_b) < 0
 
     __lshift__ = strictly_left_of
 
-    def strictly_right_of(self, other):
-        """Boolean expression. Returns true if the column is strictly
-        right of the right hand operand.
-        """
-        raise NotImplementedError("not yet implemented")
+    def strictly_right_of(self, other: Range) -> bool:
+        "Determine whether this range is completely to the right of `other`."
+
+        # Empty ranges are neither to left nor to the right of any other range
+        if self.empty or other.empty:
+            return False
+
+        slower = self.lower
+        slower_b = self.bounds[0]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        # Check whether this lower edge is greater than other's upper end
+        return self._compare_edges(slower, slower_b, oupper, oupper_b) > 0
 
     __rshift__ = strictly_right_of
 
-    def not_extend_right_of(self, other):
-        """Boolean expression. Returns true if the range in the column
-        does not extend right of the range in the operand.
-        """
-        raise NotImplementedError("not yet implemented")
+    def not_extend_left_of(self, other: Range) -> bool:
+        "Determine whether this does not extend to the left of `other`."
 
-    def not_extend_left_of(self, other):
-        """Boolean expression. Returns true if the range in the column
-        does not extend left of the range in the operand.
-        """
-        raise NotImplementedError("not yet implemented")
+        # Empty ranges are neither to left nor to the right of any other range
+        if self.empty or other.empty:
+            return False
 
-    def adjacent_to(self, other):
-        """Boolean expression. Returns true if the range in the column
-        is adjacent to the range in the operand.
-        """
-        raise NotImplementedError("not yet implemented")
+        slower = self.lower
+        slower_b = self.bounds[0]
+        olower = other.lower
+        olower_b = other.bounds[0]
+
+        # Check whether this lower edge is not less than other's lower end
+        return self._compare_edges(slower, slower_b, olower, olower_b) >= 0
+
+    def not_extend_right_of(self, other: Range) -> bool:
+        "Determine whether this does not extend to the right of `other`."
+
+        # Empty ranges are neither to left nor to the right of any other range
+        if self.empty or other.empty:
+            return False
+
+        supper = self.upper
+        supper_b = self.bounds[1]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        # Check whether this upper edge is not greater than other's upper end
+        return self._compare_edges(supper, supper_b, oupper, oupper_b) <= 0
+
+    def _upper_edge_adjacent_to_lower(
+        self,
+        value1: Optional[_T],
+        bound1: str,
+        value2: Optional[_T],
+        bound2: str,
+    ) -> bool:
+        """Determine whether an upper bound is immediately successive to a
+        lower bound."""
+
+        # Since we need a peculiar way to handle the bounds inclusivity,
+        # just do a comparison by value here
+        res = self._compare_edges(value1, bound1, value2, bound2, True)
+        if res == -1:
+            step = self._get_discrete_step()
+            if step is None:
+                return False
+            if bound1 == "]":
+                if bound2 == "[":
+                    return value1 == value2 - step
+                else:
+                    return value1 == value2
+            else:
+                if bound2 == "[":
+                    return value1 == value2
+                else:
+                    return value1 == value2 - step
+        elif res == 0:
+            return (
+                bound1 == ")"
+                and bound2 == "["
+                or bound1 == "]"
+                and bound2 == "("
+            )
+        else:
+            return False
+
+    def adjacent_to(self, other: Range) -> bool:
+        "Determine whether this range is adjacent to the `other`."
+
+        # Empty ranges are not adjacent to any other range
+        if self.empty or other.empty:
+            return False
+
+        slower = self.lower
+        slower_b = self.bounds[0]
+        supper = self.upper
+        supper_b = self.bounds[1]
+        olower = other.lower
+        olower_b = other.bounds[0]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        return self._upper_edge_adjacent_to_lower(
+            supper, supper_b, olower, olower_b
+        ) or self._upper_edge_adjacent_to_lower(
+            oupper, oupper_b, slower, slower_b
+        )
 
     def __add__(self, other):
         """Range expression. Returns the union of the two ranges.
index 491ca2ff7702cbd64f8f210ad3626433c738c5ec..ef2972a0b880db8d60c5d4e129c169b5b4bac42d 100644 (file)
@@ -4039,6 +4039,174 @@ class _RangeComparisonFixtures:
         )
         eq_(r2.contains(r1), contained, f"{r2}.contains({r1} != {contained})")
 
+    def test_overlaps(
+        self,
+        connection,
+        bounds_obj_combinations,
+        contains_range_obj_combinations,
+    ):
+        r1repr = contains_range_obj_combinations._stringify()
+        r2repr = bounds_obj_combinations._stringify()
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            cast(contains_range_obj_combinations, RANGE).label("r1"),
+            cast(bounds_obj_combinations, RANGE).label("r2"),
+            cast(contains_range_obj_combinations, RANGE).overlaps(
+                bounds_obj_combinations
+            ),
+        )
+        validate_q = select(
+            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+            literal_column(
+                f"'{r1repr}'::{range_typ} && '{r2repr}'::{range_typ}"
+            ),
+        )
+        orig_row = connection.execute(q).first()
+        validate_row = connection.execute(validate_q).first()
+        eq_(orig_row, validate_row)
+
+        r1, r2, overlaps = orig_row
+        eq_(r1.overlaps(r2), overlaps, f"{r1}.overlaps({r2}) != {overlaps}")
+
+    def test_strictly_left_or_right_of(
+        self,
+        connection,
+        bounds_obj_combinations,
+        contains_range_obj_combinations,
+    ):
+        r1repr = contains_range_obj_combinations._stringify()
+        r2repr = bounds_obj_combinations._stringify()
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            cast(contains_range_obj_combinations, RANGE).label("r1"),
+            cast(bounds_obj_combinations, RANGE).label("r2"),
+            cast(contains_range_obj_combinations, RANGE).strictly_left_of(
+                bounds_obj_combinations
+            ),
+            cast(contains_range_obj_combinations, RANGE).strictly_right_of(
+                bounds_obj_combinations
+            ),
+        )
+        validate_q = select(
+            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+            literal_column(
+                f"'{r1repr}'::{range_typ} << '{r2repr}'::{range_typ}"
+            ),
+            literal_column(
+                f"'{r1repr}'::{range_typ} >> '{r2repr}'::{range_typ}"
+            ),
+        )
+        orig_row = connection.execute(q).first()
+        validate_row = connection.execute(validate_q).first()
+        eq_(orig_row, validate_row)
+
+        r1, r2, leftof, rightof = orig_row
+        eq_(
+            r1.strictly_left_of(r2),
+            leftof,
+            f"{r1}.strictly_left_of({r2}) != {leftof}",
+        )
+        eq_(r1 << r2, leftof, f"{r1} << {r2} != {leftof}")
+        eq_(
+            r1.strictly_right_of(r2),
+            rightof,
+            f"{r1}.strictly_right_of({r2}) != {rightof}",
+        )
+        eq_(r1 >> r2, rightof, f"{r1} >> {r2} != {rightof}")
+
+    def test_not_extend_left_or_right_of(
+        self,
+        connection,
+        bounds_obj_combinations,
+        contains_range_obj_combinations,
+    ):
+        r1repr = contains_range_obj_combinations._stringify()
+        r2repr = bounds_obj_combinations._stringify()
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            cast(contains_range_obj_combinations, RANGE).label("r1"),
+            cast(bounds_obj_combinations, RANGE).label("r2"),
+            cast(contains_range_obj_combinations, RANGE).not_extend_left_of(
+                bounds_obj_combinations
+            ),
+            cast(contains_range_obj_combinations, RANGE).not_extend_right_of(
+                bounds_obj_combinations
+            ),
+        )
+        validate_q = select(
+            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+            literal_column(
+                f"'{r1repr}'::{range_typ} &> '{r2repr}'::{range_typ}"
+            ),
+            literal_column(
+                f"'{r1repr}'::{range_typ} &< '{r2repr}'::{range_typ}"
+            ),
+        )
+        orig_row = connection.execute(q).first()
+        validate_row = connection.execute(validate_q).first()
+        eq_(orig_row, validate_row)
+
+        r1, r2, leftof, rightof = orig_row
+        eq_(
+            r1.not_extend_left_of(r2),
+            leftof,
+            f"{r1}.not_extend_left_of({r2}) != {leftof}",
+        )
+        eq_(
+            r1.not_extend_right_of(r2),
+            rightof,
+            f"{r1}.not_extend_right_of({r2}) != {rightof}",
+        )
+
+    def test_adjacent(
+        self,
+        connection,
+        bounds_obj_combinations,
+        contains_range_obj_combinations,
+    ):
+        r1repr = contains_range_obj_combinations._stringify()
+        r2repr = bounds_obj_combinations._stringify()
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            cast(contains_range_obj_combinations, RANGE).label("r1"),
+            cast(bounds_obj_combinations, RANGE).label("r2"),
+            cast(contains_range_obj_combinations, RANGE).adjacent_to(
+                bounds_obj_combinations
+            ),
+        )
+        validate_q = select(
+            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+            literal_column(
+                f"'{r1repr}'::{range_typ} -|- '{r2repr}'::{range_typ}"
+            ),
+        )
+        orig_row = connection.execute(q).first()
+        validate_row = connection.execute(validate_q).first()
+        eq_(orig_row, validate_row)
+
+        r1, r2, adjacent = orig_row
+        eq_(
+            r1.adjacent_to(r2),
+            adjacent,
+            f"{r1}.adjacent_to({r2}) != {adjacent}",
+        )
+
 
 class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
     __requires__ = ("range_types",)