from __future__ import annotations
import dataclasses
+from datetime import date
+from datetime import datetime
+from datetime import timedelta
from typing import Any
from typing import Generic
from typing import Optional
from typing import TypeVar
+from typing import Union
from ... import types as sqltypes
from ...util import py310
def __bool__(self) -> bool:
return self.empty
+ def _contains_value(self, value: _T) -> bool:
+ "Check whether this range contains the given `value`."
+
+ if self.empty:
+ return False
+
+ if self.lower is None:
+ return self.upper is None or (
+ value < self.upper
+ if self.bounds[1] == ")"
+ else value <= self.upper
+ )
+
+ if self.upper is None:
+ return (
+ value > self.lower
+ if self.bounds[0] == "("
+ else value >= self.lower
+ )
+
+ return (
+ value > self.lower
+ if self.bounds[0] == "("
+ else value >= self.lower
+ ) and (
+ value < self.upper
+ if self.bounds[1] == ")"
+ else value <= self.upper
+ )
+
+ def _get_discrete_step(self):
+ "Determine the “step” for this range, if it is a discrete one."
+
+ # See
+ # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE
+ # for the rationale
+
+ if isinstance(self.lower, int) or isinstance(self.upper, int):
+ return 1
+ elif isinstance(self.lower, datetime) or isinstance(
+ self.upper, datetime
+ ):
+ # This is required, because a `isinstance(datetime.now(), date)`
+ # is True
+ return None
+ elif isinstance(self.lower, date) or isinstance(self.upper, date):
+ return timedelta(days=1)
+ else:
+ return None
+
+ def contained_by(self, other: Range) -> bool:
+ "Determine whether this range is a contained by `other`."
+
+ # Any range contains the empty one
+ if self.empty:
+ return True
+
+ # An empty range does not contain any range except the empty one
+ if other.empty:
+ return False
+
+ olower = other.lower
+ oupper = other.upper
+
+ # A bilateral unbound range contains any other range
+ if olower is oupper is None:
+ return True
+
+ slower = self.lower
+ supper = self.upper
+
+ # A lower-bound range cannot contain a lower-unbound range
+ if slower is None and olower is not None:
+ return False
+
+ # Likewise on the right side
+ if supper is None and oupper is not None:
+ return False
+
+ slower_inc = self.bounds[0] == "["
+ supper_inc = self.bounds[1] == "]"
+ olower_inc = other.bounds[0] == "["
+ oupper_inc = other.bounds[1] == "]"
+
+ # Check the lower end
+ step = -1
+ if slower is not None and olower is not None:
+ lside = olower < slower
+ if not lside:
+ if not slower_inc or olower_inc:
+ lside = olower == slower
+ if not lside:
+ # Cover (1,x] vs [2,x) and (0,x] vs [1,x)
+ if not slower_inc and olower_inc and slower < olower:
+ step = self._get_discrete_step()
+ if step is not None:
+ lside = olower == (slower + step)
+ elif slower_inc and not olower_inc and slower > olower:
+ step = self._get_discrete_step()
+ if step is not None:
+ lside = (olower + step) == slower
+ if not lside:
+ return False
+
+ # Lower end already considered, an upper-unbound range surely contains
+ # this
+ if oupper is None:
+ return True
+
+ # Check the upper end
+ uside = oupper > supper
+ if not uside:
+ if not supper_inc or oupper_inc:
+ uside = oupper == supper
+ if not uside:
+ # Cover (x,2] vs [x,3) and (x,1] vs [x,2)
+ if supper_inc and not oupper_inc and supper < oupper:
+ if step == -1:
+ step = self._get_discrete_step()
+ if step is not None:
+ uside = oupper == (supper + step)
+ elif not supper_inc and oupper_inc and supper > oupper:
+ if step == -1:
+ step = self._get_discrete_step()
+ if step is not None:
+ uside = (oupper + step) == supper
+ return uside
+
+ def contains(self, value: Union[_T, Range]) -> bool:
+ "Determine whether this range contains `value`."
+
+ if isinstance(value, Range):
+ return value.contained_by(self)
+ else:
+ return self._contains_value(value)
+
+ def overlaps(self, other):
+ """Boolean expression. Returns true if the column overlaps
+ (has points in common with) the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def strictly_left_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ left of the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ __lshift__ = strictly_left_of
+
+ def strictly_right_of(self, other):
+ """Boolean expression. Returns true if the column is strictly
+ right of the right hand operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ __rshift__ = strictly_right_of
+
+ def not_extend_right_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend right of the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def not_extend_left_of(self, other):
+ """Boolean expression. Returns true if the range in the column
+ does not extend left of the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def adjacent_to(self, other):
+ """Boolean expression. Returns true if the range in the column
+ is adjacent to the range in the operand.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def __add__(self, other):
+ """Range expression. Returns the union of the two ranges.
+ Will raise an exception if the resulting range is not
+ contiguous.
+ """
+ raise NotImplementedError("not yet implemented")
+
+ def __str__(self):
+ return self._stringify()
+
+ def _stringify(self):
+ if self.empty:
+ return "empty"
+
+ l, r = self.lower, self.upper
+ l = "" if l is None else l
+ r = "" if r is None else r
+
+ b0, b1 = self.bounds
+
+ return f"{b0}{l},{r}{b1}"
+
class AbstractRange(sqltypes.TypeEngine):
"""
render_bind_cast = True
+ __abstract__ = True
+
def adapt(self, impltype):
"""dynamically adapt a range type to an abstract impl.
class AbstractMultiRange(AbstractRange):
"""base for PostgreSQL MULTIRANGE types"""
+ __abstract__ = True
+
class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange):
"""marker for AbstractRange that will apply a subclass-specific
from sqlalchemy import inspect
from sqlalchemy import Integer
from sqlalchemy import literal
+from sqlalchemy import literal_column
from sqlalchemy import MetaData
from sqlalchemy import null
from sqlalchemy import Numeric
from sqlalchemy.sql import sqltypes
from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.testing import is_false
+from sqlalchemy.testing import is_true
from sqlalchemy.testing.assertions import assert_raises
from sqlalchemy.testing.assertions import assert_raises_message
from sqlalchemy.testing.assertions import AssertsCompiledSQL
)
-class _RangeTypeRoundTrip(fixtures.TablesTest):
+class _RangeComparisonFixtures:
+ def _data_str(self):
+ """return string form of a sample range"""
+ raise NotImplementedError()
+
+ def _data_obj(self):
+ """return Range form of the same range"""
+ raise NotImplementedError()
+
+ def _step_value_up(self, value):
+ """given a value, return a step up
+
+ this is a value that given the lower end of the sample range,
+ would be less than the upper value of the range
+
+ """
+ raise NotImplementedError()
+
+ def _step_value_down(self, value):
+ """given a value, return a step down
+
+ this is a value that given the upper end of the sample range,
+ would be greater than the lower value of the range
+
+ """
+ raise NotImplementedError()
+
+ def _value_values(self):
+ """Return a series of values related to the base range
+
+ le = left equal
+ ll = lower than left
+ re = right equal
+ rh = higher than right
+ il = inside lower
+ ih = inside higher
+
+ """
+ spec = self._data_obj()
+
+ le, re_ = spec.lower, spec.upper
+
+ ll = self._step_value_down(le)
+ il = self._step_value_up(le)
+ rh = self._step_value_up(re_)
+ ih = self._step_value_down(re_)
+
+ return {"le": le, "re_": re_, "ll": ll, "il": il, "rh": rh, "ih": ih}
+
+ @testing.fixture(
+ params=[
+ lambda **kw: Range(empty=True),
+ lambda **kw: Range(bounds="[)"),
+ lambda le, **kw: Range(upper=le, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[]"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="(]"),
+ lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="()"),
+ lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="[)"),
+ lambda il, ih, **kw: Range(lower=il, upper=ih, bounds="[)"),
+ lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="(]"),
+ lambda ll, rh, **kw: Range(lower=ll, upper=rh, bounds="[)"),
+ ]
+ )
+ def contains_range_obj_combinations(self, request):
+ """ranges that are used for range contains() contained_by() tests"""
+ data = self._value_values()
+
+ range_ = request.param(**data)
+ yield range_
+
+ @testing.fixture(
+ params=[
+ lambda l, r: Range(empty=True),
+ lambda l, r: Range(bounds="()"),
+ lambda l, r: Range(upper=r, bounds="(]"),
+ lambda l, r: Range(lower=l, bounds="[)"),
+ lambda l, r: Range(lower=l, upper=r, bounds="[)"),
+ lambda l, r: Range(lower=l, upper=r, bounds="[]"),
+ lambda l, r: Range(lower=l, upper=r, bounds="(]"),
+ lambda l, r: Range(lower=l, upper=r, bounds="()"),
+ ]
+ )
+ def bounds_obj_combinations(self, request):
+ """sample ranges used for value and range contains()/contained_by()
+ tests"""
+
+ obj = self._data_obj()
+ l, r = obj.lower, obj.upper
+
+ template = request.param
+ value = template(l=l, r=r)
+ yield value
+
+ @testing.fixture(params=["ll", "le", "il", "ih", "re_", "rh"])
+ def value_combinations(self, request):
+ """sample values used for value contains() tests"""
+ data = self._value_values()
+ return data[request.param]
+
+ def test_basic_py_sanity(self):
+ values = self._value_values()
+
+ range_ = self._data_obj()
+
+ is_true(range_.contains(Range(lower=values["il"], upper=values["ih"])))
+
+ is_true(
+ range_.contained_by(Range(lower=values["ll"], upper=values["rh"]))
+ )
+
+ is_true(range_.contains(values["il"]))
+
+ is_false(
+ range_.contains(Range(lower=values["ll"], upper=values["ih"]))
+ )
+
+ is_false(range_.contains(values["rh"]))
+
+ def test_contains_value(
+ self, connection, bounds_obj_combinations, value_combinations
+ ):
+ range_ = bounds_obj_combinations
+ range_typ = self._col_str
+
+ strvalue = range_._stringify()
+
+ v = value_combinations
+ RANGE = self._col_type
+
+ q = select(
+ literal_column(f"'{strvalue}'::{range_typ}", RANGE).label("r1"),
+ cast(range_, RANGE).label("r2"),
+ )
+ literal_range, cast_range = connection.execute(q).first()
+ eq_(literal_range, cast_range)
+
+ q = select(
+ cast(range_, RANGE),
+ cast(range_, RANGE).contains(v),
+ )
+ r, expected = connection.execute(q).first()
+ eq_(r.contains(v), expected)
+
+ def test_contains_range(
+ self,
+ connection,
+ bounds_obj_combinations,
+ contains_range_obj_combinations,
+ ):
+ r1repr = contains_range_obj_combinations._stringify()
+ r2repr = bounds_obj_combinations._stringify()
+
+ RANGE = self._col_type
+ range_typ = self._col_str
+
+ q = select(
+ cast(contains_range_obj_combinations, RANGE).label("r1"),
+ cast(bounds_obj_combinations, RANGE).label("r2"),
+ cast(contains_range_obj_combinations, RANGE).contains(
+ bounds_obj_combinations
+ ),
+ cast(contains_range_obj_combinations, RANGE).contained_by(
+ bounds_obj_combinations
+ ),
+ )
+ validate_q = select(
+ literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"),
+ literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"),
+ literal_column(
+ f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}"
+ ),
+ literal_column(
+ f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}"
+ ),
+ )
+ orig_row = connection.execute(q).first()
+ validate_row = connection.execute(validate_q).first()
+ eq_(orig_row, validate_row)
+
+ r1, r2, contains, contained = orig_row
+ eq_(r1.contains(r2), contains)
+ eq_(r1.contained_by(r2), contained)
+ eq_(r2.contains(r1), contained)
+
+
+class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest):
__requires__ = ("range_types",)
__backend__ = True
)
cls.col = table.c.range
+ def test_stringify(self):
+ eq_(str(self._data_obj()), self._data_str())
+
def test_auto_cast_back_to_type(self, connection):
"""test that a straight pass of the range type without any context
will send appropriate casting info so that the driver can round
_col_str = "INT4RANGE"
def _data_str(self):
- return "[1,2)"
+ return "[1,4)"
def _data_obj(self):
- return Range(1, 2)
+ return Range(1, 4)
+
+ def _step_value_up(self, value):
+ return value + 1
+
+ def _step_value_down(self, value):
+ return value - 1
class _Int8RangeTests:
_col_str = "INT8RANGE"
def _data_str(self):
- return "[9223372036854775806,9223372036854775807)"
+ return "[9223372036854775306,9223372036854775800)"
def _data_obj(self):
- return Range(9223372036854775806, 9223372036854775807)
+ return Range(9223372036854775306, 9223372036854775800)
+
+ def _step_value_up(self, value):
+ return value + 5
+
+ def _step_value_down(self, value):
+ return value - 5
class _NumRangeTests:
_col_str = "NUMRANGE"
def _data_str(self):
- return "[1.0,2.0)"
+ return "[1.0,9.0)"
def _data_obj(self):
- return Range(decimal.Decimal("1.0"), decimal.Decimal("2.0"))
+ return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0"))
+
+ def _step_value_up(self, value):
+ return value + decimal.Decimal("1.8")
+
+ def _step_value_down(self, value):
+ return value - decimal.Decimal("1.8")
class _DateRangeTests:
_col_str = "DATERANGE"
def _data_str(self):
- return "[2013-03-23,2013-03-24)"
+ return "[2013-03-23,2013-03-30)"
def _data_obj(self):
- return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24))
+ return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30))
+
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
class _DateTimeRangeTests:
_col_str = "TSRANGE"
def _data_str(self):
- return "[2013-03-23 14:30,2013-03-23 23:30)"
+ return "[2013-03-23 14:30:00,2013-03-30 23:30:00)"
def _data_obj(self):
return Range(
datetime.datetime(2013, 3, 23, 14, 30),
- datetime.datetime(2013, 3, 23, 23, 30),
+ datetime.datetime(2013, 3, 30, 23, 30),
)
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
+
class _DateTimeTZRangeTests:
_col_type = TSTZRANGE
_col_str = "TSTZRANGE"
- # make sure we use one, steady timestamp with timezone pair
- # for all parts of all these tests
- _tstzs = None
-
def tstzs(self):
- if self._tstzs is None:
- with testing.db.connect() as connection:
- lower = connection.scalar(func.current_timestamp().select())
- upper = lower + datetime.timedelta(1)
- self._tstzs = (lower, upper)
- return self._tstzs
+ tz = datetime.timezone(-datetime.timedelta(hours=5, minutes=30))
+
+ return (
+ datetime.datetime(2013, 3, 23, 14, 30, tzinfo=tz),
+ datetime.datetime(2013, 3, 30, 23, 30, tzinfo=tz),
+ )
def _data_str(self):
- return "[%s,%s)" % self.tstzs()
+ l, r = self.tstzs()
+ return f"[{l},{r})"
def _data_obj(self):
return Range(*self.tstzs())
+ def _step_value_up(self, value):
+ return value + datetime.timedelta(days=1)
+
+ def _step_value_down(self, value):
+ return value - datetime.timedelta(days=1)
+
class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation):
pass