From: Lele Gaifax Date: Mon, 7 Nov 2022 18:29:28 +0000 (+0100) Subject: Implement and test Range.union(), aliased to __add__ X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=19582cdcf1cd06422c89c33c21e0038b694e25fb;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement and test Range.union(), aliased to __add__ --- diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 5631fea220..0822e1b653 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -406,12 +406,51 @@ class Range(Generic[_T]): oupper, oupper_b, slower, slower_b ) - def __add__(self, other): - """Range expression. Returns the union of the two ranges. - Will raise an exception if the resulting range is not - contiguous. + def union(self, other: Range) -> Range: + """Compute the union of this range with the `other`. + + This raises a ``ValueError`` exception if the two ranges are + "disjunct", that is neither adjacent nor overlapping. """ - raise NotImplementedError("not yet implemented") + + # Empty ranges are "additive identities" + if self.empty: + return other + if other.empty: + return self + + if not self.overlaps(other) and not self.adjacent_to(other): + raise ValueError( + "Adding non-overlapping and non-adjacent" + " ranges is not implemented" + ) + + slower = self.lower + slower_b = self.bounds[0] + supper = self.upper + supper_b = self.bounds[1] + olower = other.lower + olower_b = other.bounds[0] + oupper = other.upper + oupper_b = other.bounds[1] + + if self._compare_edges(slower, slower_b, olower, olower_b) < 0: + rlower = slower + rlower_b = slower_b + else: + rlower = olower + rlower_b = olower_b + + if self._compare_edges(supper, supper_b, oupper, oupper_b) > 0: + rupper = supper + rupper_b = supper_b + else: + rupper = oupper + rupper_b = oupper_b + + return Range(rlower, rupper, bounds=rlower_b + rupper_b) + + __add__ = union def __str__(self): return self._stringify() diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index e66a159955..b847ef02d9 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -60,11 +60,13 @@ from sqlalchemy.dialects.postgresql import TSRANGE from sqlalchemy.dialects.postgresql import TSTZMULTIRANGE from sqlalchemy.dialects.postgresql import TSTZRANGE from sqlalchemy.exc import CompileError +from sqlalchemy.exc import DataError from sqlalchemy.orm import declarative_base from sqlalchemy.orm import Session from sqlalchemy.sql import bindparam from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_false @@ -4213,6 +4215,51 @@ class _RangeComparisonFixtures: f"{r1}.adjacent_to({r2}) != {adjacent}", ) + @testing.combinations( + ("empty", "empty", False), + ("empty", "[{le},{re_}]", False), + ("[{le},{re_}]", "empty", False), + ("[{le},{re_}]", "[{le},{re_}]", False), + ("[{ll},{rh})", "[{le},{re_}]", False), + ("[{ll},{ll}]", "({le},{rh}]", True), + ("[{ll},{ll}]", "[{rh},{rh}]", True), + argnames="r1repr,r2repr,err", + ) + def test_union(self, connection, r1repr, r2repr, err): + data = self._value_values() + + if r1repr != "empty": + r1repr = r1repr.format(**data) + if r2repr != "empty": + r2repr = r2repr.format(**data) + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + ) + + r1, r2 = connection.execute(q).first() + + resq = select( + literal_column(f"'{r1}'::{range_typ}+'{r2}'::{range_typ}", RANGE), + ) + + if err: + with expect_raises(DataError): + connection.execute(resq).scalar() + with expect_raises(ValueError): + r1.union(r2) + else: + union = connection.execute(resq).scalar() + eq_( + r1.union(r2), + union, + f"{r1}.union({r2}) != {union}", + ) + class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): __requires__ = ("range_types",)