]> git.ipfire.org Git - thirdparty/psycopg.git/commitdiff
Add total ordering on Multirange object
authorDaniele Varrazzo <daniele.varrazzo@gmail.com>
Sun, 3 Oct 2021 23:53:18 +0000 (01:53 +0200)
committerDaniele Varrazzo <daniele.varrazzo@gmail.com>
Mon, 4 Oct 2021 12:45:56 +0000 (14:45 +0200)
Also improved idiom for the Range ordering implementation.

psycopg/psycopg/types/multirange.py
psycopg/psycopg/types/range.py
tests/types/test_multirange.py

index dd80ce1d152c06ca404fe681ab509e2e1ce25a54..c3bae8635207c56ca58763759558c906ee975632 100644 (file)
@@ -72,6 +72,29 @@ class Multirange(MutableSequence[Range[T]]):
     def insert(self, index: int, value: Range[T]) -> None:
         self._ranges.insert(index, value)
 
+    def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, Multirange):
+            return False
+        return self._ranges == other._ranges
+
+    # Order is arbitrary but consistent
+
+    def __lt__(self, other: Any) -> bool:
+        if not isinstance(other, Multirange):
+            return NotImplemented
+        return self._ranges < other._ranges
+
+    def __le__(self, other: Any) -> bool:
+        return self == other or self < other  # type: ignore
+
+    def __gt__(self, other: Any) -> bool:
+        if not isinstance(other, Multirange):
+            return NotImplemented
+        return self._ranges > other._ranges
+
+    def __ge__(self, other: Any) -> bool:
+        return self == other or self > other  # type: ignore
+
 
 # Subclasses to specify a specific subtype. Usually not needed
 
index 34c6c8e533bf34cbf3de159fa45d721c0bc20003..bf670b3f1d66d33541e9849cf98029bc8b2dd06d 100644 (file)
@@ -172,9 +172,6 @@ class Range(Generic[T]):
             and self._bounds == other._bounds
         )
 
-    def __ne__(self, other: Any) -> bool:
-        return not self.__eq__(other)
-
     def __hash__(self) -> int:
         return hash((self._lower, self._upper, self._bounds))
 
@@ -199,22 +196,16 @@ class Range(Generic[T]):
         return False
 
     def __le__(self, other: Any) -> bool:
-        if self == other:
-            return True
-        else:
-            return self.__lt__(other)
+        return self == other or self < other  # type: ignore
 
     def __gt__(self, other: Any) -> bool:
         if isinstance(other, Range):
-            return other.__lt__(self)
+            return other < self
         else:
             return NotImplemented
 
     def __ge__(self, other: Any) -> bool:
-        if self == other:
-            return True
-        else:
-            return self.__gt__(other)
+        return self == other or self > other  # type: ignore
 
     def __getstate__(self) -> Dict[str, Any]:
         return {
index 8416c1e577c2d0a15ef2b4d13478da47cc6e8861..63231e0175da9ae3b8157fca6e1f9ab9def4f737 100644 (file)
@@ -1,6 +1,9 @@
+import pickle
+
 import pytest
 
 from psycopg.adapt import PyFormat
+from psycopg.types.range import Range
 from psycopg.types.multirange import Multirange
 
 pytestmark = pytest.mark.pg(">= 14")
@@ -12,6 +15,65 @@ mr_classes = """Int4Multirange Int8Multirange NumericMultirange
     DateMultirange TimestampMultirange TimestamptzMultirange""".split()
 
 
+class TestMultirangeObject:
+    def test_empty(self):
+        mr = Multirange()
+        assert not mr
+        assert len(mr) == 0
+
+        mr = Multirange([])
+        assert not mr
+        assert len(mr) == 0
+
+    def test_sequence(self):
+        mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+        assert mr
+        assert len(mr) == 3
+        assert mr[2] == Range(50, 60)
+        assert mr[-2] == Range(30, 40)
+
+    def test_setitem(self):
+        mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+        mr[1] = Range(31, 41)
+        assert mr == Multirange([Range(10, 20), Range(31, 41), Range(50, 60)])
+
+    def test_delitem(self):
+        mr = Multirange([Range(10, 20), Range(30, 40), Range(50, 60)])
+        del mr[1]
+        assert mr == Multirange([Range(10, 20), Range(50, 60)])
+
+        del mr[-2]
+        assert mr == Multirange([Range(50, 60)])
+
+    def test_relations(self):
+        mr1 = Multirange([Range(10, 20), Range(30, 40)])
+        mr2 = Multirange([Range(11, 20), Range(30, 40)])
+        mr3 = Multirange([Range(9, 20), Range(30, 40)])
+        assert mr1 <= mr1
+        assert not mr1 < mr1
+        assert mr1 >= mr1
+        assert not mr1 > mr1
+        assert mr1 < mr2
+        assert mr1 <= mr2
+        assert mr1 > mr3
+        assert mr1 >= mr3
+        assert mr1 != mr2
+        assert not mr1 == mr2
+
+    def test_pickling(self):
+        r = Multirange([Range(0, 4)])
+        assert pickle.loads(pickle.dumps(r)) == r
+
+    def test_str(self):
+        mr = Multirange([Range(10, 20), Range(30, 40)])
+        assert str(mr) == "{[10, 20), [30, 40)}"
+
+    def test_repr(self):
+        mr = Multirange([Range(10, 20), Range(30, 40)])
+        expected = "Multirange([Range(10, 20, '[)'), Range(30, 40, '[)')])"
+        assert repr(mr) == expected
+
+
 @pytest.mark.parametrize("pgtype", mr_names)
 @pytest.mark.parametrize("fmt_in", PyFormat)
 def test_dump_builtin_empty(conn, pgtype, fmt_in):