]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implement contains_value(), issubset() and issuperset() on PG Range
authorLele Gaifax <lele@metapensiero.it>
Wed, 2 Nov 2022 12:33:41 +0000 (08:33 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 3 Nov 2022 13:30:38 +0000 (09:30 -0400)
Added new methods :meth:`_postgresql.Range.contains` and
:meth:`_postgresql.Range.contained_by` to the new :class:`.Range` data
object, which mirror the behavior of the PostgreSQL ``@>`` and ``<@``
operators, as well as the
:meth:`_postgresql.AbstractRange.comparator_factory.contains` and
:meth:`_postgresql.AbstractRange.comparator_factory.contained_by` SQL
operator methods. Pull request courtesy Lele Gaifax.

Fixes: #8706
Closes: #8707
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8707
Pull-request-sha: 3a74a0d93e63032ebee02992977498c717a077ff

Change-Id: Ief81ca5c31448640b26dfbc3defd4dde1d51e366

doc/build/changelog/unreleased_20/8706.rst [new file with mode: 0644]
doc/build/changelog/whatsnew_20.rst
doc/build/dialects/postgresql.rst
lib/sqlalchemy/dialects/postgresql/__init__.py
lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

diff --git a/doc/build/changelog/unreleased_20/8706.rst b/doc/build/changelog/unreleased_20/8706.rst
new file mode 100644 (file)
index 0000000..a6f3321
--- /dev/null
@@ -0,0 +1,11 @@
+.. change::
+    :tags: feature, postgresql
+    :tickets: 8706
+
+    Added new methods :meth:`_postgresql.Range.contains` and
+    :meth:`_postgresql.Range.contained_by` to the new :class:`.Range` data
+    object, which mirror the behavior of the PostgreSQL ``@>`` and ``<@``
+    operators, as well as the
+    :meth:`_postgresql.AbstractRange.comparator_factory.contains` and
+    :meth:`_postgresql.AbstractRange.comparator_factory.contained_by` SQL
+    operator methods. Pull request courtesy Lele Gaifax.
index 98865f233d270e9656528a7dd1d3f5343124bf17..3d4eca6b2b5a25e9b67353ce0b8699f7a84d60d2 100644 (file)
@@ -1883,6 +1883,12 @@ objects are used.
 Code that used the previous psycopg2-specific types should be modified
 to use :class:`_postgresql.Range`, which presents a compatible interface.
 
+The :class:`_postgresql.Range` object also features comparison support which
+mirrors that of PostgreSQL.  Implemented so far are :meth:`_postgresql.Range.contains`
+and :meth:`_postgresql.Range.contained_by` methods which work in the same way as
+the PostgreSQL ``@>`` and ``<@``.  Additional operator support may be added
+in future releases.
+
 See the documentation at :ref:`postgresql_ranges` for background on
 using the new feature.
 
@@ -1891,6 +1897,9 @@ using the new feature.
 
     :ref:`postgresql_ranges`
 
+:ticket:`7156`
+:ticket:`8706`
+
 .. _change_7086:
 
 ``match()`` operator on PostgreSQL uses ``plainto_tsquery()`` rather than ``to_tsquery()``
index 411037f87cf9648de8fda68732bcb3485355406a..ef853b68326ff5446cd299f2ab3b3008dbeda19a 100644 (file)
@@ -216,6 +216,7 @@ The available range datatypes are as follows:
 * :class:`_postgresql.TSTZRANGE`
 
 .. autoclass:: sqlalchemy.dialects.postgresql.Range
+    :members:
 
 Multiranges
 ^^^^^^^^^^^
@@ -350,6 +351,12 @@ construction arguments, are as follows:
 
 .. currentmodule:: sqlalchemy.dialects.postgresql
 
+.. autoclass:: sqlalchemy.dialects.postgresql.AbstractRange
+    :members: comparator_factory
+
+.. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange
+    :members: comparator_factory
+
 .. autoclass:: aggregate_order_by
 
 .. autoclass:: array
index 8dbee1f7f473d39e78ebfe41e26de494de50d226..7890541ffde05c3afdae4ad7d6bebea54ae0dbe9 100644 (file)
@@ -48,6 +48,8 @@ from .named_types import DropDomainType
 from .named_types import DropEnumType
 from .named_types import ENUM
 from .named_types import NamedType
+from .ranges import AbstractMultiRange
+from .ranges import AbstractRange
 from .ranges import DATEMULTIRANGE
 from .ranges import DATERANGE
 from .ranges import INT4MULTIRANGE
index 327feb4092b0e700c073e5d132957d7e7cb03582..6729f3785f538ee28d396e73aca6e94e1d64e433 100644 (file)
@@ -8,10 +8,14 @@
 from __future__ import annotations
 
 import dataclasses
+from datetime import date
+from datetime import datetime
+from datetime import timedelta
 from typing import Any
 from typing import Generic
 from typing import Optional
 from typing import TypeVar
+from typing import Union
 
 from ... import types as sqltypes
 from ...util import py310
@@ -80,6 +84,204 @@ class Range(Generic[_T]):
     def __bool__(self) -> bool:
         return self.empty
 
+    def _contains_value(self, value: _T) -> bool:
+        "Check whether this range contains the given `value`."
+
+        if self.empty:
+            return False
+
+        if self.lower is None:
+            return self.upper is None or (
+                value < self.upper
+                if self.bounds[1] == ")"
+                else value <= self.upper
+            )
+
+        if self.upper is None:
+            return (
+                value > self.lower
+                if self.bounds[0] == "("
+                else value >= self.lower
+            )
+
+        return (
+            value > self.lower
+            if self.bounds[0] == "("
+            else value >= self.lower
+        ) and (
+            value < self.upper
+            if self.bounds[1] == ")"
+            else value <= self.upper
+        )
+
+    def _get_discrete_step(self):
+        "Determine the “step” for this range, if it is a discrete one."
+
+        # See
+        # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE
+        # for the rationale
+
+        if isinstance(self.lower, int) or isinstance(self.upper, int):
+            return 1
+        elif isinstance(self.lower, datetime) or isinstance(
+            self.upper, datetime
+        ):
+            # This is required, because a `isinstance(datetime.now(), date)`
+            # is True
+            return None
+        elif isinstance(self.lower, date) or isinstance(self.upper, date):
+            return timedelta(days=1)
+        else:
+            return None
+
+    def contained_by(self, other: Range) -> bool:
+        "Determine whether this range is a contained by `other`."
+
+        # Any range contains the empty one
+        if self.empty:
+            return True
+
+        # An empty range does not contain any range except the empty one
+        if other.empty:
+            return False
+
+        olower = other.lower
+        oupper = other.upper
+
+        # A bilateral unbound range contains any other range
+        if olower is oupper is None:
+            return True
+
+        slower = self.lower
+        supper = self.upper
+
+        # A lower-bound range cannot contain a lower-unbound range
+        if slower is None and olower is not None:
+            return False
+
+        # Likewise on the right side
+        if supper is None and oupper is not None:
+            return False
+
+        slower_inc = self.bounds[0] == "["
+        supper_inc = self.bounds[1] == "]"
+        olower_inc = other.bounds[0] == "["
+        oupper_inc = other.bounds[1] == "]"
+
+        # Check the lower end
+        step = -1
+        if slower is not None and olower is not None:
+            lside = olower < slower
+            if not lside:
+                if not slower_inc or olower_inc:
+                    lside = olower == slower
+            if not lside:
+                # Cover (1,x] vs [2,x) and (0,x] vs [1,x)
+                if not slower_inc and olower_inc and slower < olower:
+                    step = self._get_discrete_step()
+                    if step is not None:
+                        lside = olower == (slower + step)
+                elif slower_inc and not olower_inc and slower > olower:
+                    step = self._get_discrete_step()
+                    if step is not None:
+                        lside = (olower + step) == slower
+            if not lside:
+                return False
+
+        # Lower end already considered, an upper-unbound range surely contains
+        # this
+        if oupper is None:
+            return True
+
+        # Check the upper end
+        uside = oupper > supper
+        if not uside:
+            if not supper_inc or oupper_inc:
+                uside = oupper == supper
+            if not uside:
+                # Cover (x,2] vs [x,3) and (x,1] vs [x,2)
+                if supper_inc and not oupper_inc and supper < oupper:
+                    if step == -1:
+                        step = self._get_discrete_step()
+                    if step is not None:
+                        uside = oupper == (supper + step)
+                elif not supper_inc and oupper_inc and supper > oupper:
+                    if step == -1:
+                        step = self._get_discrete_step()
+                    if step is not None:
+                        uside = (oupper + step) == supper
+        return uside
+
+    def contains(self, value: Union[_T, Range]) -> bool:
+        "Determine whether this range contains `value`."
+
+        if isinstance(value, Range):
+            return value.contained_by(self)
+        else:
+            return self._contains_value(value)
+
+    def overlaps(self, other):
+        """Boolean expression. Returns true if the column overlaps
+        (has points in common with) the right hand operand.
+        """
+        raise NotImplementedError("not yet implemented")
+
+    def strictly_left_of(self, other):
+        """Boolean expression. Returns true if the column is strictly
+        left of the right hand operand.
+        """
+        raise NotImplementedError("not yet implemented")
+
+    __lshift__ = strictly_left_of
+
+    def strictly_right_of(self, other):
+        """Boolean expression. Returns true if the column is strictly
+        right of the right hand operand.
+        """
+        raise NotImplementedError("not yet implemented")
+
+    __rshift__ = strictly_right_of
+
+    def not_extend_right_of(self, other):
+        """Boolean expression. Returns true if the range in the column
+        does not extend right of the range in the operand.
+        """
+        raise NotImplementedError("not yet implemented")
+
+    def not_extend_left_of(self, other):
+        """Boolean expression. Returns true if the range in the column
+        does not extend left of the range in the operand.
+        """
+        raise NotImplementedError("not yet implemented")
+
+    def adjacent_to(self, other):
+        """Boolean expression. Returns true if the range in the column
+        is adjacent to the range in the operand.
+        """
+        raise NotImplementedError("not yet implemented")
+
+    def __add__(self, other):
+        """Range expression. Returns the union of the two ranges.
+        Will raise an exception if the resulting range is not
+        contiguous.
+        """
+        raise NotImplementedError("not yet implemented")
+
+    def __str__(self):
+        return self._stringify()
+
+    def _stringify(self):
+        if self.empty:
+            return "empty"
+
+        l, r = self.lower, self.upper
+        l = "" if l is None else l
+        r = "" if r is None else r
+
+        b0, b1 = self.bounds
+
+        return f"{b0}{l},{r}{b1}"
+
 
 class AbstractRange(sqltypes.TypeEngine):
     """
@@ -93,6 +295,8 @@ class AbstractRange(sqltypes.TypeEngine):
 
     render_bind_cast = True
 
+    __abstract__ = True
+
     def adapt(self, impltype):
         """dynamically adapt a range type to an abstract impl.
 
@@ -202,6 +406,8 @@ class AbstractRangeImpl(AbstractRange):
 class AbstractMultiRange(AbstractRange):
     """base for PostgreSQL MULTIRANGE types"""
 
+    __abstract__ = True
+
 
 class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange):
     """marker for AbstractRange that will apply a subclass-specific
index 91eada9a81d66724134abac73d46e353997c8ec3..83cea8f15a2f5c79ac3b27c13971ab3c9f315273 100644 (file)
@@ -20,6 +20,7 @@ from sqlalchemy import func
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import literal
+from sqlalchemy import literal_column
 from sqlalchemy import MetaData
 from sqlalchemy import null
 from sqlalchemy import Numeric
@@ -66,6 +67,8 @@ from sqlalchemy.sql import operators
 from sqlalchemy.sql import sqltypes
 from sqlalchemy.testing import expect_raises_message
 from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
 from sqlalchemy.testing.assertions import assert_raises
 from sqlalchemy.testing.assertions import assert_raises_message
 from sqlalchemy.testing.assertions import AssertsCompiledSQL
@@ -3846,7 +3849,194 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
         )
 
 
-class _RangeTypeRoundTrip(fixtures.TablesTest):
+class _RangeComparisonFixtures:
+    def _data_str(self):
+        """return string form of a sample range"""
+        raise NotImplementedError()
+
+    def _data_obj(self):
+        """return Range form of the same range"""
+        raise NotImplementedError()
+
+    def _step_value_up(self, value):
+        """given a value, return a step up
+
+        this is a value that given the lower end of the sample range,
+        would be less than the upper value of the range
+
+        """
+        raise NotImplementedError()
+
+    def _step_value_down(self, value):
+        """given a value, return a step down
+
+        this is a value that given the upper end of the sample range,
+        would be greater than the lower value of the range
+
+        """
+        raise NotImplementedError()
+
+    def _value_values(self):
+        """Return a series of values related to the base range
+
+        le = left equal
+        ll = lower than left
+        re = right equal
+        rh = higher than right
+        il = inside lower
+        ih = inside higher
+
+        """
+        spec = self._data_obj()
+
+        le, re_ = spec.lower, spec.upper
+
+        ll = self._step_value_down(le)
+        il = self._step_value_up(le)
+        rh = self._step_value_up(re_)
+        ih = self._step_value_down(re_)
+
+        return {"le": le, "re_": re_, "ll": ll, "il": il, "rh": rh, "ih": ih}
+
+    @testing.fixture(
+        params=[
+            lambda **kw: Range(empty=True),
+            lambda **kw: Range(bounds="[)"),
+            lambda le, **kw: Range(upper=le, bounds="[)"),
+            lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+            lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+            lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[]"),
+            lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="(]"),
+            lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="()"),
+            lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="[)"),
+            lambda il, ih, **kw: Range(lower=il, upper=ih, bounds="[)"),
+            lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="(]"),
+            lambda ll, rh, **kw: Range(lower=ll, upper=rh, bounds="[)"),
+        ]
+    )
+    def contains_range_obj_combinations(self, request):
+        """ranges that are used for range contains() contained_by() tests"""
+        data = self._value_values()
+
+        range_ = request.param(**data)
+        yield range_
+
+    @testing.fixture(
+        params=[
+            lambda l, r: Range(empty=True),
+            lambda l, r: Range(bounds="()"),
+            lambda l, r: Range(upper=r, bounds="(]"),
+            lambda l, r: Range(lower=l, bounds="[)"),
+            lambda l, r: Range(lower=l, upper=r, bounds="[)"),
+            lambda l, r: Range(lower=l, upper=r, bounds="[]"),
+            lambda l, r: Range(lower=l, upper=r, bounds="(]"),
+            lambda l, r: Range(lower=l, upper=r, bounds="()"),
+        ]
+    )
+    def bounds_obj_combinations(self, request):
+        """sample ranges used for value and range contains()/contained_by()
+        tests"""
+
+        obj = self._data_obj()
+        l, r = obj.lower, obj.upper
+
+        template = request.param
+        value = template(l=l, r=r)
+        yield value
+
+    @testing.fixture(params=["ll", "le", "il", "ih", "re_", "rh"])
+    def value_combinations(self, request):
+        """sample values used for value contains() tests"""
+        data = self._value_values()
+        return data[request.param]
+
+    def test_basic_py_sanity(self):
+        values = self._value_values()
+
+        range_ = self._data_obj()
+
+        is_true(range_.contains(Range(lower=values["il"], upper=values["ih"])))
+
+        is_true(
+            range_.contained_by(Range(lower=values["ll"], upper=values["rh"]))
+        )
+
+        is_true(range_.contains(values["il"]))
+
+        is_false(
+            range_.contains(Range(lower=values["ll"], upper=values["ih"]))
+        )
+
+        is_false(range_.contains(values["rh"]))
+
+    def test_contains_value(
+        self, connection, bounds_obj_combinations, value_combinations
+    ):
+        range_ = bounds_obj_combinations
+        range_typ = self._col_str
+
+        strvalue = range_._stringify()
+
+        v = value_combinations
+        RANGE = self._col_type
+
+        q = select(
+            literal_column(f"'{strvalue}'::{range_typ}", RANGE).label("r1"),
+            cast(range_, RANGE).label("r2"),
+        )
+        literal_range, cast_range = connection.execute(q).first()
+        eq_(literal_range, cast_range)
+
+        q = select(
+            cast(range_, RANGE),
+            cast(range_, RANGE).contains(v),
+        )
+        r, expected = connection.execute(q).first()
+        eq_(r.contains(v), expected)
+
+    def test_contains_range(
+        self,
+        connection,
+        bounds_obj_combinations,
+        contains_range_obj_combinations,
+    ):
+        r1repr = contains_range_obj_combinations._stringify()
+        r2repr = bounds_obj_combinations._stringify()
+
+        RANGE = self._col_type
+        range_typ = self._col_str
+
+        q = select(
+            cast(contains_range_obj_combinations, RANGE).label("r1"),
+            cast(bounds_obj_combinations, RANGE).label("r2"),
+            cast(contains_range_obj_combinations, RANGE).contains(
+                bounds_obj_combinations
+            ),
+            cast(contains_range_obj_combinations, RANGE).contained_by(
+                bounds_obj_combinations
+            ),
+        )
+        validate_q = select(
+            literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+            literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+            literal_column(
+                f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}"
+            ),
+            literal_column(
+                f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}"
+            ),
+        )
+        orig_row = connection.execute(q).first()
+        validate_row = connection.execute(validate_q).first()
+        eq_(orig_row, validate_row)
+
+        r1, r2, contains, contained = orig_row
+        eq_(r1.contains(r2), contains)
+        eq_(r1.contained_by(r2), contained)
+        eq_(r2.contains(r1), contained)
+
+
+class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
     __requires__ = ("range_types",)
     __backend__ = True
 
@@ -3861,6 +4051,9 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
         )
         cls.col = table.c.range
 
+    def test_stringify(self):
+        eq_(str(self._data_obj()), self._data_str())
+
     def test_auto_cast_back_to_type(self, connection):
         """test that a straight pass of the range type without any context
         will send appropriate casting info so that the driver can round
@@ -3995,10 +4188,16 @@ class _Int4RangeTests:
     _col_str = "INT4RANGE"
 
     def _data_str(self):
-        return "[1,2)"
+        return "[1,4)"
 
     def _data_obj(self):
-        return Range(1, 2)
+        return Range(1, 4)
+
+    def _step_value_up(self, value):
+        return value + 1
+
+    def _step_value_down(self, value):
+        return value - 1
 
 
 class _Int8RangeTests:
@@ -4007,10 +4206,16 @@ class _Int8RangeTests:
     _col_str = "INT8RANGE"
 
     def _data_str(self):
-        return "[9223372036854775806,9223372036854775807)"
+        return "[9223372036854775306,9223372036854775800)"
 
     def _data_obj(self):
-        return Range(9223372036854775806, 9223372036854775807)
+        return Range(9223372036854775306, 9223372036854775800)
+
+    def _step_value_up(self, value):
+        return value + 5
+
+    def _step_value_down(self, value):
+        return value - 5
 
 
 class _NumRangeTests:
@@ -4019,10 +4224,16 @@ class _NumRangeTests:
     _col_str = "NUMRANGE"
 
     def _data_str(self):
-        return "[1.0,2.0)"
+        return "[1.0,9.0)"
 
     def _data_obj(self):
-        return Range(decimal.Decimal("1.0"), decimal.Decimal("2.0"))
+        return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0"))
+
+    def _step_value_up(self, value):
+        return value + decimal.Decimal("1.8")
+
+    def _step_value_down(self, value):
+        return value - decimal.Decimal("1.8")
 
 
 class _DateRangeTests:
@@ -4031,10 +4242,16 @@ class _DateRangeTests:
     _col_str = "DATERANGE"
 
     def _data_str(self):
-        return "[2013-03-23,2013-03-24)"
+        return "[2013-03-23,2013-03-30)"
 
     def _data_obj(self):
-        return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24))
+        return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30))
+
+    def _step_value_up(self, value):
+        return value + datetime.timedelta(days=1)
+
+    def _step_value_down(self, value):
+        return value - datetime.timedelta(days=1)
 
 
 class _DateTimeRangeTests:
@@ -4043,38 +4260,47 @@ class _DateTimeRangeTests:
     _col_str = "TSRANGE"
 
     def _data_str(self):
-        return "[2013-03-23 14:30,2013-03-23 23:30)"
+        return "[2013-03-23 14:30:00,2013-03-30 23:30:00)"
 
     def _data_obj(self):
         return Range(
             datetime.datetime(2013, 3, 23, 14, 30),
-            datetime.datetime(2013, 3, 23, 23, 30),
+            datetime.datetime(2013, 3, 30, 23, 30),
         )
 
+    def _step_value_up(self, value):
+        return value + datetime.timedelta(days=1)
+
+    def _step_value_down(self, value):
+        return value - datetime.timedelta(days=1)
+
 
 class _DateTimeTZRangeTests:
 
     _col_type = TSTZRANGE
     _col_str = "TSTZRANGE"
 
-    # make sure we use one, steady timestamp with timezone pair
-    # for all parts of all these tests
-    _tstzs = None
-
     def tstzs(self):
-        if self._tstzs is None:
-            with testing.db.connect() as connection:
-                lower = connection.scalar(func.current_timestamp().select())
-                upper = lower + datetime.timedelta(1)
-                self._tstzs = (lower, upper)
-        return self._tstzs
+        tz = datetime.timezone(-datetime.timedelta(hours=5, minutes=30))
+
+        return (
+            datetime.datetime(2013, 3, 23, 14, 30, tzinfo=tz),
+            datetime.datetime(2013, 3, 30, 23, 30, tzinfo=tz),
+        )
 
     def _data_str(self):
-        return "[%s,%s)" % self.tstzs()
+        l, r = self.tstzs()
+        return f"[{l},{r})"
 
     def _data_obj(self):
         return Range(*self.tstzs())
 
+    def _step_value_up(self, value):
+        return value + datetime.timedelta(days=1)
+
+    def _step_value_down(self, value):
+        return value - datetime.timedelta(days=1)
+
 
 class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation):
     pass