]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement and test Range.union(), aliased to __add__
authorLele Gaifax <lele@metapensiero.it>
Mon, 7 Nov 2022 18:29:28 +0000 (19:29 +0100)
committerLele Gaifax <lele@metapensiero.it>
Mon, 7 Nov 2022 18:42:48 +0000 (19:42 +0100)
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

index 5631fea220d5109508df91136d1a7bec987010b5..0822e1b653c697fc020dd7bac6495d120d121c82 100644 (file)
@@ -406,12 +406,51 @@ class Range(Generic[_T]):
             oupper, oupper_b, slower, slower_b
         )
 
-    def __add__(self, other):
-        """Range expression. Returns the union of the two ranges.
-        Will raise an exception if the resulting range is not
-        contiguous.
+    def union(self, other: Range) -> Range:
+        """Compute the union of this range with the `other`.
+
+        This raises a ``ValueError`` exception if the two ranges are
+        "disjunct", that is neither adjacent nor overlapping.
         """
-        raise NotImplementedError("not yet implemented")
+
+        # Empty ranges are "additive identities"
+        if self.empty:
+            return other
+        if other.empty:
+            return self
+
+        if not self.overlaps(other) and not self.adjacent_to(other):
+            raise ValueError(
+                "Adding non-overlapping and non-adjacent"
+                " ranges is not implemented"
+            )
+
+        slower = self.lower
+        slower_b = self.bounds[0]
+        supper = self.upper
+        supper_b = self.bounds[1]
+        olower = other.lower
+        olower_b = other.bounds[0]
+        oupper = other.upper
+        oupper_b = other.bounds[1]
+
+        if self._compare_edges(slower, slower_b, olower, olower_b) < 0:
+            rlower = slower
+            rlower_b = slower_b
+        else:
+            rlower = olower
+            rlower_b = olower_b
+
+        if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0:
+            rupper = supper
+            rupper_b = supper_b
+        else:
+            rupper = oupper
+            rupper_b = oupper_b
+
+        return Range(rlower, rupper, bounds=rlower_b + rupper_b)
+
+    __add__ = union
 
     def __str__(self):
         return self._stringify()
index e66a159955f6474663ffe1be975e208f42dbf8e4..b847ef02d9a81b1e86cfc4794568273a53f7383c 100644 (file)
@@ -60,11 +60,13 @@ from sqlalchemy.dialects.postgresql import TSRANGE
 from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE
 from sqlalchemy.dialects.postgresql import TSTZRANGE
 from sqlalchemy.exc import CompileError
+from sqlalchemy.exc import DataError
 from sqlalchemy.orm import declarative_base
 from sqlalchemy.orm import Session
 from sqlalchemy.sql import bindparam
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
+from sqlalchemy.testing import expect_raises
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_false
@@ -4213,6 +4215,51 @@ class _RangeComparisonFixtures:
             f"{r1}.adjacent_to({r2}) != {adjacent}",
         )
 
+    @testing.combinations(
+        ("empty", "empty", False),
+        ("empty", "[{le},{re_}]", False),
+        ("[{le},{re_}]", "empty", False),
+        ("[{le},{re_}]", "[{le},{re_}]", False),
+        ("[{ll},{rh})", "[{le},{re_}]", False),
+        ("[{ll},{ll}]", "({le},{rh}]", True),
+        ("[{ll},{ll}]", "[{rh},{rh}]", True),
+        argnames="r1repr,r2repr,err",
+    )
+    def test_union(self, connection, r1repr, r2repr, err):
+        data = self._value_values()
+
+        if r1repr != "empty":
+            r1repr = r1repr.format(**data)
+        if r2repr != "empty":
+            r2repr = r2repr.format(**data)
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+        )
+
+        r1, r2 = connection.execute(q).first()
+
+        resq = select(
+            literal_column(f"'{r1}'::{range_typ}+'{r2}'::{range_typ}", RANGE),
+        )
+
+        if err:
+            with expect_raises(DataError):
+                connection.execute(resq).scalar()
+            with expect_raises(ValueError):
+                r1.union(r2)
+        else:
+            union = connection.execute(resq).scalar()
+            eq_(
+                r1.union(r2),
+                union,
+                f"{r1}.union({r2}) != {union}",
+            )
+
 
 class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
     __requires__ = ("range_types",)