#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
-# mypy: ignore-errors
from __future__ import annotations
from datetime import timedelta
from decimal import Decimal
from typing import Any
+from typing import cast
from typing import Generic
from typing import Optional
+from typing import overload
+from typing import Tuple
+from typing import Type
+from typing import TYPE_CHECKING
from typing import TypeVar
from typing import Union
from ...util import py310
from ...util.typing import Literal
+if TYPE_CHECKING:
+ from ...sql.elements import ColumnElement
+ from ...sql.type_api import _TE
+ from ...sql.type_api import TypeEngine
+ from ...sql.type_api import TypeEngineMixin
+
_T = TypeVar("_T", bound=Any)
+_BoundsType = Literal["()", "[)", "(]", "[]"]
if py310:
dc_slots = {"slots": True}
upper: Optional[_T] = None
"""the upper bound"""
- bounds: Literal["()", "[)", "(]", "[]"] = dataclasses.field(
- default="[)", **dc_kwonly
- )
- empty: bool = dataclasses.field(default=False, **dc_kwonly)
+ if TYPE_CHECKING:
+ bounds: _BoundsType = dataclasses.field(default="[)")
+ empty: bool = dataclasses.field(default=False)
+ else:
+ bounds: _BoundsType = dataclasses.field(default="[)", **dc_kwonly)
+ empty: bool = dataclasses.field(default=False, **dc_kwonly)
if not py310:
def __init__(
- self, lower=None, upper=None, *, bounds="[)", empty=False
+ self,
+ lower: Optional[_T] = None,
+ upper: Optional[_T] = None,
+ *,
+ bounds: _BoundsType = "[)",
+ empty: bool = False,
):
# no __slots__ either so we can update dict
self.__dict__.update(
return not self.empty
@property
- def __sa_type_engine__(self):
+ def isempty(self) -> bool:
+ "A synonym for the 'empty' attribute."
+
+ return self.empty
+
+ @property
+ def is_empty(self) -> bool:
+ "A synonym for the 'empty' attribute."
+
+ return self.empty
+
+ @property
+ def lower_inc(self) -> bool:
+ """Return True if the lower bound is inclusive."""
+
+ return self.bounds[0] == "["
+
+ @property
+ def lower_inf(self) -> bool:
+ """Return True if this range is non-empty and lower bound is
+ infinite."""
+
+ return not self.empty and self.lower is None
+
+ @property
+ def upper_inc(self) -> bool:
+ """Return True if the upper bound is inclusive."""
+
+ return self.bounds[1] == "]"
+
+ @property
+ def upper_inf(self) -> bool:
+ """Return True if this range is non-empty and the upper bound is
+ infinite."""
+
+ return not self.empty and self.upper is None
+
+ @property
+ def __sa_type_engine__(self) -> AbstractRange[Range[_T]]:
return AbstractRange()
def _contains_value(self, value: _T) -> bool:
- "Check whether this range contains the given `value`."
+ """return True if this range contains the given value."""
if self.empty:
return False
)
if self.upper is None:
- return (
+ return ( # type: ignore
value > self.lower
if self.bounds[0] == "("
else value >= self.lower
)
- return (
+ return ( # type: ignore
value > self.lower
if self.bounds[0] == "("
else value >= self.lower
else value <= self.upper
)
- def _get_discrete_step(self):
+ def _get_discrete_step(self) -> Any:
"Determine the “step” for this range, if it is a discrete one."
# See
value2 += step
value2_inc = False
- if value1 < value2:
+ if value1 < value2: # type: ignore
return -1
- elif value1 > value2:
+ elif value1 > value2: # type: ignore
return 1
elif only_values:
return 0
else:
return 0
- def __eq__(self, other: Range) -> bool:
+ def __eq__(self, other: Any) -> bool: # type: ignore[override] # noqa: E501
"""Compare this range to the `other` taking into account
bounds inclusivity, returning ``True`` if they are equal.
"""
and self._compare_edges(supper, supper_b, oupper, oupper_b) == 0
)
- def contained_by(self, other: Range) -> bool:
+ def contained_by(self, other: Range[_T]) -> bool:
"Determine whether this range is a contained by `other`."
# Any range contains the empty one
return True
- def contains(self, value: Union[_T, Range]) -> bool:
+ def contains(self, value: Union[_T, Range[_T]]) -> bool:
"Determine whether this range contains `value`."
if isinstance(value, Range):
else:
return self._contains_value(value)
- def overlaps(self, other: Range) -> bool:
+ def overlaps(self, other: Range[_T]) -> bool:
"Determine whether this range overlaps with `other`."
# Empty ranges never overlap with any other range
return False
- def strictly_left_of(self, other: Range) -> bool:
+ def strictly_left_of(self, other: Range[_T]) -> bool:
"Determine whether this range is completely to the left of `other`."
# Empty ranges are neither to left nor to the right of any other range
__lshift__ = strictly_left_of
- def strictly_right_of(self, other: Range) -> bool:
+ def strictly_right_of(self, other: Range[_T]) -> bool:
"Determine whether this range is completely to the right of `other`."
# Empty ranges are neither to left nor to the right of any other range
__rshift__ = strictly_right_of
- def not_extend_left_of(self, other: Range) -> bool:
+ def not_extend_left_of(self, other: Range[_T]) -> bool:
"Determine whether this does not extend to the left of `other`."
# Empty ranges are neither to left nor to the right of any other range
# Check whether this lower edge is not less than other's lower end
return self._compare_edges(slower, slower_b, olower, olower_b) >= 0
- def not_extend_right_of(self, other: Range) -> bool:
+ def not_extend_right_of(self, other: Range[_T]) -> bool:
"Determine whether this does not extend to the right of `other`."
# Empty ranges are neither to left nor to the right of any other range
return False
if bound1 == "]":
if bound2 == "[":
- return value1 == value2 - step
+ return value1 == value2 - step # type: ignore
else:
return value1 == value2
else:
if bound2 == "[":
return value1 == value2
else:
- return value1 == value2 - step
+ return value1 == value2 - step # type: ignore
elif res == 0:
# Cover cases like [0,0] -|- [1,] and [0,2) -|- (1,3]
if (
else:
return False
- def adjacent_to(self, other: Range) -> bool:
+ def adjacent_to(self, other: Range[_T]) -> bool:
"Determine whether this range is adjacent to the `other`."
# Empty ranges are not adjacent to any other range
oupper, oupper_b, slower, slower_b
)
- def union(self, other: Range) -> Range:
+ def union(self, other: Range[_T]) -> Range[_T]:
"""Compute the union of this range with the `other`.
This raises a ``ValueError`` exception if the two ranges are
rupper = oupper
rupper_b = oupper_b
- return Range(rlower, rupper, bounds=rlower_b + rupper_b)
+ return Range(
+ rlower, rupper, bounds=cast(_BoundsType, rlower_b + rupper_b)
+ )
- __add__ = union
+ def __add__(self, other: Range[_T]) -> Range[_T]:
+ return self.union(other)
- def difference(self, other: Range) -> Range:
+ def difference(self, other: Range[_T]) -> Range[_T]:
"""Compute the difference between this range and the `other`.
This raises a ``ValueError`` exception if the two ranges are
):
return Range(None, None, empty=True)
else:
- return Range(slower, olower, bounds=slower_b + rupper_b)
+ return Range(
+ slower,
+ olower,
+ bounds=cast(_BoundsType, slower_b + rupper_b),
+ )
# If this range starts in the middle of the other and extends to its
# right
):
return Range(None, None, empty=True)
else:
- return Range(oupper, supper, bounds=rlower_b + supper_b)
+ return Range(
+ oupper,
+ supper,
+ bounds=cast(_BoundsType, rlower_b + supper_b),
+ )
assert False, f"Unhandled case computing {self} - {other}"
- __sub__ = difference
+ def __sub__(self, other: Range[_T]) -> Range[_T]:
+ return self.difference(other)
- def __str__(self):
+ def __str__(self) -> str:
return self._stringify()
- def _stringify(self):
+ def _stringify(self) -> str:
if self.empty:
return "empty"
l, r = self.lower, self.upper
- l = "" if l is None else l
- r = "" if r is None else r
+ l = "" if l is None else l # type: ignore
+ r = "" if r is None else r # type: ignore
- b0, b1 = self.bounds
+ b0, b1 = cast("Tuple[str, str]", self.bounds)
return f"{b0}{l},{r}{b1}"
-class AbstractRange(sqltypes.TypeEngine):
+class AbstractRange(sqltypes.TypeEngine[Range[_T]]):
"""
Base for PostgreSQL RANGE types.
__abstract__ = True
- def adapt(self, impltype):
+ @overload
+ def adapt(self, cls: Type[_TE], **kw: Any) -> _TE:
+ ...
+
+ @overload
+ def adapt(self, cls: Type[TypeEngineMixin], **kw: Any) -> TypeEngine[Any]:
+ ...
+
+ def adapt(
+ self,
+ cls: Type[Union[TypeEngine[Any], TypeEngineMixin]],
+ **kw: Any,
+ ) -> TypeEngine[Any]:
"""dynamically adapt a range type to an abstract impl.
For example ``INT4RANGE().adapt(_Psycopg2NumericRange)`` should
and also render as ``INT4RANGE`` in SQL and DDL.
"""
- if issubclass(impltype, AbstractRangeImpl):
+ if issubclass(cls, AbstractRangeImpl):
# two ways to do this are: 1. create a new type on the fly
# or 2. have AbstractRangeImpl(visit_name) constructor and a
# visit_abstract_range_impl() method in the PG compiler.
# The adapt() operation here is cached per type-class-per-dialect,
# so is not much of a performance concern
visit_name = self.__visit_name__
- return type(
+ return type( # type: ignore
f"{visit_name}RangeImpl",
- (impltype, self.__class__),
+ (cls, self.__class__),
{"__visit_name__": visit_name},
)()
else:
- return super().adapt(impltype)
+ return super().adapt(cls)
- def _resolve_for_literal(self, value):
+ def _resolve_for_literal(self, value: Any) -> Any:
spec = value.lower if value.lower is not None else value.upper
if isinstance(spec, int):
# empty Range, SQL datatype can't be determined here
return sqltypes.NULLTYPE
- class comparator_factory(sqltypes.Concatenable.Comparator):
+ class comparator_factory(sqltypes.Concatenable.Comparator[Range[Any]]):
"""Define comparison operations for range types."""
- def __ne__(self, other):
+ def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
"Boolean expression. Returns true if two ranges are not equal"
if other is None:
- return super().__ne__(other)
+ return super().__ne__(other) # type: ignore
else:
- return self.expr.op("<>", is_comparison=True)(other)
+ return self.expr.op("<>", is_comparison=True)(other) # type: ignore # noqa: E501
- def contains(self, other, **kw):
+ def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the right hand operand,
which can be an element or a range, is contained within the
column.
kwargs may be ignored by this operator but are required for API
conformance.
"""
- return self.expr.op("@>", is_comparison=True)(other)
+ return self.expr.op("@>", is_comparison=True)(other) # type: ignore # noqa: E501
- def contained_by(self, other):
+ def contained_by(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the column is contained
within the right hand operand.
"""
- return self.expr.op("<@", is_comparison=True)(other)
+ return self.expr.op("<@", is_comparison=True)(other) # type: ignore # noqa: E501
- def overlaps(self, other):
+ def overlaps(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the column overlaps
(has points in common with) the right hand operand.
"""
- return self.expr.op("&&", is_comparison=True)(other)
+ return self.expr.op("&&", is_comparison=True)(other) # type: ignore # noqa: E501
- def strictly_left_of(self, other):
+ def strictly_left_of(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the column is strictly
left of the right hand operand.
"""
- return self.expr.op("<<", is_comparison=True)(other)
+ return self.expr.op("<<", is_comparison=True)(other) # type: ignore # noqa: E501
__lshift__ = strictly_left_of
- def strictly_right_of(self, other):
+ def strictly_right_of(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the column is strictly
right of the right hand operand.
"""
- return self.expr.op(">>", is_comparison=True)(other)
+ return self.expr.op(">>", is_comparison=True)(other) # type: ignore # noqa: E501
__rshift__ = strictly_right_of
- def not_extend_right_of(self, other):
+ def not_extend_right_of(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the range in the column
does not extend right of the range in the operand.
"""
- return self.expr.op("&<", is_comparison=True)(other)
+ return self.expr.op("&<", is_comparison=True)(other) # type: ignore # noqa: E501
- def not_extend_left_of(self, other):
+ def not_extend_left_of(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the range in the column
does not extend left of the range in the operand.
"""
- return self.expr.op("&>", is_comparison=True)(other)
+ return self.expr.op("&>", is_comparison=True)(other) # type: ignore # noqa: E501
- def adjacent_to(self, other):
+ def adjacent_to(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the range in the column
is adjacent to the range in the operand.
"""
- return self.expr.op("-|-", is_comparison=True)(other)
+ return self.expr.op("-|-", is_comparison=True)(other) # type: ignore # noqa: E501
- def union(self, other):
+ def union(self, other: Any) -> ColumnElement[bool]:
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contiguous.
"""
- return self.expr.op("+")(other)
+ return self.expr.op("+")(other) # type: ignore
__add__ = union
- def difference(self, other):
+ def difference(self, other: Any) -> ColumnElement[bool]:
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contiguous.
"""
- return self.expr.op("-")(other)
+ return self.expr.op("-")(other) # type: ignore
__sub__ = difference
-class AbstractRangeImpl(AbstractRange):
+class AbstractRangeImpl(AbstractRange[Range[_T]]):
"""marker for AbstractRange that will apply a subclass-specific
adaptation"""
-class AbstractMultiRange(AbstractRange):
+class AbstractMultiRange(AbstractRange[Range[_T]]):
"""base for PostgreSQL MULTIRANGE types"""
__abstract__ = True
-class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange):
+class AbstractMultiRangeImpl(
+ AbstractRangeImpl[Range[_T]], AbstractMultiRange[Range[_T]]
+):
"""marker for AbstractRange that will apply a subclass-specific
adaptation"""
-class INT4RANGE(AbstractRange):
+class INT4RANGE(AbstractRange[Range[int]]):
"""Represent the PostgreSQL INT4RANGE type."""
__visit_name__ = "INT4RANGE"
-class INT8RANGE(AbstractRange):
+class INT8RANGE(AbstractRange[Range[int]]):
"""Represent the PostgreSQL INT8RANGE type."""
__visit_name__ = "INT8RANGE"
-class NUMRANGE(AbstractRange):
+class NUMRANGE(AbstractRange[Range[Decimal]]):
"""Represent the PostgreSQL NUMRANGE type."""
__visit_name__ = "NUMRANGE"
-class DATERANGE(AbstractRange):
+class DATERANGE(AbstractRange[Range[date]]):
"""Represent the PostgreSQL DATERANGE type."""
__visit_name__ = "DATERANGE"
-class TSRANGE(AbstractRange):
+class TSRANGE(AbstractRange[Range[datetime]]):
"""Represent the PostgreSQL TSRANGE type."""
__visit_name__ = "TSRANGE"
-class TSTZRANGE(AbstractRange):
+class TSTZRANGE(AbstractRange[Range[datetime]]):
"""Represent the PostgreSQL TSTZRANGE type."""
__visit_name__ = "TSTZRANGE"
-class INT4MULTIRANGE(AbstractMultiRange):
+class INT4MULTIRANGE(AbstractMultiRange[Range[int]]):
"""Represent the PostgreSQL INT4MULTIRANGE type."""
__visit_name__ = "INT4MULTIRANGE"
-class INT8MULTIRANGE(AbstractMultiRange):
+class INT8MULTIRANGE(AbstractMultiRange[Range[int]]):
"""Represent the PostgreSQL INT8MULTIRANGE type."""
__visit_name__ = "INT8MULTIRANGE"
-class NUMMULTIRANGE(AbstractMultiRange):
+class NUMMULTIRANGE(AbstractMultiRange[Range[Decimal]]):
"""Represent the PostgreSQL NUMMULTIRANGE type."""
__visit_name__ = "NUMMULTIRANGE"
-class DATEMULTIRANGE(AbstractMultiRange):
+class DATEMULTIRANGE(AbstractMultiRange[Range[date]]):
"""Represent the PostgreSQL DATEMULTIRANGE type."""
__visit_name__ = "DATEMULTIRANGE"
-class TSMULTIRANGE(AbstractMultiRange):
+class TSMULTIRANGE(AbstractMultiRange[Range[datetime]]):
"""Represent the PostgreSQL TSRANGE type."""
__visit_name__ = "TSMULTIRANGE"
-class TSTZMULTIRANGE(AbstractMultiRange):
+class TSTZMULTIRANGE(AbstractMultiRange[Range[datetime]]):
"""Represent the PostgreSQL TSTZRANGE type."""
__visit_name__ = "TSTZMULTIRANGE"