From 90561265fbbb3d90a7a60180b78d750dac04f9b8 Mon Sep 17 00:00:00 2001 From: Lele Gaifax Date: Wed, 26 Oct 2022 19:32:14 +0200 Subject: [PATCH] Introduce a _get_discrete_step() to handle (1,2] vs [2,3) --- lib/sqlalchemy/dialects/postgresql/ranges.py | 69 +++++++++++++++----- test/dialect/postgresql/test_dialect.py | 2 + 2 files changed, 56 insertions(+), 15 deletions(-) diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 93ca39b543..c5c6c8f8d2 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -8,6 +8,8 @@ from __future__ import annotations import dataclasses +from datetime import date +from datetime import timedelta from typing import Any from typing import Generic from typing import Optional @@ -111,6 +113,18 @@ class Range(Generic[_T]): else value <= self.upper ) + def _get_discrete_step(self): + "Determine the “step” for this range, if it is a discrete one." + + # See + # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE + # for the rationale + + if isinstance(self.lower, int) or isinstance(self.upper, int): + return 1 + if isinstance(self.lower, date) or isinstance(self.upper, date): + return timedelta(days=1) + def _contained_by(self, other: Range) -> bool: "Determine whether this range is a contained by `other`." @@ -122,38 +136,63 @@ class Range(Generic[_T]): if other.empty: return False + slower = self.lower + slower_inc = self.bounds[0] == '[' + supper = self.upper + supper_inc = self.bounds[1] == ']' + olower = other.lower + olower_inc = other.bounds[0] == '[' + oupper = other.upper + oupper_inc = other.bounds[1] == ']' + # A bilateral unbound range contains any other range - if other.lower is other.upper is None: + if olower is oupper is None: return True # A lower-bound range cannot contain a lower-unbound range - if self.lower is None and other.lower is not None: + if slower is None and olower is not None: return False # Likewise on the right side - if self.upper is None and other.upper is not None: + if supper is None and oupper is not None: return False # Check the lower end - if self.lower is not None and other.lower is not None: - lower_side = other.lower < self.lower - if not lower_side: - if self.bounds[0] == "(" or other.bounds[0] == "[": - lower_side = other.lower == self.lower - if not lower_side: + if slower is not None and olower is not None: + lside = olower < slower + if not lside: + if not slower_inc or olower_inc: + lside = olower == slower + if not lside: + step = self._get_discrete_step() + if step is not None: + # Cover (1,x] vs [2,x) and (0,x] vs [1,x) + if not slower_inc and olower_inc and slower < olower: + lside = olower == (slower + step) + elif slower_inc and not olower_inc and slower > olower: + lside = (olower + step) == slower + if not lside: return False # Lower end already considered, an upper-unbound range surely contains # this - if other.upper is None: + if oupper is None: return True # Check the upper end - upper_side = other.upper > self.upper - if not upper_side: - if self.bounds[1] == ")" or other.bounds[1] == "]": - upper_side = other.upper == self.upper - return upper_side + uside = oupper > supper + if not uside: + if not supper_inc or oupper_inc: + uside = oupper == supper + if not uside: + step = self._get_discrete_step() + if step is not None: + # Cover (x,2] vs [x,3) and (x,1] vs [x,2) + if supper_inc and not oupper_inc and supper < oupper: + uside = oupper == (supper + step) + elif not supper_inc and oupper_inc and supper > oupper: + uside = (oupper + step) == supper + return uside def contains(self, value: Union[_T, Range]) -> bool: "Determine whether this range contains `value`." diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 3c36d3b787..4ca4847691 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -1235,6 +1235,7 @@ class TestRange(fixtures.TestBase): @testing.combinations( (Range(empty=True), "empty"), (Range(1, 2, bounds="(]"), "(1,2]"), + (Range(1, 2, bounds="[)"), "[1,2)"), (Range(None, None, bounds="()"), "(,)"), (Range(None, 1, bounds="[)"), "[,1)"), (Range(1, None, bounds="[)"), "[1,)"), @@ -1255,6 +1256,7 @@ class TestRange(fixtures.TestBase): (Range(1, 4, bounds="()"), "(1,4)"), (Range(-4, 1, bounds="[)"), "[-4,1)"), (Range(2, 3, bounds="[)"), "[2,3)"), + (Range(0, 1, bounds="(]"), "(0,1]"), (Range(0, 6, bounds="[)"), "[0,6)"), argnames="r1,r1repr", ) -- 2.47.3