from typing import TypeVar
from typing import Union
+from .operators import ADJACENT_TO
+from .operators import CONTAINED_BY
+from .operators import CONTAINS
+from .operators import NOT_EXTEND_LEFT_OF
+from .operators import NOT_EXTEND_RIGHT_OF
+from .operators import OVERLAP
+from .operators import STRICTLY_LEFT_OF
+from .operators import STRICTLY_RIGHT_OF
from ... import types as sqltypes
+from ...sql import operators
+from ...sql.type_api import TypeEngine
from ...util import py310
from ...util.typing import Literal
if TYPE_CHECKING:
from ...sql.elements import ColumnElement
from ...sql.type_api import _TE
- from ...sql.type_api import TypeEngine
from ...sql.type_api import TypeEngineMixin
_T = TypeVar("_T", bound=Any)
# empty Range, SQL datatype can't be determined here
return sqltypes.NULLTYPE
- class comparator_factory(sqltypes.Concatenable.Comparator[Range[Any]]):
+ class comparator_factory(TypeEngine.Comparator[Range[Any]]):
"""Define comparison operations for range types."""
- def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501
- "Boolean expression. Returns true if two ranges are not equal"
- if other is None:
- return super().__ne__(other) # type: ignore
- else:
- return self.expr.op("<>", is_comparison=True)(other) # type: ignore # noqa: E501
-
def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the right hand operand,
which can be an element or a range, is contained within the
kwargs may be ignored by this operator but are required for API
conformance.
"""
- return self.expr.op("@>", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(CONTAINS, other)
def contained_by(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the column is contained
within the right hand operand.
"""
- return self.expr.op("<@", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(CONTAINED_BY, other)
def overlaps(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the column overlaps
(has points in common with) the right hand operand.
"""
- return self.expr.op("&&", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(OVERLAP, other)
def strictly_left_of(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the column is strictly
left of the right hand operand.
"""
- return self.expr.op("<<", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(STRICTLY_LEFT_OF, other)
__lshift__ = strictly_left_of
"""Boolean expression. Returns true if the column is strictly
right of the right hand operand.
"""
- return self.expr.op(">>", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(STRICTLY_RIGHT_OF, other)
__rshift__ = strictly_right_of
"""Boolean expression. Returns true if the range in the column
does not extend right of the range in the operand.
"""
- return self.expr.op("&<", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(NOT_EXTEND_RIGHT_OF, other)
def not_extend_left_of(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the range in the column
does not extend left of the range in the operand.
"""
- return self.expr.op("&>", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(NOT_EXTEND_LEFT_OF, other)
def adjacent_to(self, other: Any) -> ColumnElement[bool]:
"""Boolean expression. Returns true if the range in the column
is adjacent to the range in the operand.
"""
- return self.expr.op("-|-", is_comparison=True)(other) # type: ignore # noqa: E501
+ return self.expr.operate(ADJACENT_TO, other)
def union(self, other: Any) -> ColumnElement[bool]:
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contiguous.
"""
- return self.expr.op("+")(other) # type: ignore
-
- __add__ = union
+ return self.expr.operate(operators.add, other)
def difference(self, other: Any) -> ColumnElement[bool]:
"""Range expression. Returns the union of the two ranges.
Will raise an exception if the resulting range is not
contiguous.
"""
- return self.expr.op("-")(other) # type: ignore
-
- __sub__ = difference
+ return self.expr.operate(operators.sub, other)
def intersection(self, other: Any) -> ColumnElement[Range[_T]]:
"""Range expression. Returns the intersection of the two ranges.
Will raise an exception if the resulting range is not
contiguous.
"""
- return self.expr.op("*")(other) # type: ignore
-
- __mul__ = intersection
+ return self.expr.operate(operators.mul, other)
class AbstractRangeImpl(AbstractRange[Range[_T]]):
checkparams={"param_1": 4, "param_3": 6, "param_2": 5},
)
+ def test_array_overlap_any(self):
+ col = column("x", postgresql.ARRAY(Integer))
+ self.assert_compile(
+ select(col.overlap(any_(array([4, 5, 6])))),
+ "SELECT x && ANY (ARRAY[%(param_1)s, %(param_2)s, %(param_3)s]) "
+ "AS anon_1",
+ checkparams={"param_1": 4, "param_3": 6, "param_2": 5},
+ )
+
+ def test_array_contains_any(self):
+ col = column("x", postgresql.ARRAY(Integer))
+ self.assert_compile(
+ select(col.contains(any_(array([4, 5, 6])))),
+ "SELECT x @> ANY (ARRAY[%(param_1)s, %(param_2)s, %(param_3)s]) "
+ "AS anon_1",
+ checkparams={"param_1": 4, "param_3": 6, "param_2": 5},
+ )
+
def test_array_slice_index(self):
col = column("x", postgresql.ARRAY(Integer))
self.assert_compile(
def test_where_has_key(self):
self._test_where(
- # hide from 2to3
- getattr(self.hashcol, "has_key")("foo"),
+ self.hashcol.has_key("foo"),
"test_table.hash ? %(hash_1)s",
)
"test_table.hash <@ %(hash_1)s",
)
+ def test_where_has_key_any(self):
+ self._test_where(
+ self.hashcol.has_key(any_(array(["foo"]))),
+ "test_table.hash ? ANY (ARRAY[%(param_1)s])",
+ )
+
+ def test_where_has_all_any(self):
+ self._test_where(
+ self.hashcol.has_all(any_(postgresql.array(["1", "2"]))),
+ "test_table.hash ?& ANY (ARRAY[%(param_1)s, %(param_2)s])",
+ )
+
+ def test_where_has_any_any(self):
+ self._test_where(
+ self.hashcol.has_any(any_(postgresql.array(["1", "2"]))),
+ "test_table.hash ?| ANY (ARRAY[%(param_1)s, %(param_2)s])",
+ )
+
+ def test_where_contains_any(self):
+ self._test_where(
+ self.hashcol.contains(any_(array(["foo"]))),
+ "test_table.hash @> ANY (ARRAY[%(param_1)s])",
+ )
+
+ def test_where_contained_by_any(self):
+ self._test_where(
+ self.hashcol.contained_by(any_(array(["foo"]))),
+ "test_table.hash <@ ANY (ARRAY[%(param_1)s])",
+ )
+
def test_where_getitem(self):
self._test_where(
self.hashcol["bar"] == None, # noqa
"(test_table.hash -> %(hash_1)s) IS NULL",
)
+ def test_where_getitem_any(self):
+ self._test_where(
+ self.hashcol["bar"] == any_(array(["foo"])), # noqa
+ "(test_table.hash -> %(hash_1)s) = ANY (ARRAY[%(param_1)s])",
+ )
+
@testing.combinations(
(
lambda self: self.hashcol["foo"],
):
__dialect__ = "postgresql"
+ @property
+ def _col_str_arr(self):
+ return self._col_str
+
# operator tests
@classmethod
self.assert_compile(colclause, expected)
is_(colclause.type._type_affinity, type_._type_affinity)
- def test_where_equal(self):
- self._test_clause(
- self.col == self._data_str(),
- "data_table.range = %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_where_equal_obj(self):
+ _comparisons = [
+ (lambda col, other: col == other, "="),
+ (lambda col, other: col != other, "!="),
+ (lambda col, other: col > other, ">"),
+ (lambda col, other: col < other, "<"),
+ (lambda col, other: col >= other, ">="),
+ (lambda col, other: col <= other, "<="),
+ (lambda col, other: col.contains(other), "@>"),
+ (lambda col, other: col.contained_by(other), "<@"),
+ (lambda col, other: col.overlaps(other), "&&"),
+ (lambda col, other: col << other, "<<"),
+ (lambda col, other: col.strictly_left_of(other), "<<"),
+ (lambda col, other: col >> other, ">>"),
+ (lambda col, other: col.strictly_right_of(other), ">>"),
+ (lambda col, other: col.not_extend_left_of(other), "&>"),
+ (lambda col, other: col.not_extend_right_of(other), "&<"),
+ (lambda col, other: col.adjacent_to(other), "-|-"),
+ ]
+
+ _operations = [
+ (lambda col, other: col + other, "+"),
+ (lambda col, other: col.union(other), "+"),
+ (lambda col, other: col - other, "-"),
+ (lambda col, other: col.difference(other), "-"),
+ (lambda col, other: col * other, "*"),
+ (lambda col, other: col.intersection(other), "*"),
+ ]
+
+ _all_fns = _comparisons + _operations
+
+ _not_compare_op = ("+", "-", "*")
+
+ @testing.combinations(*_all_fns, id_="as")
+ def test_data_str(self, fn, op):
self._test_clause(
- self.col == self._data_obj(),
- f"data_table.range = %(range_1)s::{self._col_str}",
- sqltypes.BOOLEANTYPE,
+ fn(self.col, self._data_str()),
+ f"data_table.range {op} %(range_1)s",
+ self.col.type
+ if op in self._not_compare_op
+ else sqltypes.BOOLEANTYPE,
)
- def test_where_not_equal(self):
+ @testing.combinations(*_all_fns, id_="as")
+ def test_data_obj(self, fn, op):
self._test_clause(
- self.col != self._data_str(),
- "data_table.range <> %(range_1)s",
- sqltypes.BOOLEANTYPE,
+ fn(self.col, self._data_obj()),
+ f"data_table.range {op} %(range_1)s::{self._col_str}",
+ self.col.type
+ if op in self._not_compare_op
+ else sqltypes.BOOLEANTYPE,
)
- def test_where_not_equal_obj(self):
+ @testing.combinations(*_comparisons, id_="as")
+ def test_data_str_any(self, fn, op):
self._test_clause(
- self.col != self._data_obj(),
- f"data_table.range <> %(range_1)s::{self._col_str}",
- sqltypes.BOOLEANTYPE,
+ fn(self.col, any_(array([self._data_str()]))),
+ f"data_table.range {op} ANY (ARRAY[%(param_1)s])",
+ self.col.type
+ if op in self._not_compare_op
+ else sqltypes.BOOLEANTYPE,
)
def test_where_is_null(self):
sqltypes.BOOLEANTYPE,
)
- def test_where_less_than(self):
- self._test_clause(
- self.col < self._data_str(),
- "data_table.range < %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_where_greater_than(self):
- self._test_clause(
- self.col > self._data_str(),
- "data_table.range > %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_where_less_than_or_equal(self):
- self._test_clause(
- self.col <= self._data_str(),
- "data_table.range <= %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_where_greater_than_or_equal(self):
- self._test_clause(
- self.col >= self._data_str(),
- "data_table.range >= %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_contains(self):
- self._test_clause(
- self.col.contains(self._data_str()),
- "data_table.range @> %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_contains_obj(self):
- self._test_clause(
- self.col.contains(self._data_obj()),
- f"data_table.range @> %(range_1)s::{self._col_str}",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_contained_by(self):
- self._test_clause(
- self.col.contained_by(self._data_str()),
- "data_table.range <@ %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_overlaps(self):
- self._test_clause(
- self.col.overlaps(self._data_str()),
- "data_table.range && %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_strictly_left_of(self):
- self._test_clause(
- self.col << self._data_str(),
- "data_table.range << %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
- self._test_clause(
- self.col.strictly_left_of(self._data_str()),
- "data_table.range << %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_strictly_right_of(self):
- self._test_clause(
- self.col >> self._data_str(),
- "data_table.range >> %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
- self._test_clause(
- self.col.strictly_right_of(self._data_str()),
- "data_table.range >> %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_not_extend_right_of(self):
- self._test_clause(
- self.col.not_extend_right_of(self._data_str()),
- "data_table.range &< %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_not_extend_left_of(self):
- self._test_clause(
- self.col.not_extend_left_of(self._data_str()),
- "data_table.range &> %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_adjacent_to(self):
- self._test_clause(
- self.col.adjacent_to(self._data_str()),
- "data_table.range -|- %(range_1)s",
- sqltypes.BOOLEANTYPE,
- )
-
- def test_union(self):
- self._test_clause(
- self.col + self.col,
- "data_table.range + data_table.range",
- self.col.type,
- )
-
- self._test_clause(
- self.col.union(self._data_str()),
- "data_table.range + %(range_1)s",
- self.col.type,
- )
-
- def test_intersection(self):
- self._test_clause(
- self.col * self.col,
- "data_table.range * data_table.range",
- self.col.type,
- )
-
- def test_difference(self):
- self._test_clause(
- self.col - self.col,
- "data_table.range - data_table.range",
- self.col.type,
- )
-
- self._test_clause(
- self.col.difference(self._data_str()),
- "data_table.range - %(range_1)s",
- self.col.type,
- )
-
class _RangeComparisonFixtures(_RangeTests):
def _step_value_up(self, value):
_col_type = INT4RANGE
_col_str = "INT4RANGE"
+ _col_str_arr = "INT8RANGE"
def _data_str(self):
return "[1,4)"
def test_where_not_equal(self):
self._test_clause(
self.col != self._data_str(),
- "data_table.multirange <> %(multirange_1)s",
+ "data_table.multirange != %(multirange_1)s",
sqltypes.BOOLEANTYPE,
)
def test_where_not_equal_obj(self):
self._test_clause(
self.col != self._data_obj(),
- f"data_table.multirange <> %(multirange_1)s::{self._col_str}",
+ f"data_table.multirange != %(multirange_1)s::{self._col_str}",
sqltypes.BOOLEANTYPE,
)
)
self.jsoncol = self.test_table.c.test_column
+ @property
+ def any_(self):
+ return any_(array([7]))
+
@testing.combinations(
(
lambda self: self.jsoncol["bar"] == None, # noqa
"(test_table.test_column -> %(test_column_1)s) IS NULL",
),
+ (
+ lambda self: self.jsoncol["bar"] != None, # noqa
+ "(test_table.test_column -> %(test_column_1)s) IS NOT NULL",
+ ),
(
lambda self: self.jsoncol[("foo", 1)] == None, # noqa
"(test_table.test_column #> %(test_column_1)s) IS NULL",
),
+ (
+ lambda self: self.jsoncol[("foo", 1)] != None, # noqa
+ "(test_table.test_column #> %(test_column_1)s) IS NOT NULL",
+ ),
(
lambda self: self.jsoncol["bar"].astext == None, # noqa
"(test_table.test_column ->> %(test_column_1)s) IS NULL",
lambda self: self.jsoncol[("foo", 1)].astext == None, # noqa
"(test_table.test_column #>> %(test_column_1)s) IS NULL",
),
+ (
+ lambda self: self.jsoncol["bar"] == 42,
+ "(test_table.test_column -> %(test_column_1)s) = %(param_1)s",
+ ),
+ (
+ lambda self: self.jsoncol["bar"] != 42,
+ "(test_table.test_column -> %(test_column_1)s) != %(param_1)s",
+ ),
+ (
+ lambda self: self.jsoncol["bar"] == self.any_,
+ "(test_table.test_column -> %(test_column_1)s) = "
+ "ANY (ARRAY[%(param_1)s])",
+ ),
+ (
+ lambda self: self.jsoncol["bar"] != self.any_,
+ "(test_table.test_column -> %(test_column_1)s) != "
+ "ANY (ARRAY[%(param_1)s])",
+ ),
+ (
+ lambda self: self.jsoncol["bar"].astext == self.any_,
+ "(test_table.test_column ->> %(test_column_1)s) = "
+ "ANY (ARRAY[%(param_1)s])",
+ ),
+ (
+ lambda self: self.jsoncol["bar"].astext != self.any_,
+ "(test_table.test_column ->> %(test_column_1)s) != "
+ "ANY (ARRAY[%(param_1)s])",
+ ),
+ (
+ lambda self: self.jsoncol[("foo", 1)] == self.any_,
+ "(test_table.test_column #> %(test_column_1)s) = "
+ "ANY (ARRAY[%(param_1)s])",
+ ),
+ (
+ lambda self: self.jsoncol[("foo", 1)] != self.any_,
+ "(test_table.test_column #> %(test_column_1)s) != "
+ "ANY (ARRAY[%(param_1)s])",
+ ),
+ id_="as",
)
def test_where(self, whereclause_fn, expected):
whereclause = whereclause_fn(self)
@testing.combinations(
(
- # hide from 2to3
- lambda self: getattr(self.jsoncol, "has_key")("data"),
+ lambda self: self.jsoncol.has_key("data"),
"test_table.test_column ? %(test_column_1)s",
),
+ (
+ lambda self: self.jsoncol.has_key(self.any_),
+ "test_table.test_column ? ANY (ARRAY[%(param_1)s])",
+ ),
(
lambda self: self.jsoncol.has_all(
{"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}}
),
"test_table.test_column ?& %(test_column_1)s",
),
+ (
+ lambda self: self.jsoncol.has_all(self.any_),
+ "test_table.test_column ?& ANY (ARRAY[%(param_1)s])",
+ ),
(
lambda self: self.jsoncol.has_any(
postgresql.array(["name", "data"])
),
"test_table.test_column ?| ARRAY[%(param_1)s, %(param_2)s]",
),
+ (
+ lambda self: self.jsoncol.has_any(self.any_),
+ "test_table.test_column ?| ANY (ARRAY[%(param_1)s])",
+ ),
(
lambda self: self.jsoncol.contains({"k1": "r1v1"}),
"test_table.test_column @> %(test_column_1)s",
),
+ (
+ lambda self: self.jsoncol.contains(self.any_),
+ "test_table.test_column @> ANY (ARRAY[%(param_1)s])",
+ ),
(
lambda self: self.jsoncol.contained_by({"foo": "1", "bar": None}),
"test_table.test_column <@ %(test_column_1)s",
),
+ (
+ lambda self: self.jsoncol.contained_by(self.any_),
+ "test_table.test_column <@ ANY (ARRAY[%(param_1)s])",
+ ),
(
lambda self: self.jsoncol.delete_path(["a", "b"]),
"test_table.test_column #- CAST(ARRAY[%(param_1)s, "
lambda self: self.jsoncol.path_exists("$.k1"),
"test_table.test_column @? %(test_column_1)s",
),
+ (
+ lambda self: self.jsoncol.path_exists(self.any_),
+ "test_table.test_column @? ANY (ARRAY[%(param_1)s])",
+ ),
(
lambda self: self.jsoncol.path_match("$.k1[0] > 2"),
"test_table.test_column @@ %(test_column_1)s",
),
+ (
+ lambda self: self.jsoncol.path_match(self.any_),
+ "test_table.test_column @@ ANY (ARRAY[%(param_1)s])",
+ ),
+ id_="as",
)
- def test_where(self, whereclause_fn, expected):
+ def test_where_jsonb(self, whereclause_fn, expected):
super().test_where(whereclause_fn, expected)