]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Introduce a _get_discrete_step() to handle (1,2] vs [2,3)
authorLele Gaifax <lele@metapensiero.it>
Wed, 26 Oct 2022 17:32:14 +0000 (19:32 +0200)
committerLele Gaifax <lele@metapensiero.it>
Wed, 26 Oct 2022 17:32:14 +0000 (19:32 +0200)
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_dialect.py

index 93ca39b5430777affb22fa18062b17d9a7ffbf95..c5c6c8f8d2f53c36d36812f0ee33e698ff24c655 100644 (file)
@@ -8,6 +8,8 @@
 from __future__ import annotations
 
 import dataclasses
+from datetime import date
+from datetime import timedelta
 from typing import Any
 from typing import Generic
 from typing import Optional
@@ -111,6 +113,18 @@ class Range(Generic[_T]):
             else value <= self.upper
         )
 
+    def _get_discrete_step(self):
+        "Determine the “step” for this range, if it is a discrete one."
+
+        # See
+        # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE
+        # for the rationale
+
+        if isinstance(self.lower, int) or isinstance(self.upper, int):
+            return 1
+        if isinstance(self.lower, date) or isinstance(self.upper, date):
+            return timedelta(days=1)
+
     def _contained_by(self, other: Range) -> bool:
         "Determine whether this range is a contained by `other`."
 
@@ -122,38 +136,63 @@ class Range(Generic[_T]):
         if other.empty:
             return False
 
+        slower = self.lower
+        slower_inc = self.bounds[0] == '['
+        supper = self.upper
+        supper_inc = self.bounds[1] == ']'
+        olower = other.lower
+        olower_inc = other.bounds[0] == '['
+        oupper = other.upper
+        oupper_inc = other.bounds[1] == ']'
+
         # A bilateral unbound range contains any other range
-        if other.lower is other.upper is None:
+        if olower is oupper is None:
             return True
 
         # A lower-bound range cannot contain a lower-unbound range
-        if self.lower is None and other.lower is not None:
+        if slower is None and olower is not None:
             return False
 
         # Likewise on the right side
-        if self.upper is None and other.upper is not None:
+        if supper is None and oupper is not None:
             return False
 
         # Check the lower end
-        if self.lower is not None and other.lower is not None:
-            lower_side = other.lower < self.lower
-            if not lower_side:
-                if self.bounds[0] == "(" or other.bounds[0] == "[":
-                    lower_side = other.lower == self.lower
-            if not lower_side:
+        if slower is not None and olower is not None:
+            lside = olower < slower
+            if not lside:
+                if not slower_inc or olower_inc:
+                    lside = olower == slower
+            if not lside:
+                step = self._get_discrete_step()
+                if step is not None:
+                    # Cover (1,x] vs [2,x) and (0,x] vs [1,x)
+                    if not slower_inc and olower_inc and slower < olower:
+                        lside = olower == (slower + step)
+                    elif slower_inc and not olower_inc and slower > olower:
+                        lside = (olower + step) == slower
+            if not lside:
                 return False
 
         # Lower end already considered, an upper-unbound range surely contains
         # this
-        if other.upper is None:
+        if oupper is None:
             return True
 
         # Check the upper end
-        upper_side = other.upper > self.upper
-        if not upper_side:
-            if self.bounds[1] == ")" or other.bounds[1] == "]":
-                upper_side = other.upper == self.upper
-        return upper_side
+        uside = oupper > supper
+        if not uside:
+            if not supper_inc or oupper_inc:
+                uside = oupper == supper
+            if not uside:
+                step = self._get_discrete_step()
+                if step is not None:
+                    # Cover (x,2] vs [x,3) and (x,1] vs [x,2)
+                    if supper_inc and not oupper_inc and supper < oupper:
+                        uside = oupper == (supper + step)
+                    elif not supper_inc and oupper_inc and supper > oupper:
+                        uside = (oupper + step) == supper
+        return uside
 
     def contains(self, value: Union[_T, Range]) -> bool:
         "Determine whether this range contains `value`."
index 3c36d3b787b7cc622b91023390509a125afe066e..4ca4847691fde08a49d4cc51583f7e192fd92070 100644 (file)
@@ -1235,6 +1235,7 @@ class TestRange(fixtures.TestBase):
     @testing.combinations(
         (Range(empty=True), "empty"),
         (Range(1, 2, bounds="(]"), "(1,2]"),
+        (Range(1, 2, bounds="[)"), "[1,2)"),
         (Range(None, None, bounds="()"), "(,)"),
         (Range(None, 1, bounds="[)"), "[,1)"),
         (Range(1, None, bounds="[)"), "[1,)"),
@@ -1255,6 +1256,7 @@ class TestRange(fixtures.TestBase):
         (Range(1, 4, bounds="()"), "(1,4)"),
         (Range(-4, 1, bounds="[)"), "[-4,1)"),
         (Range(2, 3, bounds="[)"), "[2,3)"),
+        (Range(0, 1, bounds="(]"), "(0,1]"),
         (Range(0, 6, bounds="[)"), "[0,6)"),
         argnames="r1,r1repr",
     )