From: Lele Gaifax Date: Mon, 5 Dec 2022 01:51:13 +0000 (-0500) Subject: Add compatibility properties to Range; implement pep-484 X-Git-Tag: rel_2_0_0b4~4^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=3da784be647125f8727a92d1e386155e1f53c671;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Add compatibility properties to Range; implement pep-484 This adds a bunch of properties to new PG Range class for compatibility with other implementations, providing a more similar API to access emptiness and bounds status. The naming conventions here derive from PostgreSQL itself, see https://www.postgresql.org/docs/9.3/functions-range.html . pep-484 also implemented by Mike, as this is a pretty type-intensive module. Co-authored-by: Mike Bayer Closes: #8927 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8927 Pull-request-sha: 8b9e7b7e3345673b43aeabd7ec88b88dc3cfa7eb Change-Id: I0b1d49311517ee1cc1377a974ed0a860ea5756e4 --- diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index e772777bf3..9b5834ccd8 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -3,7 +3,6 @@ # # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php -# mypy: ignore-errors from __future__ import annotations @@ -13,8 +12,13 @@ from datetime import datetime from datetime import timedelta from decimal import Decimal from typing import Any +from typing import cast from typing import Generic from typing import Optional +from typing import overload +from typing import Tuple +from typing import Type +from typing import TYPE_CHECKING from typing import TypeVar from typing import Union @@ -22,8 +26,15 @@ from ... import types as sqltypes from ...util import py310 from ...util.typing import Literal +if TYPE_CHECKING: + from ...sql.elements import ColumnElement + from ...sql.type_api import _TE + from ...sql.type_api import TypeEngine + from ...sql.type_api import TypeEngineMixin + _T = TypeVar("_T", bound=Any) +_BoundsType = Literal["()", "[)", "(]", "[]"] if py310: dc_slots = {"slots": True} @@ -62,15 +73,22 @@ class Range(Generic[_T]): upper: Optional[_T] = None """the upper bound""" - bounds: Literal["()", "[)", "(]", "[]"] = dataclasses.field( - default="[)", **dc_kwonly - ) - empty: bool = dataclasses.field(default=False, **dc_kwonly) + if TYPE_CHECKING: + bounds: _BoundsType = dataclasses.field(default="[)") + empty: bool = dataclasses.field(default=False) + else: + bounds: _BoundsType = dataclasses.field(default="[)", **dc_kwonly) + empty: bool = dataclasses.field(default=False, **dc_kwonly) if not py310: def __init__( - self, lower=None, upper=None, *, bounds="[)", empty=False + self, + lower: Optional[_T] = None, + upper: Optional[_T] = None, + *, + bounds: _BoundsType = "[)", + empty: bool = False, ): # no __slots__ either so we can update dict self.__dict__.update( @@ -86,11 +104,49 @@ class Range(Generic[_T]): return not self.empty @property - def __sa_type_engine__(self): + def isempty(self) -> bool: + "A synonym for the 'empty' attribute." + + return self.empty + + @property + def is_empty(self) -> bool: + "A synonym for the 'empty' attribute." + + return self.empty + + @property + def lower_inc(self) -> bool: + """Return True if the lower bound is inclusive.""" + + return self.bounds[0] == "[" + + @property + def lower_inf(self) -> bool: + """Return True if this range is non-empty and lower bound is + infinite.""" + + return not self.empty and self.lower is None + + @property + def upper_inc(self) -> bool: + """Return True if the upper bound is inclusive.""" + + return self.bounds[1] == "]" + + @property + def upper_inf(self) -> bool: + """Return True if this range is non-empty and the upper bound is + infinite.""" + + return not self.empty and self.upper is None + + @property + def __sa_type_engine__(self) -> AbstractRange[Range[_T]]: return AbstractRange() def _contains_value(self, value: _T) -> bool: - "Check whether this range contains the given `value`." + """return True if this range contains the given value.""" if self.empty: return False @@ -103,13 +159,13 @@ class Range(Generic[_T]): ) if self.upper is None: - return ( + return ( # type: ignore value > self.lower if self.bounds[0] == "(" else value >= self.lower ) - return ( + return ( # type: ignore value > self.lower if self.bounds[0] == "(" else value >= self.lower @@ -119,7 +175,7 @@ class Range(Generic[_T]): else value <= self.upper ) - def _get_discrete_step(self): + def _get_discrete_step(self) -> Any: "Determine the “step” for this range, if it is a discrete one." # See @@ -203,9 +259,9 @@ class Range(Generic[_T]): value2 += step value2_inc = False - if value1 < value2: + if value1 < value2: # type: ignore return -1 - elif value1 > value2: + elif value1 > value2: # type: ignore return 1 elif only_values: return 0 @@ -228,7 +284,7 @@ class Range(Generic[_T]): else: return 0 - def __eq__(self, other: Range) -> bool: + def __eq__(self, other: Any) -> bool: # type: ignore[override] # noqa: E501 """Compare this range to the `other` taking into account bounds inclusivity, returning ``True`` if they are equal. """ @@ -252,7 +308,7 @@ class Range(Generic[_T]): and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0 ) - def contained_by(self, other: Range) -> bool: + def contained_by(self, other: Range[_T]) -> bool: "Determine whether this range is a contained by `other`." # Any range contains the empty one @@ -281,7 +337,7 @@ class Range(Generic[_T]): return True - def contains(self, value: Union[_T, Range]) -> bool: + def contains(self, value: Union[_T, Range[_T]]) -> bool: "Determine whether this range contains `value`." if isinstance(value, Range): @@ -289,7 +345,7 @@ class Range(Generic[_T]): else: return self._contains_value(value) - def overlaps(self, other: Range) -> bool: + def overlaps(self, other: Range[_T]) -> bool: "Determine whether this range overlaps with `other`." # Empty ranges never overlap with any other range @@ -321,7 +377,7 @@ class Range(Generic[_T]): return False - def strictly_left_of(self, other: Range) -> bool: + def strictly_left_of(self, other: Range[_T]) -> bool: "Determine whether this range is completely to the left of `other`." # Empty ranges are neither to left nor to the right of any other range @@ -338,7 +394,7 @@ class Range(Generic[_T]): __lshift__ = strictly_left_of - def strictly_right_of(self, other: Range) -> bool: + def strictly_right_of(self, other: Range[_T]) -> bool: "Determine whether this range is completely to the right of `other`." # Empty ranges are neither to left nor to the right of any other range @@ -355,7 +411,7 @@ class Range(Generic[_T]): __rshift__ = strictly_right_of - def not_extend_left_of(self, other: Range) -> bool: + def not_extend_left_of(self, other: Range[_T]) -> bool: "Determine whether this does not extend to the left of `other`." # Empty ranges are neither to left nor to the right of any other range @@ -370,7 +426,7 @@ class Range(Generic[_T]): # Check whether this lower edge is not less than other's lower end return self._compare_edges(slower, slower_b, olower, olower_b) >= 0 - def not_extend_right_of(self, other: Range) -> bool: + def not_extend_right_of(self, other: Range[_T]) -> bool: "Determine whether this does not extend to the right of `other`." # Empty ranges are neither to left nor to the right of any other range @@ -404,14 +460,14 @@ class Range(Generic[_T]): return False if bound1 == "]": if bound2 == "[": - return value1 == value2 - step + return value1 == value2 - step # type: ignore else: return value1 == value2 else: if bound2 == "[": return value1 == value2 else: - return value1 == value2 - step + return value1 == value2 - step # type: ignore elif res == 0: # Cover cases like [0,0] -|- [1,] and [0,2) -|- (1,3] if ( @@ -432,7 +488,7 @@ class Range(Generic[_T]): else: return False - def adjacent_to(self, other: Range) -> bool: + def adjacent_to(self, other: Range[_T]) -> bool: "Determine whether this range is adjacent to the `other`." # Empty ranges are not adjacent to any other range @@ -454,7 +510,7 @@ class Range(Generic[_T]): oupper, oupper_b, slower, slower_b ) - def union(self, other: Range) -> Range: + def union(self, other: Range[_T]) -> Range[_T]: """Compute the union of this range with the `other`. This raises a ``ValueError`` exception if the two ranges are @@ -496,11 +552,14 @@ class Range(Generic[_T]): rupper = oupper rupper_b = oupper_b - return Range(rlower, rupper, bounds=rlower_b + rupper_b) + return Range( + rlower, rupper, bounds=cast(_BoundsType, rlower_b + rupper_b) + ) - __add__ = union + def __add__(self, other: Range[_T]) -> Range[_T]: + return self.union(other) - def difference(self, other: Range) -> Range: + def difference(self, other: Range[_T]) -> Range[_T]: """Compute the difference between this range and the `other`. This raises a ``ValueError`` exception if the two ranges are @@ -550,7 +609,11 @@ class Range(Generic[_T]): ): return Range(None, None, empty=True) else: - return Range(slower, olower, bounds=slower_b + rupper_b) + return Range( + slower, + olower, + bounds=cast(_BoundsType, slower_b + rupper_b), + ) # If this range starts in the middle of the other and extends to its # right @@ -564,29 +627,34 @@ class Range(Generic[_T]): ): return Range(None, None, empty=True) else: - return Range(oupper, supper, bounds=rlower_b + supper_b) + return Range( + oupper, + supper, + bounds=cast(_BoundsType, rlower_b + supper_b), + ) assert False, f"Unhandled case computing {self} - {other}" - __sub__ = difference + def __sub__(self, other: Range[_T]) -> Range[_T]: + return self.difference(other) - def __str__(self): + def __str__(self) -> str: return self._stringify() - def _stringify(self): + def _stringify(self) -> str: if self.empty: return "empty" l, r = self.lower, self.upper - l = "" if l is None else l - r = "" if r is None else r + l = "" if l is None else l # type: ignore + r = "" if r is None else r # type: ignore - b0, b1 = self.bounds + b0, b1 = cast("Tuple[str, str]", self.bounds) return f"{b0}{l},{r}{b1}" -class AbstractRange(sqltypes.TypeEngine): +class AbstractRange(sqltypes.TypeEngine[Range[_T]]): """ Base for PostgreSQL RANGE types. @@ -600,7 +668,19 @@ class AbstractRange(sqltypes.TypeEngine): __abstract__ = True - def adapt(self, impltype): + @overload + def adapt(self, cls: Type[_TE], **kw: Any) -> _TE: + ... + + @overload + def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]: + ... + + def adapt( + self, + cls: Type[Union[TypeEngine[Any], TypeEngineMixin]], + **kw: Any, + ) -> TypeEngine[Any]: """dynamically adapt a range type to an abstract impl. For example ``INT4RANGE().adapt(_Psycopg2NumericRange)`` should @@ -608,7 +688,7 @@ class AbstractRange(sqltypes.TypeEngine): and also render as ``INT4RANGE`` in SQL and DDL. """ - if issubclass(impltype, AbstractRangeImpl): + if issubclass(cls, AbstractRangeImpl): # two ways to do this are: 1. create a new type on the fly # or 2. have AbstractRangeImpl(visit_name) constructor and a # visit_abstract_range_impl() method in the PG compiler. @@ -619,15 +699,15 @@ class AbstractRange(sqltypes.TypeEngine): # The adapt() operation here is cached per type-class-per-dialect, # so is not much of a performance concern visit_name = self.__visit_name__ - return type( + return type( # type: ignore f"{visit_name}RangeImpl", - (impltype, self.__class__), + (cls, self.__class__), {"__visit_name__": visit_name}, )() else: - return super().adapt(impltype) + return super().adapt(cls) - def _resolve_for_literal(self, value): + def _resolve_for_literal(self, value: Any) -> Any: spec = value.lower if value.lower is not None else value.upper if isinstance(spec, int): @@ -642,17 +722,17 @@ class AbstractRange(sqltypes.TypeEngine): # empty Range, SQL datatype can't be determined here return sqltypes.NULLTYPE - class comparator_factory(sqltypes.Concatenable.Comparator): + class comparator_factory(sqltypes.Concatenable.Comparator[Range[Any]]): """Define comparison operations for range types.""" - def __ne__(self, other): + def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 "Boolean expression. Returns true if two ranges are not equal" if other is None: - return super().__ne__(other) + return super().__ne__(other) # type: ignore else: - return self.expr.op("<>", is_comparison=True)(other) + return self.expr.op("<>", is_comparison=True)(other) # type: ignore # noqa: E501 - def contains(self, other, **kw): + def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the right hand operand, which can be an element or a range, is contained within the column. @@ -660,156 +740,158 @@ class AbstractRange(sqltypes.TypeEngine): kwargs may be ignored by this operator but are required for API conformance. """ - return self.expr.op("@>", is_comparison=True)(other) + return self.expr.op("@>", is_comparison=True)(other) # type: ignore # noqa: E501 - def contained_by(self, other): + def contained_by(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the column is contained within the right hand operand. """ - return self.expr.op("<@", is_comparison=True)(other) + return self.expr.op("<@", is_comparison=True)(other) # type: ignore # noqa: E501 - def overlaps(self, other): + def overlaps(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the column overlaps (has points in common with) the right hand operand. """ - return self.expr.op("&&", is_comparison=True)(other) + return self.expr.op("&&", is_comparison=True)(other) # type: ignore # noqa: E501 - def strictly_left_of(self, other): + def strictly_left_of(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the column is strictly left of the right hand operand. """ - return self.expr.op("<<", is_comparison=True)(other) + return self.expr.op("<<", is_comparison=True)(other) # type: ignore # noqa: E501 __lshift__ = strictly_left_of - def strictly_right_of(self, other): + def strictly_right_of(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the column is strictly right of the right hand operand. """ - return self.expr.op(">>", is_comparison=True)(other) + return self.expr.op(">>", is_comparison=True)(other) # type: ignore # noqa: E501 __rshift__ = strictly_right_of - def not_extend_right_of(self, other): + def not_extend_right_of(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the range in the column does not extend right of the range in the operand. """ - return self.expr.op("&<", is_comparison=True)(other) + return self.expr.op("&<", is_comparison=True)(other) # type: ignore # noqa: E501 - def not_extend_left_of(self, other): + def not_extend_left_of(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the range in the column does not extend left of the range in the operand. """ - return self.expr.op("&>", is_comparison=True)(other) + return self.expr.op("&>", is_comparison=True)(other) # type: ignore # noqa: E501 - def adjacent_to(self, other): + def adjacent_to(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the range in the column is adjacent to the range in the operand. """ - return self.expr.op("-|-", is_comparison=True)(other) + return self.expr.op("-|-", is_comparison=True)(other) # type: ignore # noqa: E501 - def union(self, other): + def union(self, other: Any) -> ColumnElement[bool]: """Range expression. Returns the union of the two ranges. Will raise an exception if the resulting range is not contiguous. """ - return self.expr.op("+")(other) + return self.expr.op("+")(other) # type: ignore __add__ = union - def difference(self, other): + def difference(self, other: Any) -> ColumnElement[bool]: """Range expression. Returns the union of the two ranges. Will raise an exception if the resulting range is not contiguous. """ - return self.expr.op("-")(other) + return self.expr.op("-")(other) # type: ignore __sub__ = difference -class AbstractRangeImpl(AbstractRange): +class AbstractRangeImpl(AbstractRange[Range[_T]]): """marker for AbstractRange that will apply a subclass-specific adaptation""" -class AbstractMultiRange(AbstractRange): +class AbstractMultiRange(AbstractRange[Range[_T]]): """base for PostgreSQL MULTIRANGE types""" __abstract__ = True -class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange): +class AbstractMultiRangeImpl( + AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]] +): """marker for AbstractRange that will apply a subclass-specific adaptation""" -class INT4RANGE(AbstractRange): +class INT4RANGE(AbstractRange[Range[int]]): """Represent the PostgreSQL INT4RANGE type.""" __visit_name__ = "INT4RANGE" -class INT8RANGE(AbstractRange): +class INT8RANGE(AbstractRange[Range[int]]): """Represent the PostgreSQL INT8RANGE type.""" __visit_name__ = "INT8RANGE" -class NUMRANGE(AbstractRange): +class NUMRANGE(AbstractRange[Range[Decimal]]): """Represent the PostgreSQL NUMRANGE type.""" __visit_name__ = "NUMRANGE" -class DATERANGE(AbstractRange): +class DATERANGE(AbstractRange[Range[date]]): """Represent the PostgreSQL DATERANGE type.""" __visit_name__ = "DATERANGE" -class TSRANGE(AbstractRange): +class TSRANGE(AbstractRange[Range[datetime]]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSRANGE" -class TSTZRANGE(AbstractRange): +class TSTZRANGE(AbstractRange[Range[datetime]]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZRANGE" -class INT4MULTIRANGE(AbstractMultiRange): +class INT4MULTIRANGE(AbstractMultiRange[Range[int]]): """Represent the PostgreSQL INT4MULTIRANGE type.""" __visit_name__ = "INT4MULTIRANGE" -class INT8MULTIRANGE(AbstractMultiRange): +class INT8MULTIRANGE(AbstractMultiRange[Range[int]]): """Represent the PostgreSQL INT8MULTIRANGE type.""" __visit_name__ = "INT8MULTIRANGE" -class NUMMULTIRANGE(AbstractMultiRange): +class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]): """Represent the PostgreSQL NUMMULTIRANGE type.""" __visit_name__ = "NUMMULTIRANGE" -class DATEMULTIRANGE(AbstractMultiRange): +class DATEMULTIRANGE(AbstractMultiRange[Range[date]]): """Represent the PostgreSQL DATEMULTIRANGE type.""" __visit_name__ = "DATEMULTIRANGE" -class TSMULTIRANGE(AbstractMultiRange): +class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]): """Represent the PostgreSQL TSRANGE type.""" __visit_name__ = "TSMULTIRANGE" -class TSTZMULTIRANGE(AbstractMultiRange): +class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]): """Represent the PostgreSQL TSTZRANGE type.""" __visit_name__ = "TSTZMULTIRANGE" diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index ec9bcbae92..dcde497b46 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -3996,6 +3996,24 @@ class _RangeComparisonFixtures(_RangeTests): is_false(range_.contains(values["rh"])) + def test_compatibility_accessors(self): + range_ = self._data_obj() + + is_true(range_.lower_inc) + is_false(range_.upper_inc) + is_false(Range(lower=range_.lower, bounds="()").lower_inc) + is_true(Range(upper=range_.upper, bounds="(]").upper_inc) + + is_false(range_.lower_inf) + is_false(range_.upper_inf) + is_false(Range(empty=True).lower_inf) + is_false(Range(empty=True).upper_inf) + is_true(Range().lower_inf) + is_true(Range().upper_inf) + + is_false(range_.isempty) + is_true(Range(empty=True).isempty) + def test_contains_value( self, connection, bounds_obj_combinations, value_combinations ):