]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement contains_value(), issubset() and issuperset() on PG Range
authorLele Gaifax <lele@metapensiero.it>
Mon, 24 Oct 2022 16:30:13 +0000 (18:30 +0200)
committerLele Gaifax <lele@metapensiero.it>
Mon, 24 Oct 2022 19:18:33 +0000 (21:18 +0200)
Fixes issue #8706.

lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_dialect.py

index 327feb4092b0e700c073e5d132957d7e7cb03582..dd4383b8b1c6f9b34fba36f24654df15d92739bc 100644 (file)
@@ -80,6 +80,77 @@ class Range(Generic[_T]):
     def __bool__(self) -> bool:
         return self.empty
 
+    def contains_value(self, value: T) -> bool:
+        "Check whether this range contains the given `value`."
+
+        if self.empty:
+            return False
+
+        if self.lower is None:
+            if self.upper is None:
+                return True
+            return (self.upper is None or
+                    (value < self.upper if self.bounds[1] == ")"
+                     else value <= self.upper))
+
+        if self.upper is None:
+            return (value > self.lower if self.bounds[0] == "("
+                    else value >= self.lower)
+
+        return ((value > self.lower if self.bounds[0] == "("
+                 else value >= self.lower)
+                and
+                (value < self.upper if self.bounds[1] == ")"
+                 else value <= self.upper))
+
+    def issubset(self, other) -> bool:
+        "Determine whether this range is a contained by `other`."
+
+        # Any range contains the empty one
+        if self.empty:
+            return True
+
+        # An empty range does not contain any range except the empty one
+        if other.empty:
+            return False
+
+        # A bilateral unbound range contains any other range
+        if other.lower is other.upper is None:
+            return True
+
+        # A lower-bound range cannot contain a lower-unbound range
+        if self.lower is None and other.lower is not None:
+            return False
+
+        # Likewise on the right side
+        if self.upper is None and other.upper is not None:
+            return False
+
+        # Check the lower end
+        if self.lower is not None and other.lower is not None:
+            lower_side = other.lower < self.lower
+            if not lower_side:
+                if self.bounds[0] == '(' or other.bounds[0] == '[':
+                    lower_side = other.lower == self.lower
+            if not lower_side:
+                return False
+
+        # Lower end already considered, an upper-unbound range surely contains this
+        if other.upper is None:
+            return True
+
+        # Check the upper end
+        upper_side = other.upper > self.upper
+        if not upper_side:
+            if self.bounds[1] == ')' or other.bounds[1] == ']':
+                upper_side = other.upper == self.upper
+        return upper_side
+
+    def issuperset(self, other) -> bool:
+        "Determine whether this range contains `other`."
+
+        return other.issubset(self)
+
 
 class AbstractRange(sqltypes.TypeEngine):
     """
index 27d4a4cf99366f0125a932d7406fb5afc1c03c33..cc8df8e010577a446502b39f8aa32260e3a20f00 100644 (file)
@@ -1210,3 +1210,56 @@ class Psycopg3Test(fixtures.TestBase):
     def test_async_version(self):
         e = create_engine("postgresql+psycopg_async://")
         is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg))
+
+
+class TestRange(fixtures.TestBase):
+    __only_on__ = "postgresql"
+    __backend__ = True
+
+    @testing.combinations(0, 1, 2, 4, 5, argnames="v")
+    @testing.combinations(
+        (Range(empty=True), 'empty'),
+        (Range(None, None, bounds='()'), '(,)'),
+        (Range(None, 4, bounds='(]'), '(,4]'),
+        (Range(1, None, bounds='[)'), '[1,)'),
+        (Range(1, 4, bounds='[)'), '[1,4)'),
+        (Range(1, 4, bounds='[]'), '[1,4]'),
+        (Range(1, 4, bounds='(]'), '(1,4]'),
+        (Range(1, 4, bounds='()'), '(1,4)'),
+        argnames="r,rrepr",
+    )
+    def test_range_contains_value(self, connection, r, rrepr, v):
+        q = text(f"select {v} <@ '{rrepr}'::int4range")
+        eq_(r.contains_value(v), connection.scalar(q))
+
+    @testing.combinations(
+        (Range(empty=True), 'empty'),
+        (Range(1, 2, bounds='()'), '(1,2)'),
+        (Range(None, None, bounds='()'), '(,)'),
+        (Range(None, 1, bounds='[)'), '[,1)'),
+        (Range(1, None, bounds='[)'), '[1,)'),
+        (Range(1, 4, bounds='[)'), '[1,4)'),
+        (Range(1, 4, bounds='[]'), '[1,4]'),
+        (Range(1, 4, bounds='(]'), '(1,4]'),
+        (Range(1, 4, bounds='()'), '(1,4)'),
+        argnames="r2,r2repr",
+    )
+    @testing.combinations(
+        (Range(empty=True), 'empty'),
+        (Range(None, None, bounds='[)'), '[,)'),
+        (Range(None, 1, bounds='[)'), '[,1)'),
+        (Range(1, None, bounds='[)'), '[1,)'),
+        (Range(1, 4, bounds='[)'), '[1,4)'),
+        (Range(1, 4, bounds='[]'), '[1,4]'),
+        (Range(1, 4, bounds='(]'), '(1,4]'),
+        (Range(1, 4, bounds='()'), '(1,4)'),
+        (Range(-4, 1, bounds='[)'), '[-4,1)'),
+        (Range(2, 3, bounds='[)'), '[2,3)'),
+        (Range(0, 6, bounds='[)'), '[0,6)'),
+        argnames="r1,r1repr",
+    )
+    def test_is_sub_super_set(self, connection, r1, r1repr, r2, r2repr):
+        q = text(f"select '{r1repr}'::int4range <@ '{r2repr}'::int4range")
+        eq_(r1.issubset(r2), connection.scalar(q))
+        q = text(f"select '{r2repr}'::int4range @> '{r1repr}'::int4range")
+        eq_(r2.issuperset(r1), connection.scalar(q))