From: Lele Gaifax Date: Wed, 2 Nov 2022 12:33:41 +0000 (-0400) Subject: Implement contains_value(), issubset() and issuperset() on PG Range X-Git-Tag: rel_2_0_0b3~10^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=e8124b29b07fd17ab2f2b6892534dcc4b0797ab4;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Implement contains_value(), issubset() and issuperset() on PG Range Added new methods :meth:`_postgresql.Range.contains` and :meth:`_postgresql.Range.contained_by` to the new :class:`.Range` data object, which mirror the behavior of the PostgreSQL ``@>`` and ``<@`` operators, as well as the :meth:`_postgresql.AbstractRange.comparator_factory.contains` and :meth:`_postgresql.AbstractRange.comparator_factory.contained_by` SQL operator methods. Pull request courtesy Lele Gaifax. Fixes: #8706 Closes: #8707 Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/8707 Pull-request-sha: 3a74a0d93e63032ebee02992977498c717a077ff Change-Id: Ief81ca5c31448640b26dfbc3defd4dde1d51e366 --- diff --git a/doc/build/changelog/unreleased_20/8706.rst b/doc/build/changelog/unreleased_20/8706.rst new file mode 100644 index 0000000000..a6f3321b6d --- /dev/null +++ b/doc/build/changelog/unreleased_20/8706.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: feature, postgresql + :tickets: 8706 + + Added new methods :meth:`_postgresql.Range.contains` and + :meth:`_postgresql.Range.contained_by` to the new :class:`.Range` data + object, which mirror the behavior of the PostgreSQL ``@>`` and ``<@`` + operators, as well as the + :meth:`_postgresql.AbstractRange.comparator_factory.contains` and + :meth:`_postgresql.AbstractRange.comparator_factory.contained_by` SQL + operator methods. Pull request courtesy Lele Gaifax. diff --git a/doc/build/changelog/whatsnew_20.rst b/doc/build/changelog/whatsnew_20.rst index 98865f233d..3d4eca6b2b 100644 --- a/doc/build/changelog/whatsnew_20.rst +++ b/doc/build/changelog/whatsnew_20.rst @@ -1883,6 +1883,12 @@ objects are used. Code that used the previous psycopg2-specific types should be modified to use :class:`_postgresql.Range`, which presents a compatible interface. +The :class:`_postgresql.Range` object also features comparison support which +mirrors that of PostgreSQL. Implemented so far are :meth:`_postgresql.Range.contains` +and :meth:`_postgresql.Range.contained_by` methods which work in the same way as +the PostgreSQL ``@>`` and ``<@``. Additional operator support may be added +in future releases. + See the documentation at :ref:`postgresql_ranges` for background on using the new feature. @@ -1891,6 +1897,9 @@ using the new feature. :ref:`postgresql_ranges` +:ticket:`7156` +:ticket:`8706` + .. _change_7086: ``match()`` operator on PostgreSQL uses ``plainto_tsquery()`` rather than ``to_tsquery()`` diff --git a/doc/build/dialects/postgresql.rst b/doc/build/dialects/postgresql.rst index 411037f87c..ef853b6832 100644 --- a/doc/build/dialects/postgresql.rst +++ b/doc/build/dialects/postgresql.rst @@ -216,6 +216,7 @@ The available range datatypes are as follows: * :class:`_postgresql.TSTZRANGE` .. autoclass:: sqlalchemy.dialects.postgresql.Range + :members: Multiranges ^^^^^^^^^^^ @@ -350,6 +351,12 @@ construction arguments, are as follows: .. currentmodule:: sqlalchemy.dialects.postgresql +.. autoclass:: sqlalchemy.dialects.postgresql.AbstractRange + :members: comparator_factory + +.. autoclass:: sqlalchemy.dialects.postgresql.AbstractMultiRange + :members: comparator_factory + .. autoclass:: aggregate_order_by .. autoclass:: array diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 8dbee1f7f4..7890541ffd 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -48,6 +48,8 @@ from .named_types import DropDomainType from .named_types import DropEnumType from .named_types import ENUM from .named_types import NamedType +from .ranges import AbstractMultiRange +from .ranges import AbstractRange from .ranges import DATEMULTIRANGE from .ranges import DATERANGE from .ranges import INT4MULTIRANGE diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 327feb4092..6729f3785f 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -8,10 +8,14 @@ from __future__ import annotations import dataclasses +from datetime import date +from datetime import datetime +from datetime import timedelta from typing import Any from typing import Generic from typing import Optional from typing import TypeVar +from typing import Union from ... import types as sqltypes from ...util import py310 @@ -80,6 +84,204 @@ class Range(Generic[_T]): def __bool__(self) -> bool: return self.empty + def _contains_value(self, value: _T) -> bool: + "Check whether this range contains the given `value`." + + if self.empty: + return False + + if self.lower is None: + return self.upper is None or ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + if self.upper is None: + return ( + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) + + return ( + value > self.lower + if self.bounds[0] == "(" + else value >= self.lower + ) and ( + value < self.upper + if self.bounds[1] == ")" + else value <= self.upper + ) + + def _get_discrete_step(self): + "Determine the “step” for this range, if it is a discrete one." + + # See + # https://www.postgresql.org/docs/current/rangetypes.html#RANGETYPES-DISCRETE + # for the rationale + + if isinstance(self.lower, int) or isinstance(self.upper, int): + return 1 + elif isinstance(self.lower, datetime) or isinstance( + self.upper, datetime + ): + # This is required, because a `isinstance(datetime.now(), date)` + # is True + return None + elif isinstance(self.lower, date) or isinstance(self.upper, date): + return timedelta(days=1) + else: + return None + + def contained_by(self, other: Range) -> bool: + "Determine whether this range is a contained by `other`." + + # Any range contains the empty one + if self.empty: + return True + + # An empty range does not contain any range except the empty one + if other.empty: + return False + + olower = other.lower + oupper = other.upper + + # A bilateral unbound range contains any other range + if olower is oupper is None: + return True + + slower = self.lower + supper = self.upper + + # A lower-bound range cannot contain a lower-unbound range + if slower is None and olower is not None: + return False + + # Likewise on the right side + if supper is None and oupper is not None: + return False + + slower_inc = self.bounds[0] == "[" + supper_inc = self.bounds[1] == "]" + olower_inc = other.bounds[0] == "[" + oupper_inc = other.bounds[1] == "]" + + # Check the lower end + step = -1 + if slower is not None and olower is not None: + lside = olower < slower + if not lside: + if not slower_inc or olower_inc: + lside = olower == slower + if not lside: + # Cover (1,x] vs [2,x) and (0,x] vs [1,x) + if not slower_inc and olower_inc and slower < olower: + step = self._get_discrete_step() + if step is not None: + lside = olower == (slower + step) + elif slower_inc and not olower_inc and slower > olower: + step = self._get_discrete_step() + if step is not None: + lside = (olower + step) == slower + if not lside: + return False + + # Lower end already considered, an upper-unbound range surely contains + # this + if oupper is None: + return True + + # Check the upper end + uside = oupper > supper + if not uside: + if not supper_inc or oupper_inc: + uside = oupper == supper + if not uside: + # Cover (x,2] vs [x,3) and (x,1] vs [x,2) + if supper_inc and not oupper_inc and supper < oupper: + if step == -1: + step = self._get_discrete_step() + if step is not None: + uside = oupper == (supper + step) + elif not supper_inc and oupper_inc and supper > oupper: + if step == -1: + step = self._get_discrete_step() + if step is not None: + uside = (oupper + step) == supper + return uside + + def contains(self, value: Union[_T, Range]) -> bool: + "Determine whether this range contains `value`." + + if isinstance(value, Range): + return value.contained_by(self) + else: + return self._contains_value(value) + + def overlaps(self, other): + """Boolean expression. Returns true if the column overlaps + (has points in common with) the right hand operand. + """ + raise NotImplementedError("not yet implemented") + + def strictly_left_of(self, other): + """Boolean expression. Returns true if the column is strictly + left of the right hand operand. + """ + raise NotImplementedError("not yet implemented") + + __lshift__ = strictly_left_of + + def strictly_right_of(self, other): + """Boolean expression. Returns true if the column is strictly + right of the right hand operand. + """ + raise NotImplementedError("not yet implemented") + + __rshift__ = strictly_right_of + + def not_extend_right_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend right of the range in the operand. + """ + raise NotImplementedError("not yet implemented") + + def not_extend_left_of(self, other): + """Boolean expression. Returns true if the range in the column + does not extend left of the range in the operand. + """ + raise NotImplementedError("not yet implemented") + + def adjacent_to(self, other): + """Boolean expression. Returns true if the range in the column + is adjacent to the range in the operand. + """ + raise NotImplementedError("not yet implemented") + + def __add__(self, other): + """Range expression. Returns the union of the two ranges. + Will raise an exception if the resulting range is not + contiguous. + """ + raise NotImplementedError("not yet implemented") + + def __str__(self): + return self._stringify() + + def _stringify(self): + if self.empty: + return "empty" + + l, r = self.lower, self.upper + l = "" if l is None else l + r = "" if r is None else r + + b0, b1 = self.bounds + + return f"{b0}{l},{r}{b1}" + class AbstractRange(sqltypes.TypeEngine): """ @@ -93,6 +295,8 @@ class AbstractRange(sqltypes.TypeEngine): render_bind_cast = True + __abstract__ = True + def adapt(self, impltype): """dynamically adapt a range type to an abstract impl. @@ -202,6 +406,8 @@ class AbstractRangeImpl(AbstractRange): class AbstractMultiRange(AbstractRange): """base for PostgreSQL MULTIRANGE types""" + __abstract__ = True + class AbstractMultiRangeImpl(AbstractRangeImpl, AbstractMultiRange): """marker for AbstractRange that will apply a subclass-specific diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 91eada9a81..83cea8f15a 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -20,6 +20,7 @@ from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer from sqlalchemy import literal +from sqlalchemy import literal_column from sqlalchemy import MetaData from sqlalchemy import null from sqlalchemy import Numeric @@ -66,6 +67,8 @@ from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_false +from sqlalchemy.testing import is_true from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message from sqlalchemy.testing.assertions import AssertsCompiledSQL @@ -3846,7 +3849,194 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): ) -class _RangeTypeRoundTrip(fixtures.TablesTest): +class _RangeComparisonFixtures: + def _data_str(self): + """return string form of a sample range""" + raise NotImplementedError() + + def _data_obj(self): + """return Range form of the same range""" + raise NotImplementedError() + + def _step_value_up(self, value): + """given a value, return a step up + + this is a value that given the lower end of the sample range, + would be less than the upper value of the range + + """ + raise NotImplementedError() + + def _step_value_down(self, value): + """given a value, return a step down + + this is a value that given the upper end of the sample range, + would be greater than the lower value of the range + + """ + raise NotImplementedError() + + def _value_values(self): + """Return a series of values related to the base range + + le = left equal + ll = lower than left + re = right equal + rh = higher than right + il = inside lower + ih = inside higher + + """ + spec = self._data_obj() + + le, re_ = spec.lower, spec.upper + + ll = self._step_value_down(le) + il = self._step_value_up(le) + rh = self._step_value_up(re_) + ih = self._step_value_down(re_) + + return {"le": le, "re_": re_, "ll": ll, "il": il, "rh": rh, "ih": ih} + + @testing.fixture( + params=[ + lambda **kw: Range(empty=True), + lambda **kw: Range(bounds="[)"), + lambda le, **kw: Range(upper=le, bounds="[)"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[)"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="[]"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="(]"), + lambda le, re_, **kw: Range(lower=le, upper=re_, bounds="()"), + lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="[)"), + lambda il, ih, **kw: Range(lower=il, upper=ih, bounds="[)"), + lambda ll, le, **kw: Range(lower=ll, upper=le, bounds="(]"), + lambda ll, rh, **kw: Range(lower=ll, upper=rh, bounds="[)"), + ] + ) + def contains_range_obj_combinations(self, request): + """ranges that are used for range contains() contained_by() tests""" + data = self._value_values() + + range_ = request.param(**data) + yield range_ + + @testing.fixture( + params=[ + lambda l, r: Range(empty=True), + lambda l, r: Range(bounds="()"), + lambda l, r: Range(upper=r, bounds="(]"), + lambda l, r: Range(lower=l, bounds="[)"), + lambda l, r: Range(lower=l, upper=r, bounds="[)"), + lambda l, r: Range(lower=l, upper=r, bounds="[]"), + lambda l, r: Range(lower=l, upper=r, bounds="(]"), + lambda l, r: Range(lower=l, upper=r, bounds="()"), + ] + ) + def bounds_obj_combinations(self, request): + """sample ranges used for value and range contains()/contained_by() + tests""" + + obj = self._data_obj() + l, r = obj.lower, obj.upper + + template = request.param + value = template(l=l, r=r) + yield value + + @testing.fixture(params=["ll", "le", "il", "ih", "re_", "rh"]) + def value_combinations(self, request): + """sample values used for value contains() tests""" + data = self._value_values() + return data[request.param] + + def test_basic_py_sanity(self): + values = self._value_values() + + range_ = self._data_obj() + + is_true(range_.contains(Range(lower=values["il"], upper=values["ih"]))) + + is_true( + range_.contained_by(Range(lower=values["ll"], upper=values["rh"])) + ) + + is_true(range_.contains(values["il"])) + + is_false( + range_.contains(Range(lower=values["ll"], upper=values["ih"])) + ) + + is_false(range_.contains(values["rh"])) + + def test_contains_value( + self, connection, bounds_obj_combinations, value_combinations + ): + range_ = bounds_obj_combinations + range_typ = self._col_str + + strvalue = range_._stringify() + + v = value_combinations + RANGE = self._col_type + + q = select( + literal_column(f"'{strvalue}'::{range_typ}", RANGE).label("r1"), + cast(range_, RANGE).label("r2"), + ) + literal_range, cast_range = connection.execute(q).first() + eq_(literal_range, cast_range) + + q = select( + cast(range_, RANGE), + cast(range_, RANGE).contains(v), + ) + r, expected = connection.execute(q).first() + eq_(r.contains(v), expected) + + def test_contains_range( + self, + connection, + bounds_obj_combinations, + contains_range_obj_combinations, + ): + r1repr = contains_range_obj_combinations._stringify() + r2repr = bounds_obj_combinations._stringify() + + RANGE = self._col_type + range_typ = self._col_str + + q = select( + cast(contains_range_obj_combinations, RANGE).label("r1"), + cast(bounds_obj_combinations, RANGE).label("r2"), + cast(contains_range_obj_combinations, RANGE).contains( + bounds_obj_combinations + ), + cast(contains_range_obj_combinations, RANGE).contained_by( + bounds_obj_combinations + ), + ) + validate_q = select( + literal_column(f"'{r1repr}'::{range_typ}", RANGE).label("r1"), + literal_column(f"'{r2repr}'::{range_typ}", RANGE).label("r2"), + literal_column( + f"'{r1repr}'::{range_typ} @> '{r2repr}'::{range_typ}" + ), + literal_column( + f"'{r1repr}'::{range_typ} <@ '{r2repr}'::{range_typ}" + ), + ) + orig_row = connection.execute(q).first() + validate_row = connection.execute(validate_q).first() + eq_(orig_row, validate_row) + + r1, r2, contains, contained = orig_row + eq_(r1.contains(r2), contains) + eq_(r1.contained_by(r2), contained) + eq_(r2.contains(r1), contained) + + +class _RangeTypeRoundTrip(_RangeComparisonFixtures, fixtures.TablesTest): __requires__ = ("range_types",) __backend__ = True @@ -3861,6 +4051,9 @@ class _RangeTypeRoundTrip(fixtures.TablesTest): ) cls.col = table.c.range + def test_stringify(self): + eq_(str(self._data_obj()), self._data_str()) + def test_auto_cast_back_to_type(self, connection): """test that a straight pass of the range type without any context will send appropriate casting info so that the driver can round @@ -3995,10 +4188,16 @@ class _Int4RangeTests: _col_str = "INT4RANGE" def _data_str(self): - return "[1,2)" + return "[1,4)" def _data_obj(self): - return Range(1, 2) + return Range(1, 4) + + def _step_value_up(self, value): + return value + 1 + + def _step_value_down(self, value): + return value - 1 class _Int8RangeTests: @@ -4007,10 +4206,16 @@ class _Int8RangeTests: _col_str = "INT8RANGE" def _data_str(self): - return "[9223372036854775806,9223372036854775807)" + return "[9223372036854775306,9223372036854775800)" def _data_obj(self): - return Range(9223372036854775806, 9223372036854775807) + return Range(9223372036854775306, 9223372036854775800) + + def _step_value_up(self, value): + return value + 5 + + def _step_value_down(self, value): + return value - 5 class _NumRangeTests: @@ -4019,10 +4224,16 @@ class _NumRangeTests: _col_str = "NUMRANGE" def _data_str(self): - return "[1.0,2.0)" + return "[1.0,9.0)" def _data_obj(self): - return Range(decimal.Decimal("1.0"), decimal.Decimal("2.0")) + return Range(decimal.Decimal("1.0"), decimal.Decimal("9.0")) + + def _step_value_up(self, value): + return value + decimal.Decimal("1.8") + + def _step_value_down(self, value): + return value - decimal.Decimal("1.8") class _DateRangeTests: @@ -4031,10 +4242,16 @@ class _DateRangeTests: _col_str = "DATERANGE" def _data_str(self): - return "[2013-03-23,2013-03-24)" + return "[2013-03-23,2013-03-30)" def _data_obj(self): - return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 24)) + return Range(datetime.date(2013, 3, 23), datetime.date(2013, 3, 30)) + + def _step_value_up(self, value): + return value + datetime.timedelta(days=1) + + def _step_value_down(self, value): + return value - datetime.timedelta(days=1) class _DateTimeRangeTests: @@ -4043,38 +4260,47 @@ class _DateTimeRangeTests: _col_str = "TSRANGE" def _data_str(self): - return "[2013-03-23 14:30,2013-03-23 23:30)" + return "[2013-03-23 14:30:00,2013-03-30 23:30:00)" def _data_obj(self): return Range( datetime.datetime(2013, 3, 23, 14, 30), - datetime.datetime(2013, 3, 23, 23, 30), + datetime.datetime(2013, 3, 30, 23, 30), ) + def _step_value_up(self, value): + return value + datetime.timedelta(days=1) + + def _step_value_down(self, value): + return value - datetime.timedelta(days=1) + class _DateTimeTZRangeTests: _col_type = TSTZRANGE _col_str = "TSTZRANGE" - # make sure we use one, steady timestamp with timezone pair - # for all parts of all these tests - _tstzs = None - def tstzs(self): - if self._tstzs is None: - with testing.db.connect() as connection: - lower = connection.scalar(func.current_timestamp().select()) - upper = lower + datetime.timedelta(1) - self._tstzs = (lower, upper) - return self._tstzs + tz = datetime.timezone(-datetime.timedelta(hours=5, minutes=30)) + + return ( + datetime.datetime(2013, 3, 23, 14, 30, tzinfo=tz), + datetime.datetime(2013, 3, 30, 23, 30, tzinfo=tz), + ) def _data_str(self): - return "[%s,%s)" % self.tstzs() + l, r = self.tstzs() + return f"[{l},{r})" def _data_obj(self): return Range(*self.tstzs()) + def _step_value_up(self, value): + return value + datetime.timedelta(days=1) + + def _step_value_down(self, value): + return value - datetime.timedelta(days=1) + class Int4RangeCompilationTest(_Int4RangeTests, _RangeTypeCompilation): pass