]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement and test Range.__eq__
authorLele Gaifax <lele@metapensiero.it>
Wed, 9 Nov 2022 08:19:20 +0000 (09:19 +0100)
committerLele Gaifax <lele@metapensiero.it>
Wed, 9 Nov 2022 08:19:20 +0000 (09:19 +0100)
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

index 5706487e4d3efb887a7813554d9c5ad62973bf22..7327f6f1cf1f5a8f2391ae41e9409bbca7edc040 100644 (file)
@@ -223,6 +223,28 @@ class Range(Generic[_T]):
             else:
                 return 0
 
+    def __eq__(self, other: Range) -> bool:
+        """Compare this range to the `other` taking into account
+        bounds inclusivity, returning ``True`` if they are equal.
+        """
+
+        if self.empty and other.empty:
+            return True
+
+        slower = self.lower
+        slower_b = self.bounds[0]
+        olower = other.lower
+        olower_b = other.bounds[0]
+        supper = self.upper
+        supper_b = self.bounds[1]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        return (
+            self._compare_edges(slower, slower_b, olower, olower_b) == 0
+            and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0
+        )
+
     def contained_by(self, other: Range) -> bool:
         "Determine whether this range is a contained by `other`."
 
index 58c8efe916427a5cc3386c3c6233b26445821379..128041e4adabb0163be51547ea1421a7a7ef58a5 100644 (file)
@@ -4324,6 +4324,47 @@ class _RangeComparisonFixtures(_RangeTests):
                 f"{r1}.difference({r2}) != {difference}",
             )
 
+    @testing.combinations(
+        lambda r, e: Range(empty=True),
+        lambda r, e: r,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        argnames="r1t",
+    )
+    @testing.combinations(
+        lambda r, e: Range(empty=True),
+        lambda r, e: r,
+        lambda r, e: Range(r.lower, r.lower, bounds="[]"),
+        lambda r, e: Range(r.lower, r.lower + e, bounds="[)"),
+        lambda r, e: Range(r.lower - e, r.lower, bounds="(]"),
+        lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"),
+        argnames="r2t",
+    )
+    def test_equality(self, connection, r1t, r2t):
+        r1 = r1t(self._data_obj(), self._epsilon)
+        r2 = r2t(self._data_obj(), self._epsilon)
+
+        range_typ = self._col_str
+
+        q = select(
+            literal_column(f"'{r1}'::{range_typ} = '{r2}'::{range_typ}")
+        )
+        equal = connection.execute(q).scalar()
+        eq_(
+            r1 == r2,
+            equal,
+            f"{r1} == {r2} != {equal}",
+        )
+
+        q = select(
+            literal_column(f"'{r1}'::{range_typ} <> '{r2}'::{range_typ}")
+        )
+        different = connection.execute(q).scalar()
+        eq_(
+            r1 != r2,
+            different,
+            f"{r1} != {r2} != {different}",
+        )
+
 
 class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
     __requires__ = ("range_types",)
@@ -4482,6 +4523,8 @@ class _Int4RangeTests:
     def _data_obj(self):
         return Range(1, 4)
 
+    _epsilon = 1
+
     def _step_value_up(self, value):
         return value + 1
 
@@ -4500,6 +4543,8 @@ class _Int8RangeTests:
     def _data_obj(self):
         return Range(9223372036854775306, 9223372036854775800)
 
+    _epsilon = 1
+
     def _step_value_up(self, value):
         return value + 5
 
@@ -4518,6 +4563,8 @@ class _NumRangeTests:
     def _data_obj(self):
         return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0"))
 
+    _epsilon = decimal.Decimal(1)
+
     def _step_value_up(self, value):
         return value + decimal.Decimal("1.8")
 
@@ -4536,6 +4583,8 @@ class _DateRangeTests:
     def _data_obj(self):
         return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30))
 
+    _epsilon = datetime.timedelta(days=1)
+
     def _step_value_up(self, value):
         return value + datetime.timedelta(days=1)
 
@@ -4557,6 +4606,8 @@ class _DateTimeRangeTests:
             datetime.datetime(2013, 3, 30, 23, 30),
         )
 
+    _epsilon = datetime.timedelta(days=1)
+
     def _step_value_up(self, value):
         return value + datetime.timedelta(days=1)
 
@@ -4584,6 +4635,8 @@ class _DateTimeTZRangeTests:
     def _data_obj(self):
         return Range(*self.tstzs())
 
+    _epsilon = datetime.timedelta(days=1)
+
     def _step_value_up(self, value):
         return value + datetime.timedelta(days=1)