From: Federico Caselli Date: Fri, 26 May 2023 22:36:04 +0000 (+0200) Subject: Improve PostgreSQL custom operators X-Git-Tag: rel_2_0_16~23^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=cf4dec5ac476cf9a2179f9f2b16d46fe95d1f18d;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Improve PostgreSQL custom operators Unified the custom PostgreSQL operator definitions, since they are shared among multiple different data types. Use proper precedence on PostgreSQL specific operators, such as ``@>``. Previously the precedence was wrong, leasing to wrong parenthesis when rending against and ``ANY`` or ``ALL`` construct. Fixes: #9041 Fixes: #9836 Change-Id: I1c1d8b4c2d58d53c51c2e6d4934ac1ed83bda5d3 --- diff --git a/doc/build/changelog/unreleased_20/9041.rst b/doc/build/changelog/unreleased_20/9041.rst new file mode 100644 index 0000000000..80cff6f018 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9041.rst @@ -0,0 +1,6 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 9041 + + Unified the custom PostgreSQL operator definitions, since they are + shared among multiple different data types. diff --git a/doc/build/changelog/unreleased_20/9836.rst b/doc/build/changelog/unreleased_20/9836.rst new file mode 100644 index 0000000000..b6ad9b7bc9 --- /dev/null +++ b/doc/build/changelog/unreleased_20/9836.rst @@ -0,0 +1,7 @@ +.. change:: + :tags: bug, postgresql + :tickets: 9836 + + Use proper precedence on PostgreSQL specific operators, such as ``@>``. + Previously the precedence was wrong, leading to wrong parenthesis when + rendering against and ``ANY`` or ``ALL`` construct. diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index ba53bf665a..bbfcecdc91 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -14,6 +14,9 @@ from typing import Any from typing import Optional from typing import TypeVar +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import OVERLAP from ... import types as sqltypes from ... import util from ...sql import expression @@ -155,13 +158,6 @@ class array(expression.ExpressionClauseList[_T]): return self -CONTAINS = operators.custom_op("@>", precedence=5, is_comparison=True) - -CONTAINED_BY = operators.custom_op("<@", precedence=5, is_comparison=True) - -OVERLAP = operators.custom_op("&&", precedence=5, is_comparison=True) - - class ARRAY(sqltypes.ARRAY): """PostgreSQL ARRAY type. diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index dc7e4d40da..83c4932a6e 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -10,57 +10,18 @@ import re from .array import ARRAY +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import GETITEM +from .operators import HAS_ALL +from .operators import HAS_ANY +from .operators import HAS_KEY from ... import types as sqltypes from ...sql import functions as sqlfunc -from ...sql import operators __all__ = ("HSTORE", "hstore") -idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] - -GETITEM = operators.custom_op( - "->", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -HAS_KEY = operators.custom_op( - "?", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -HAS_ALL = operators.custom_op( - "?&", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -HAS_ANY = operators.custom_op( - "?|", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -CONTAINS = operators.custom_op( - "@>", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -CONTAINED_BY = operators.custom_op( - "<@", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): """Represent the PostgreSQL HSTORE type. diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index 232f058042..ee56a74504 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -9,86 +9,21 @@ from .array import ARRAY from .array import array as _pg_array +from .operators import ASTEXT +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import DELETE_PATH +from .operators import HAS_ALL +from .operators import HAS_ANY +from .operators import HAS_KEY +from .operators import JSONPATH_ASTEXT +from .operators import PATH_EXISTS +from .operators import PATH_MATCH from ... import types as sqltypes from ...sql import cast -from ...sql import operators - __all__ = ("JSON", "JSONB") -idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] - -ASTEXT = operators.custom_op( - "->>", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -JSONPATH_ASTEXT = operators.custom_op( - "#>>", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - - -HAS_KEY = operators.custom_op( - "?", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -HAS_ALL = operators.custom_op( - "?&", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -HAS_ANY = operators.custom_op( - "?|", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -CONTAINS = operators.custom_op( - "@>", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -CONTAINED_BY = operators.custom_op( - "<@", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -DELETE_PATH = operators.custom_op( - "#-", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -PATH_EXISTS = operators.custom_op( - "@?", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - -PATH_MATCH = operators.custom_op( - "@@", - precedence=idx_precedence, - natural_self_precedent=True, - eager_grouping=True, -) - class JSONPathType(sqltypes.JSON.JSONPathType): def _processor(self, dialect, super_proc): diff --git a/lib/sqlalchemy/dialects/postgresql/operators.py b/lib/sqlalchemy/dialects/postgresql/operators.py new file mode 100644 index 0000000000..f393451c6e --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/operators.py @@ -0,0 +1,129 @@ +# postgresql/operators.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors +from ...sql import operators + + +_getitem_precedence = operators._PRECEDENCE[operators.json_getitem_op] +_eq_precedence = operators._PRECEDENCE[operators.eq] + +# JSON + JSONB +ASTEXT = operators.custom_op( + "->>", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +JSONPATH_ASTEXT = operators.custom_op( + "#>>", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +# JSONB + HSTORE +HAS_KEY = operators.custom_op( + "?", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +HAS_ALL = operators.custom_op( + "?&", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +HAS_ANY = operators.custom_op( + "?|", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +# JSONB +DELETE_PATH = operators.custom_op( + "#-", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) + +PATH_EXISTS = operators.custom_op( + "@?", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +PATH_MATCH = operators.custom_op( + "@@", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +# JSONB + ARRAY + HSTORE + RANGE +CONTAINS = operators.custom_op( + "@>", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +CONTAINED_BY = operators.custom_op( + "<@", + precedence=_eq_precedence, + natural_self_precedent=True, + eager_grouping=True, + is_comparison=True, +) + +# ARRAY + RANGE +OVERLAP = operators.custom_op( + "&&", + precedence=_eq_precedence, + is_comparison=True, +) + +# RANGE +STRICTLY_LEFT_OF = operators.custom_op( + "<<", precedence=_eq_precedence, is_comparison=True +) + +STRICTLY_RIGHT_OF = operators.custom_op( + ">>", precedence=_eq_precedence, is_comparison=True +) + +NOT_EXTEND_RIGHT_OF = operators.custom_op( + "&<", precedence=_eq_precedence, is_comparison=True +) + +NOT_EXTEND_LEFT_OF = operators.custom_op( + "&>", precedence=_eq_precedence, is_comparison=True +) + +ADJACENT_TO = operators.custom_op( + "-|-", precedence=_eq_precedence, is_comparison=True +) + +# HSTORE +GETITEM = operators.custom_op( + "->", + precedence=_getitem_precedence, + natural_self_precedent=True, + eager_grouping=True, +) diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index 20006e7ab0..2cd1552a73 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -22,14 +22,23 @@ from typing import TYPE_CHECKING from typing import TypeVar from typing import Union +from .operators import ADJACENT_TO +from .operators import CONTAINED_BY +from .operators import CONTAINS +from .operators import NOT_EXTEND_LEFT_OF +from .operators import NOT_EXTEND_RIGHT_OF +from .operators import OVERLAP +from .operators import STRICTLY_LEFT_OF +from .operators import STRICTLY_RIGHT_OF from ... import types as sqltypes +from ...sql import operators +from ...sql.type_api import TypeEngine from ...util import py310 from ...util.typing import Literal if TYPE_CHECKING: from ...sql.elements import ColumnElement from ...sql.type_api import _TE - from ...sql.type_api import TypeEngine from ...sql.type_api import TypeEngineMixin _T = TypeVar("_T", bound=Any) @@ -766,16 +775,9 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): # empty Range, SQL datatype can't be determined here return sqltypes.NULLTYPE - class comparator_factory(sqltypes.Concatenable.Comparator[Range[Any]]): + class comparator_factory(TypeEngine.Comparator[Range[Any]]): """Define comparison operations for range types.""" - def __ne__(self, other: Any) -> ColumnElement[bool]: # type: ignore[override] # noqa: E501 - "Boolean expression. Returns true if two ranges are not equal" - if other is None: - return super().__ne__(other) # type: ignore - else: - return self.expr.op("<>", is_comparison=True)(other) # type: ignore # noqa: E501 - def contains(self, other: Any, **kw: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the right hand operand, which can be an element or a range, is contained within the @@ -784,25 +786,25 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): kwargs may be ignored by this operator but are required for API conformance. """ - return self.expr.op("@>", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(CONTAINS, other) def contained_by(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the column is contained within the right hand operand. """ - return self.expr.op("<@", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(CONTAINED_BY, other) def overlaps(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the column overlaps (has points in common with) the right hand operand. """ - return self.expr.op("&&", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(OVERLAP, other) def strictly_left_of(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the column is strictly left of the right hand operand. """ - return self.expr.op("<<", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(STRICTLY_LEFT_OF, other) __lshift__ = strictly_left_of @@ -810,7 +812,7 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): """Boolean expression. Returns true if the column is strictly right of the right hand operand. """ - return self.expr.op(">>", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(STRICTLY_RIGHT_OF, other) __rshift__ = strictly_right_of @@ -818,46 +820,40 @@ class AbstractRange(sqltypes.TypeEngine[Range[_T]]): """Boolean expression. Returns true if the range in the column does not extend right of the range in the operand. """ - return self.expr.op("&<", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(NOT_EXTEND_RIGHT_OF, other) def not_extend_left_of(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the range in the column does not extend left of the range in the operand. """ - return self.expr.op("&>", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(NOT_EXTEND_LEFT_OF, other) def adjacent_to(self, other: Any) -> ColumnElement[bool]: """Boolean expression. Returns true if the range in the column is adjacent to the range in the operand. """ - return self.expr.op("-|-", is_comparison=True)(other) # type: ignore # noqa: E501 + return self.expr.operate(ADJACENT_TO, other) def union(self, other: Any) -> ColumnElement[bool]: """Range expression. Returns the union of the two ranges. Will raise an exception if the resulting range is not contiguous. """ - return self.expr.op("+")(other) # type: ignore - - __add__ = union + return self.expr.operate(operators.add, other) def difference(self, other: Any) -> ColumnElement[bool]: """Range expression. Returns the union of the two ranges. Will raise an exception if the resulting range is not contiguous. """ - return self.expr.op("-")(other) # type: ignore - - __sub__ = difference + return self.expr.operate(operators.sub, other) def intersection(self, other: Any) -> ColumnElement[Range[_T]]: """Range expression. Returns the intersection of the two ranges. Will raise an exception if the resulting range is not contiguous. """ - return self.expr.op("*")(other) # type: ignore - - __mul__ = intersection + return self.expr.operate(operators.mul, other) class AbstractRangeImpl(AbstractRange[Range[_T]]): diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index e9e1d8ced8..0d72cde3a7 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -536,7 +536,7 @@ class ReflectedIndex(TypedDict): """whether or not the index has a unique flag""" duplicates_constraint: NotRequired[Optional[str]] - "Indicates if this index mirrors a unique constraint with this name" + "Indicates if this index mirrors a constraint with this name" include_columns: NotRequired[List[str]] """columns to include in the INCLUDE clause for supporting databases. diff --git a/setup.cfg b/setup.cfg index efeeee36d3..7477fec456 100644 --- a/setup.cfg +++ b/setup.cfg @@ -105,7 +105,7 @@ ignore = E203,E305,E711,E712,E721,E722,E741, N801,N802,N806, RST304,RST303,RST299,RST399, - W503,W504 + W503,W504,W601 exclude = .venv,.git,.tox,dist,doc,*egg,build import-order-style = google application-import-names = sqlalchemy,test diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 5df8bc0a5b..ca1b35a76a 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -1750,6 +1750,24 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): checkparams={"param_1": 4, "param_3": 6, "param_2": 5}, ) + def test_array_overlap_any(self): + col = column("x", postgresql.ARRAY(Integer)) + self.assert_compile( + select(col.overlap(any_(array([4, 5, 6])))), + "SELECT x && ANY (ARRAY[%(param_1)s, %(param_2)s, %(param_3)s]) " + "AS anon_1", + checkparams={"param_1": 4, "param_3": 6, "param_2": 5}, + ) + + def test_array_contains_any(self): + col = column("x", postgresql.ARRAY(Integer)) + self.assert_compile( + select(col.contains(any_(array([4, 5, 6])))), + "SELECT x @> ANY (ARRAY[%(param_1)s, %(param_2)s, %(param_3)s]) " + "AS anon_1", + checkparams={"param_1": 4, "param_3": 6, "param_2": 5}, + ) + def test_array_slice_index(self): col = column("x", postgresql.ARRAY(Integer)) self.assert_compile( @@ -3459,8 +3477,7 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): def test_where_has_key(self): self._test_where( - # hide from 2to3 - getattr(self.hashcol, "has_key")("foo"), + self.hashcol.has_key("foo"), "test_table.hash ? %(hash_1)s", ) @@ -3494,12 +3511,48 @@ class HStoreTest(AssertsCompiledSQL, fixtures.TestBase): "test_table.hash <@ %(hash_1)s", ) + def test_where_has_key_any(self): + self._test_where( + self.hashcol.has_key(any_(array(["foo"]))), + "test_table.hash ? ANY (ARRAY[%(param_1)s])", + ) + + def test_where_has_all_any(self): + self._test_where( + self.hashcol.has_all(any_(postgresql.array(["1", "2"]))), + "test_table.hash ?& ANY (ARRAY[%(param_1)s, %(param_2)s])", + ) + + def test_where_has_any_any(self): + self._test_where( + self.hashcol.has_any(any_(postgresql.array(["1", "2"]))), + "test_table.hash ?| ANY (ARRAY[%(param_1)s, %(param_2)s])", + ) + + def test_where_contains_any(self): + self._test_where( + self.hashcol.contains(any_(array(["foo"]))), + "test_table.hash @> ANY (ARRAY[%(param_1)s])", + ) + + def test_where_contained_by_any(self): + self._test_where( + self.hashcol.contained_by(any_(array(["foo"]))), + "test_table.hash <@ ANY (ARRAY[%(param_1)s])", + ) + def test_where_getitem(self): self._test_where( self.hashcol["bar"] == None, # noqa "(test_table.hash -> %(hash_1)s) IS NULL", ) + def test_where_getitem_any(self): + self._test_where( + self.hashcol["bar"] == any_(array(["foo"])), # noqa + "(test_table.hash -> %(hash_1)s) = ANY (ARRAY[%(param_1)s])", + ) + @testing.combinations( ( lambda self: self.hashcol["foo"], @@ -3812,6 +3865,10 @@ class _RangeTypeCompilation( ): __dialect__ = "postgresql" + @property + def _col_str_arr(self): + return self._col_str + # operator tests @classmethod @@ -3827,32 +3884,66 @@ class _RangeTypeCompilation( self.assert_compile(colclause, expected) is_(colclause.type._type_affinity, type_._type_affinity) - def test_where_equal(self): - self._test_clause( - self.col == self._data_str(), - "data_table.range = %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_where_equal_obj(self): + _comparisons = [ + (lambda col, other: col == other, "="), + (lambda col, other: col != other, "!="), + (lambda col, other: col > other, ">"), + (lambda col, other: col < other, "<"), + (lambda col, other: col >= other, ">="), + (lambda col, other: col <= other, "<="), + (lambda col, other: col.contains(other), "@>"), + (lambda col, other: col.contained_by(other), "<@"), + (lambda col, other: col.overlaps(other), "&&"), + (lambda col, other: col << other, "<<"), + (lambda col, other: col.strictly_left_of(other), "<<"), + (lambda col, other: col >> other, ">>"), + (lambda col, other: col.strictly_right_of(other), ">>"), + (lambda col, other: col.not_extend_left_of(other), "&>"), + (lambda col, other: col.not_extend_right_of(other), "&<"), + (lambda col, other: col.adjacent_to(other), "-|-"), + ] + + _operations = [ + (lambda col, other: col + other, "+"), + (lambda col, other: col.union(other), "+"), + (lambda col, other: col - other, "-"), + (lambda col, other: col.difference(other), "-"), + (lambda col, other: col * other, "*"), + (lambda col, other: col.intersection(other), "*"), + ] + + _all_fns = _comparisons + _operations + + _not_compare_op = ("+", "-", "*") + + @testing.combinations(*_all_fns, id_="as") + def test_data_str(self, fn, op): self._test_clause( - self.col == self._data_obj(), - f"data_table.range = %(range_1)s::{self._col_str}", - sqltypes.BOOLEANTYPE, + fn(self.col, self._data_str()), + f"data_table.range {op} %(range_1)s", + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE, ) - def test_where_not_equal(self): + @testing.combinations(*_all_fns, id_="as") + def test_data_obj(self, fn, op): self._test_clause( - self.col != self._data_str(), - "data_table.range <> %(range_1)s", - sqltypes.BOOLEANTYPE, + fn(self.col, self._data_obj()), + f"data_table.range {op} %(range_1)s::{self._col_str}", + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE, ) - def test_where_not_equal_obj(self): + @testing.combinations(*_comparisons, id_="as") + def test_data_str_any(self, fn, op): self._test_clause( - self.col != self._data_obj(), - f"data_table.range <> %(range_1)s::{self._col_str}", - sqltypes.BOOLEANTYPE, + fn(self.col, any_(array([self._data_str()]))), + f"data_table.range {op} ANY (ARRAY[%(param_1)s])", + self.col.type + if op in self._not_compare_op + else sqltypes.BOOLEANTYPE, ) def test_where_is_null(self): @@ -3867,140 +3958,6 @@ class _RangeTypeCompilation( sqltypes.BOOLEANTYPE, ) - def test_where_less_than(self): - self._test_clause( - self.col < self._data_str(), - "data_table.range < %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_where_greater_than(self): - self._test_clause( - self.col > self._data_str(), - "data_table.range > %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_where_less_than_or_equal(self): - self._test_clause( - self.col <= self._data_str(), - "data_table.range <= %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_where_greater_than_or_equal(self): - self._test_clause( - self.col >= self._data_str(), - "data_table.range >= %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_contains(self): - self._test_clause( - self.col.contains(self._data_str()), - "data_table.range @> %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_contains_obj(self): - self._test_clause( - self.col.contains(self._data_obj()), - f"data_table.range @> %(range_1)s::{self._col_str}", - sqltypes.BOOLEANTYPE, - ) - - def test_contained_by(self): - self._test_clause( - self.col.contained_by(self._data_str()), - "data_table.range <@ %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_overlaps(self): - self._test_clause( - self.col.overlaps(self._data_str()), - "data_table.range && %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_strictly_left_of(self): - self._test_clause( - self.col << self._data_str(), - "data_table.range << %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - self._test_clause( - self.col.strictly_left_of(self._data_str()), - "data_table.range << %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_strictly_right_of(self): - self._test_clause( - self.col >> self._data_str(), - "data_table.range >> %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - self._test_clause( - self.col.strictly_right_of(self._data_str()), - "data_table.range >> %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_not_extend_right_of(self): - self._test_clause( - self.col.not_extend_right_of(self._data_str()), - "data_table.range &< %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_not_extend_left_of(self): - self._test_clause( - self.col.not_extend_left_of(self._data_str()), - "data_table.range &> %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_adjacent_to(self): - self._test_clause( - self.col.adjacent_to(self._data_str()), - "data_table.range -|- %(range_1)s", - sqltypes.BOOLEANTYPE, - ) - - def test_union(self): - self._test_clause( - self.col + self.col, - "data_table.range + data_table.range", - self.col.type, - ) - - self._test_clause( - self.col.union(self._data_str()), - "data_table.range + %(range_1)s", - self.col.type, - ) - - def test_intersection(self): - self._test_clause( - self.col * self.col, - "data_table.range * data_table.range", - self.col.type, - ) - - def test_difference(self): - self._test_clause( - self.col - self.col, - "data_table.range - data_table.range", - self.col.type, - ) - - self._test_clause( - self.col.difference(self._data_str()), - "data_table.range - %(range_1)s", - self.col.type, - ) - class _RangeComparisonFixtures(_RangeTests): def _step_value_up(self, value): @@ -4768,6 +4725,7 @@ class _Int4RangeTests: _col_type = INT4RANGE _col_str = "INT4RANGE" + _col_str_arr = "INT8RANGE" def _data_str(self): return "[1,4)" @@ -4981,14 +4939,14 @@ class _MultiRangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase): def test_where_not_equal(self): self._test_clause( self.col != self._data_str(), - "data_table.multirange <> %(multirange_1)s", + "data_table.multirange != %(multirange_1)s", sqltypes.BOOLEANTYPE, ) def test_where_not_equal_obj(self): self._test_clause( self.col != self._data_obj(), - f"data_table.multirange <> %(multirange_1)s::{self._col_str}", + f"data_table.multirange != %(multirange_1)s::{self._col_str}", sqltypes.BOOLEANTYPE, ) @@ -5480,15 +5438,27 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): ) self.jsoncol = self.test_table.c.test_column + @property + def any_(self): + return any_(array([7])) + @testing.combinations( ( lambda self: self.jsoncol["bar"] == None, # noqa "(test_table.test_column -> %(test_column_1)s) IS NULL", ), + ( + lambda self: self.jsoncol["bar"] != None, # noqa + "(test_table.test_column -> %(test_column_1)s) IS NOT NULL", + ), ( lambda self: self.jsoncol[("foo", 1)] == None, # noqa "(test_table.test_column #> %(test_column_1)s) IS NULL", ), + ( + lambda self: self.jsoncol[("foo", 1)] != None, # noqa + "(test_table.test_column #> %(test_column_1)s) IS NOT NULL", + ), ( lambda self: self.jsoncol["bar"].astext == None, # noqa "(test_table.test_column ->> %(test_column_1)s) IS NULL", @@ -5507,6 +5477,45 @@ class JSONTest(AssertsCompiledSQL, fixtures.TestBase): lambda self: self.jsoncol[("foo", 1)].astext == None, # noqa "(test_table.test_column #>> %(test_column_1)s) IS NULL", ), + ( + lambda self: self.jsoncol["bar"] == 42, + "(test_table.test_column -> %(test_column_1)s) = %(param_1)s", + ), + ( + lambda self: self.jsoncol["bar"] != 42, + "(test_table.test_column -> %(test_column_1)s) != %(param_1)s", + ), + ( + lambda self: self.jsoncol["bar"] == self.any_, + "(test_table.test_column -> %(test_column_1)s) = " + "ANY (ARRAY[%(param_1)s])", + ), + ( + lambda self: self.jsoncol["bar"] != self.any_, + "(test_table.test_column -> %(test_column_1)s) != " + "ANY (ARRAY[%(param_1)s])", + ), + ( + lambda self: self.jsoncol["bar"].astext == self.any_, + "(test_table.test_column ->> %(test_column_1)s) = " + "ANY (ARRAY[%(param_1)s])", + ), + ( + lambda self: self.jsoncol["bar"].astext != self.any_, + "(test_table.test_column ->> %(test_column_1)s) != " + "ANY (ARRAY[%(param_1)s])", + ), + ( + lambda self: self.jsoncol[("foo", 1)] == self.any_, + "(test_table.test_column #> %(test_column_1)s) = " + "ANY (ARRAY[%(param_1)s])", + ), + ( + lambda self: self.jsoncol[("foo", 1)] != self.any_, + "(test_table.test_column #> %(test_column_1)s) != " + "ANY (ARRAY[%(param_1)s])", + ), + id_="as", ) def test_where(self, whereclause_fn, expected): whereclause = whereclause_fn(self) @@ -5832,30 +5841,49 @@ class JSONBTest(JSONTest): @testing.combinations( ( - # hide from 2to3 - lambda self: getattr(self.jsoncol, "has_key")("data"), + lambda self: self.jsoncol.has_key("data"), "test_table.test_column ? %(test_column_1)s", ), + ( + lambda self: self.jsoncol.has_key(self.any_), + "test_table.test_column ? ANY (ARRAY[%(param_1)s])", + ), ( lambda self: self.jsoncol.has_all( {"name": "r1", "data": {"k1": "r1v1", "k2": "r1v2"}} ), "test_table.test_column ?& %(test_column_1)s", ), + ( + lambda self: self.jsoncol.has_all(self.any_), + "test_table.test_column ?& ANY (ARRAY[%(param_1)s])", + ), ( lambda self: self.jsoncol.has_any( postgresql.array(["name", "data"]) ), "test_table.test_column ?| ARRAY[%(param_1)s, %(param_2)s]", ), + ( + lambda self: self.jsoncol.has_any(self.any_), + "test_table.test_column ?| ANY (ARRAY[%(param_1)s])", + ), ( lambda self: self.jsoncol.contains({"k1": "r1v1"}), "test_table.test_column @> %(test_column_1)s", ), + ( + lambda self: self.jsoncol.contains(self.any_), + "test_table.test_column @> ANY (ARRAY[%(param_1)s])", + ), ( lambda self: self.jsoncol.contained_by({"foo": "1", "bar": None}), "test_table.test_column <@ %(test_column_1)s", ), + ( + lambda self: self.jsoncol.contained_by(self.any_), + "test_table.test_column <@ ANY (ARRAY[%(param_1)s])", + ), ( lambda self: self.jsoncol.delete_path(["a", "b"]), "test_table.test_column #- CAST(ARRAY[%(param_1)s, " @@ -5870,12 +5898,21 @@ class JSONBTest(JSONTest): lambda self: self.jsoncol.path_exists("$.k1"), "test_table.test_column @? %(test_column_1)s", ), + ( + lambda self: self.jsoncol.path_exists(self.any_), + "test_table.test_column @? ANY (ARRAY[%(param_1)s])", + ), ( lambda self: self.jsoncol.path_match("$.k1[0] > 2"), "test_table.test_column @@ %(test_column_1)s", ), + ( + lambda self: self.jsoncol.path_match(self.any_), + "test_table.test_column @@ ANY (ARRAY[%(param_1)s])", + ), + id_="as", ) - def test_where(self, whereclause_fn, expected): + def test_where_jsonb(self, whereclause_fn, expected): super().test_where(whereclause_fn, expected)