From 4cc08251e5bcc8ad33147e2198a8b8f222a524eb Mon Sep 17 00:00:00 2001 From: Lele Gaifax Date: Wed, 9 Nov 2022 09:19:20 +0100 Subject: [PATCH] Implement and test Range.__eq__ --- lib/sqlalchemy/dialects/postgresql/ranges.py | 22 ++++++++ test/dialect/postgresql/test_types.py | 53 ++++++++++++++++++++ 2 files changed, 75 insertions(+) diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 5706487e4d..7327f6f1cf 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -223,6 +223,28 @@ class Range(Generic[_T]): else: return 0 + def __eq__(self, other: Range) -> bool: + """Compare this range to the `other` taking into account + bounds inclusivity, returning ``True`` if they are equal. + """ + + if self.empty and other.empty: + return True + + slower = self.lower + slower_b = self.bounds[0] + olower = other.lower + olower_b = other.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + oupper = other.upper + oupper_b = other.bounds[1] + + return ( + self._compare_edges(slower, slower_b, olower, olower_b) == 0 + and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0 + ) + def contained_by(self, other: Range) -> bool: "Determine whether this range is a contained by `other`." diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 58c8efe916..128041e4ad 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -4324,6 +4324,47 @@ class _RangeComparisonFixtures(_RangeTests): f"{r1}.difference({r2}) != {difference}", ) + @testing.combinations( + lambda r, e: Range(empty=True), + lambda r, e: r, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + argnames="r1t", + ) + @testing.combinations( + lambda r, e: Range(empty=True), + lambda r, e: r, + lambda r, e: Range(r.lower, r.lower, bounds="[]"), + lambda r, e: Range(r.lower, r.lower + e, bounds="[)"), + lambda r, e: Range(r.lower - e, r.lower, bounds="(]"), + lambda r, e: Range(r.lower - e, r.lower + e, bounds="()"), + argnames="r2t", + ) + def test_equality(self, connection, r1t, r2t): + r1 = r1t(self._data_obj(), self._epsilon) + r2 = r2t(self._data_obj(), self._epsilon) + + range_typ = self._col_str + + q = select( + literal_column(f"'{r1}'::{range_typ} = '{r2}'::{range_typ}") + ) + equal = connection.execute(q).scalar() + eq_( + r1 == r2, + equal, + f"{r1} == {r2} != {equal}", + ) + + q = select( + literal_column(f"'{r1}'::{range_typ} <> '{r2}'::{range_typ}") + ) + different = connection.execute(q).scalar() + eq_( + r1 != r2, + different, + f"{r1} != {r2} != {different}", + ) + class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): __requires__ = ("range_types",) @@ -4482,6 +4523,8 @@ class _Int4RangeTests: def _data_obj(self): return Range(1, 4) + _epsilon = 1 + def _step_value_up(self, value): return value + 1 @@ -4500,6 +4543,8 @@ class _Int8RangeTests: def _data_obj(self): return Range(9223372036854775306, 9223372036854775800) + _epsilon = 1 + def _step_value_up(self, value): return value + 5 @@ -4518,6 +4563,8 @@ class _NumRangeTests: def _data_obj(self): return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0")) + _epsilon = decimal.Decimal(1) + def _step_value_up(self, value): return value + decimal.Decimal("1.8") @@ -4536,6 +4583,8 @@ class _DateRangeTests: def _data_obj(self): return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30)) + _epsilon = datetime.timedelta(days=1) + def _step_value_up(self, value): return value + datetime.timedelta(days=1) @@ -4557,6 +4606,8 @@ class _DateTimeRangeTests: datetime.datetime(2013, 3, 30, 23, 30), ) + _epsilon = datetime.timedelta(days=1) + def _step_value_up(self, value): return value + datetime.timedelta(days=1) @@ -4584,6 +4635,8 @@ class _DateTimeTZRangeTests: def _data_obj(self): return Range(*self.tstzs()) + _epsilon = datetime.timedelta(days=1) + def _step_value_up(self, value): return value + datetime.timedelta(days=1) -- 2.47.3