]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add compatibility properties to Range; implement pep-484
authorLele Gaifax <lele@metapensiero.it>
Mon, 5 Dec 2022 01:51:13 +0000 (20:51 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 5 Dec 2022 16:51:26 +0000 (11:51 -0500)
This adds a bunch of properties to new PG Range class for compatibility
with other implementations, providing a more similar API to access
emptiness and bounds status.

The naming conventions here derive from PostgreSQL itself,
see https://www.postgresql.org/docs/9.3/functions-range.html .

pep-484 also implemented by Mike, as this is a pretty type-intensive
module.

Co-authored-by: Mike Bayer <mike_mp@zzzcomputing.com>
Closes: #8927
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8927
Pull-request-sha: 8b9e7b7e3345673b43aeabd7ec88b88dc3cfa7eb

Change-Id: I0b1d49311517ee1cc1377a974ed0a860ea5756e4

lib/sqlalchemy/dialects/postgresql/ranges.py
test/dialect/postgresql/test_types.py

index e772777bf31c5090b065b93d639f3272dce8abdf..9b5834ccd80a2b017442cc871f3107a5d56b4afa 100644 (file)
@@ -3,7 +3,6 @@
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
 
 from __future__ import annotations
 
@@ -13,8 +12,13 @@ from datetime import datetime
 from datetime import timedelta
 from decimal import Decimal
 from typing import Any
+from typing import cast
 from typing import Generic
 from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
 from typing import TypeVar
 from typing import Union
 
@@ -22,8 +26,15 @@ from ... import types as sqltypes
 from ...util import py310
 from ...util.typing import Literal
 
+if TYPE_CHECKING:
+    from ...sql.elements import ColumnElement
+    from ...sql.type_api import _TE
+    from ...sql.type_api import TypeEngine
+    from ...sql.type_api import TypeEngineMixin
+
 _T = TypeVar("_T", bound=Any)
 
+_BoundsType = Literal["()", "[)", "(]", "[]"]
 
 if py310:
     dc_slots = {"slots": True}
@@ -62,15 +73,22 @@ class Range(Generic[_T]):
     upper: Optional[_T] = None
     """the upper bound"""
 
-    bounds: Literal["()", "[)", "(]", "[]"] = dataclasses.field(
-        default="[)", **dc_kwonly
-    )
-    empty: bool = dataclasses.field(default=False, **dc_kwonly)
+    if TYPE_CHECKING:
+        bounds: _BoundsType = dataclasses.field(default="[)")
+        empty: bool = dataclasses.field(default=False)
+    else:
+        bounds: _BoundsType = dataclasses.field(default="[)", **dc_kwonly)
+        empty: bool = dataclasses.field(default=False, **dc_kwonly)
 
     if not py310:
 
         def __init__(
-            self, lower=None, upper=None, *, bounds="[)", empty=False
+            self,
+            lower: Optional[_T] = None,
+            upper: Optional[_T] = None,
+            *,
+            bounds: _BoundsType = "[)",
+            empty: bool = False,
         ):
             # no __slots__ either so we can update dict
             self.__dict__.update(
@@ -86,11 +104,49 @@ class Range(Generic[_T]):
         return not self.empty
 
     @property
-    def __sa_type_engine__(self):
+    def isempty(self) -> bool:
+        "A synonym for the 'empty' attribute."
+
+        return self.empty
+
+    @property
+    def is_empty(self) -> bool:
+        "A synonym for the 'empty' attribute."
+
+        return self.empty
+
+    @property
+    def lower_inc(self) -> bool:
+        """Return True if the lower bound is inclusive."""
+
+        return self.bounds[0] == "["
+
+    @property
+    def lower_inf(self) -> bool:
+        """Return True if this range is non-empty and lower bound is
+        infinite."""
+
+        return not self.empty and self.lower is None
+
+    @property
+    def upper_inc(self) -> bool:
+        """Return True if the upper bound is inclusive."""
+
+        return self.bounds[1] == "]"
+
+    @property
+    def upper_inf(self) -> bool:
+        """Return True if this range is non-empty and the upper bound is
+        infinite."""
+
+        return not self.empty and self.upper is None
+
+    @property
+    def __sa_type_engine__(self) -> AbstractRange[Range[_T]]:
         return AbstractRange()
 
     def _contains_value(self, value: _T) -> bool:
-        "Check whether this range contains the given `value`."
+        """return True if this range contains the given value."""
 
         if self.empty:
             return False
@@ -103,13 +159,13 @@ class Range(Generic[_T]):
             )
 
         if self.upper is None:
-            return (
+            return (  # type: ignore
                 value > self.lower
                 if self.bounds[0] == "("
                 else value >= self.lower
             )
 
-        return (
+        return (  # type: ignore
             value > self.lower
             if self.bounds[0] == "("
             else value >= self.lower
@@ -119,7 +175,7 @@ class Range(Generic[_T]):
             else value <= self.upper
         )
 
-    def _get_discrete_step(self):
+    def _get_discrete_step(self) -> Any:
         "Determine the “step” for this range, if it is a discrete one."
 
         # See
@@ -203,9 +259,9 @@ class Range(Generic[_T]):
                     value2 += step
                     value2_inc = False
 
-        if value1 < value2:
+        if value1 < value2:  # type: ignore
             return -1
-        elif value1 > value2:
+        elif value1 > value2:  # type: ignore
             return 1
         elif only_values:
             return 0
@@ -228,7 +284,7 @@ class Range(Generic[_T]):
             else:
                 return 0
 
-    def __eq__(self, other: Range) -> bool:
+    def __eq__(self, other: Any) -> bool:  # type: ignore[override]  # noqa: E501
         """Compare this range to the `other` taking into account
         bounds inclusivity, returning ``True`` if they are equal.
         """
@@ -252,7 +308,7 @@ class Range(Generic[_T]):
             and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0
         )
 
-    def contained_by(self, other: Range) -> bool:
+    def contained_by(self, other: Range[_T]) -> bool:
         "Determine whether this range is a contained by `other`."
 
         # Any range contains the empty one
@@ -281,7 +337,7 @@ class Range(Generic[_T]):
 
         return True
 
-    def contains(self, value: Union[_T, Range]) -> bool:
+    def contains(self, value: Union[_T, Range[_T]]) -> bool:
         "Determine whether this range contains `value`."
 
         if isinstance(value, Range):
@@ -289,7 +345,7 @@ class Range(Generic[_T]):
         else:
             return self._contains_value(value)
 
-    def overlaps(self, other: Range) -> bool:
+    def overlaps(self, other: Range[_T]) -> bool:
         "Determine whether this range overlaps with `other`."
 
         # Empty ranges never overlap with any other range
@@ -321,7 +377,7 @@ class Range(Generic[_T]):
 
         return False
 
-    def strictly_left_of(self, other: Range) -> bool:
+    def strictly_left_of(self, other: Range[_T]) -> bool:
         "Determine whether this range is completely to the left of `other`."
 
         # Empty ranges are neither to left nor to the right of any other range
@@ -338,7 +394,7 @@ class Range(Generic[_T]):
 
     __lshift__ = strictly_left_of
 
-    def strictly_right_of(self, other: Range) -> bool:
+    def strictly_right_of(self, other: Range[_T]) -> bool:
         "Determine whether this range is completely to the right of `other`."
 
         # Empty ranges are neither to left nor to the right of any other range
@@ -355,7 +411,7 @@ class Range(Generic[_T]):
 
     __rshift__ = strictly_right_of
 
-    def not_extend_left_of(self, other: Range) -> bool:
+    def not_extend_left_of(self, other: Range[_T]) -> bool:
         "Determine whether this does not extend to the left of `other`."
 
         # Empty ranges are neither to left nor to the right of any other range
@@ -370,7 +426,7 @@ class Range(Generic[_T]):
         # Check whether this lower edge is not less than other's lower end
         return self._compare_edges(slower, slower_b, olower, olower_b) >= 0
 
-    def not_extend_right_of(self, other: Range) -> bool:
+    def not_extend_right_of(self, other: Range[_T]) -> bool:
         "Determine whether this does not extend to the right of `other`."
 
         # Empty ranges are neither to left nor to the right of any other range
@@ -404,14 +460,14 @@ class Range(Generic[_T]):
                 return False
             if bound1 == "]":
                 if bound2 == "[":
-                    return value1 == value2 - step
+                    return value1 == value2 - step  # type: ignore
                 else:
                     return value1 == value2
             else:
                 if bound2 == "[":
                     return value1 == value2
                 else:
-                    return value1 == value2 - step
+                    return value1 == value2 - step  # type: ignore
         elif res == 0:
             # Cover cases like [0,0] -|- [1,] and [0,2) -|- (1,3]
             if (
@@ -432,7 +488,7 @@ class Range(Generic[_T]):
         else:
             return False
 
-    def adjacent_to(self, other: Range) -> bool:
+    def adjacent_to(self, other: Range[_T]) -> bool:
         "Determine whether this range is adjacent to the `other`."
 
         # Empty ranges are not adjacent to any other range
@@ -454,7 +510,7 @@ class Range(Generic[_T]):
             oupper, oupper_b, slower, slower_b
         )
 
-    def union(self, other: Range) -> Range:
+    def union(self, other: Range[_T]) -> Range[_T]:
         """Compute the union of this range with the `other`.
 
         This raises a ``ValueError`` exception if the two ranges are
@@ -496,11 +552,14 @@ class Range(Generic[_T]):
             rupper = oupper
             rupper_b = oupper_b
 
-        return Range(rlower, rupper, bounds=rlower_b + rupper_b)
+        return Range(
+            rlower, rupper, bounds=cast(_BoundsType, rlower_b + rupper_b)
+        )
 
-    __add__ = union
+    def __add__(self, other: Range[_T]) -> Range[_T]:
+        return self.union(other)
 
-    def difference(self, other: Range) -> Range:
+    def difference(self, other: Range[_T]) -> Range[_T]:
         """Compute the difference between this range and the `other`.
 
         This raises a ``ValueError`` exception if the two ranges are
@@ -550,7 +609,11 @@ class Range(Generic[_T]):
             ):
                 return Range(None, None, empty=True)
             else:
-                return Range(slower, olower, bounds=slower_b + rupper_b)
+                return Range(
+                    slower,
+                    olower,
+                    bounds=cast(_BoundsType, slower_b + rupper_b),
+                )
 
         # If this range starts in the middle of the other and extends to its
         # right
@@ -564,29 +627,34 @@ class Range(Generic[_T]):
             ):
                 return Range(None, None, empty=True)
             else:
-                return Range(oupper, supper, bounds=rlower_b + supper_b)
+                return Range(
+                    oupper,
+                    supper,
+                    bounds=cast(_BoundsType, rlower_b + supper_b),
+                )
 
         assert False, f"Unhandled case computing {self} - {other}"
 
-    __sub__ = difference
+    def __sub__(self, other: Range[_T]) -> Range[_T]:
+        return self.difference(other)
 
-    def __str__(self):
+    def __str__(self) -> str:
         return self._stringify()
 
-    def _stringify(self):
+    def _stringify(self) -> str:
         if self.empty:
             return "empty"
 
         l, r = self.lower, self.upper
-        l = "" if l is None else l
-        r = "" if r is None else r
+        l = "" if l is None else l  # type: ignore
+        r = "" if r is None else r  # type: ignore
 
-        b0, b1 = self.bounds
+        b0, b1 = cast("Tuple[str, str]", self.bounds)
 
         return f"{b0}{l},{r}{b1}"
 
 
-class AbstractRange(sqltypes.TypeEngine):
+class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
     """
     Base for PostgreSQL RANGE types.
 
@@ -600,7 +668,19 @@ class AbstractRange(sqltypes.TypeEngine):
 
     __abstract__ = True
 
-    def adapt(self, impltype):
+    @overload
+    def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+        ...
+
+    @overload
+    def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+        ...
+
+    def adapt(
+        self,
+        cls: Type[Union[TypeEngine[Any], TypeEngineMixin]],
+        **kw: Any,
+    ) -> TypeEngine[Any]:
         """dynamically adapt a range type to an abstract impl.
 
         For example ``INT4RANGE().adapt(_Psycopg2NumericRange)`` should
@@ -608,7 +688,7 @@ class AbstractRange(sqltypes.TypeEngine):
         and also render as ``INT4RANGE`` in SQL and DDL.
 
         """
-        if issubclass(impltype, AbstractRangeImpl):
+        if issubclass(cls, AbstractRangeImpl):
             # two ways to do this are:  1. create a new type on the fly
             # or 2. have AbstractRangeImpl(visit_name) constructor and a
             # visit_abstract_range_impl() method in the PG compiler.
@@ -619,15 +699,15 @@ class AbstractRange(sqltypes.TypeEngine):
             # The adapt() operation here is cached per type-class-per-dialect,
             # so is not much of a performance concern
             visit_name = self.__visit_name__
-            return type(
+            return type(  # type: ignore
                 f"{visit_name}RangeImpl",
-                (impltype, self.__class__),
+                (cls, self.__class__),
                 {"__visit_name__": visit_name},
             )()
         else:
-            return super().adapt(impltype)
+            return super().adapt(cls)
 
-    def _resolve_for_literal(self, value):
+    def _resolve_for_literal(self, value: Any) -> Any:
         spec = value.lower if value.lower is not None else value.upper
 
         if isinstance(spec, int):
@@ -642,17 +722,17 @@ class AbstractRange(sqltypes.TypeEngine):
             # empty Range, SQL datatype can't be determined here
             return sqltypes.NULLTYPE
 
-    class comparator_factory(sqltypes.Concatenable.Comparator):
+    class comparator_factory(sqltypes.Concatenable.Comparator[Range[Any]]):
         """Define comparison operations for range types."""
 
-        def __ne__(self, other):
+        def __ne__(self, other: Any) -> ColumnElement[bool]:  # type: ignore[override]  # noqa: E501
             "Boolean expression. Returns true if two ranges are not equal"
             if other is None:
-                return super().__ne__(other)
+                return super().__ne__(other)  # type: ignore
             else:
-                return self.expr.op("<>", is_comparison=True)(other)
+                return self.expr.op("<>", is_comparison=True)(other)  # type: ignore # noqa: E501
 
-        def contains(self, other, **kw):
+        def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the right hand operand,
             which can be an element or a range, is contained within the
             column.
@@ -660,156 +740,158 @@ class AbstractRange(sqltypes.TypeEngine):
             kwargs may be ignored by this operator but are required for API
             conformance.
             """
-            return self.expr.op("@>", is_comparison=True)(other)
+            return self.expr.op("@>", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
-        def contained_by(self, other):
+        def contained_by(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the column is contained
             within the right hand operand.
             """
-            return self.expr.op("<@", is_comparison=True)(other)
+            return self.expr.op("<@", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
-        def overlaps(self, other):
+        def overlaps(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the column overlaps
             (has points in common with) the right hand operand.
             """
-            return self.expr.op("&&", is_comparison=True)(other)
+            return self.expr.op("&&", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
-        def strictly_left_of(self, other):
+        def strictly_left_of(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the column is strictly
             left of the right hand operand.
             """
-            return self.expr.op("<<", is_comparison=True)(other)
+            return self.expr.op("<<", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
         __lshift__ = strictly_left_of
 
-        def strictly_right_of(self, other):
+        def strictly_right_of(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the column is strictly
             right of the right hand operand.
             """
-            return self.expr.op(">>", is_comparison=True)(other)
+            return self.expr.op(">>", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
         __rshift__ = strictly_right_of
 
-        def not_extend_right_of(self, other):
+        def not_extend_right_of(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the range in the column
             does not extend right of the range in the operand.
             """
-            return self.expr.op("&<", is_comparison=True)(other)
+            return self.expr.op("&<", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
-        def not_extend_left_of(self, other):
+        def not_extend_left_of(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the range in the column
             does not extend left of the range in the operand.
             """
-            return self.expr.op("&>", is_comparison=True)(other)
+            return self.expr.op("&>", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
-        def adjacent_to(self, other):
+        def adjacent_to(self, other: Any) -> ColumnElement[bool]:
             """Boolean expression. Returns true if the range in the column
             is adjacent to the range in the operand.
             """
-            return self.expr.op("-|-", is_comparison=True)(other)
+            return self.expr.op("-|-", is_comparison=True)(other)  # type: ignore  # noqa: E501
 
-        def union(self, other):
+        def union(self, other: Any) -> ColumnElement[bool]:
             """Range expression. Returns the union of the two ranges.
             Will raise an exception if the resulting range is not
             contiguous.
             """
-            return self.expr.op("+")(other)
+            return self.expr.op("+")(other)  # type: ignore
 
         __add__ = union
 
-        def difference(self, other):
+        def difference(self, other: Any) -> ColumnElement[bool]:
             """Range expression. Returns the union of the two ranges.
             Will raise an exception if the resulting range is not
             contiguous.
             """
-            return self.expr.op("-")(other)
+            return self.expr.op("-")(other)  # type: ignore
 
         __sub__ = difference
 
 
-class AbstractRangeImpl(AbstractRange):
+class AbstractRangeImpl(AbstractRange[Range[_T]]):
     """marker for AbstractRange that will apply a subclass-specific
     adaptation"""
 
 
-class AbstractMultiRange(AbstractRange):
+class AbstractMultiRange(AbstractRange[Range[_T]]):
     """base for PostgreSQL MULTIRANGE types"""
 
     __abstract__ = True
 
 
-class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange):
+class AbstractMultiRangeImpl(
+    AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]]
+):
     """marker for AbstractRange that will apply a subclass-specific
     adaptation"""
 
 
-class INT4RANGE(AbstractRange):
+class INT4RANGE(AbstractRange[Range[int]]):
     """Represent the PostgreSQL INT4RANGE type."""
 
     __visit_name__ = "INT4RANGE"
 
 
-class INT8RANGE(AbstractRange):
+class INT8RANGE(AbstractRange[Range[int]]):
     """Represent the PostgreSQL INT8RANGE type."""
 
     __visit_name__ = "INT8RANGE"
 
 
-class NUMRANGE(AbstractRange):
+class NUMRANGE(AbstractRange[Range[Decimal]]):
     """Represent the PostgreSQL NUMRANGE type."""
 
     __visit_name__ = "NUMRANGE"
 
 
-class DATERANGE(AbstractRange):
+class DATERANGE(AbstractRange[Range[date]]):
     """Represent the PostgreSQL DATERANGE type."""
 
     __visit_name__ = "DATERANGE"
 
 
-class TSRANGE(AbstractRange):
+class TSRANGE(AbstractRange[Range[datetime]]):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSRANGE"
 
 
-class TSTZRANGE(AbstractRange):
+class TSTZRANGE(AbstractRange[Range[datetime]]):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZRANGE"
 
 
-class INT4MULTIRANGE(AbstractMultiRange):
+class INT4MULTIRANGE(AbstractMultiRange[Range[int]]):
     """Represent the PostgreSQL INT4MULTIRANGE type."""
 
     __visit_name__ = "INT4MULTIRANGE"
 
 
-class INT8MULTIRANGE(AbstractMultiRange):
+class INT8MULTIRANGE(AbstractMultiRange[Range[int]]):
     """Represent the PostgreSQL INT8MULTIRANGE type."""
 
     __visit_name__ = "INT8MULTIRANGE"
 
 
-class NUMMULTIRANGE(AbstractMultiRange):
+class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]):
     """Represent the PostgreSQL NUMMULTIRANGE type."""
 
     __visit_name__ = "NUMMULTIRANGE"
 
 
-class DATEMULTIRANGE(AbstractMultiRange):
+class DATEMULTIRANGE(AbstractMultiRange[Range[date]]):
     """Represent the PostgreSQL DATEMULTIRANGE type."""
 
     __visit_name__ = "DATEMULTIRANGE"
 
 
-class TSMULTIRANGE(AbstractMultiRange):
+class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]):
     """Represent the PostgreSQL TSRANGE type."""
 
     __visit_name__ = "TSMULTIRANGE"
 
 
-class TSTZMULTIRANGE(AbstractMultiRange):
+class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]):
     """Represent the PostgreSQL TSTZRANGE type."""
 
     __visit_name__ = "TSTZMULTIRANGE"
index ec9bcbae92a104b9c7d94d77f8572598da6cac6d..dcde497b46326292f269deea13802db91b3d2ec2 100644 (file)
@@ -3996,6 +3996,24 @@ class _RangeComparisonFixtures(_RangeTests):
 
         is_false(range_.contains(values["rh"]))
 
+    def test_compatibility_accessors(self):
+        range_ = self._data_obj()
+
+        is_true(range_.lower_inc)
+        is_false(range_.upper_inc)
+        is_false(Range(lower=range_.lower, bounds="()").lower_inc)
+        is_true(Range(upper=range_.upper, bounds="(]").upper_inc)
+
+        is_false(range_.lower_inf)
+        is_false(range_.upper_inf)
+        is_false(Range(empty=True).lower_inf)
+        is_false(Range(empty=True).upper_inf)
+        is_true(Range().lower_inf)
+        is_true(Range().upper_inf)
+
+        is_false(range_.isempty)
+        is_true(Range(empty=True).isempty)
+
     def test_contains_value(
         self, connection, bounds_obj_combinations, value_combinations
     ):