From: Mike Bayer Date: Tue, 2 Jan 2024 16:16:13 +0000 (-0500) Subject: refactor any_ / all_ X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=f3ca2350a5d0a34d86ceb934682798438f769e59;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git refactor any_ / all_ Improved compilation of :func:`_sql.any_` / :func:`_sql.all_` in the context of a negation of boolean comparison, will now render ``NOT (expr)`` rather than reversing the equality operator to not equals, allowing finer-grained control of negations for these non-typical operators. Fixes: #10817 Change-Id: If0b324b1220ad3c7f053af91e8a61c81015f312a --- diff --git a/doc/build/changelog/unreleased_20/10817.rst b/doc/build/changelog/unreleased_20/10817.rst new file mode 100644 index 0000000000..69634d06dc --- /dev/null +++ b/doc/build/changelog/unreleased_20/10817.rst @@ -0,0 +1,8 @@ +.. change:: + :tags: bug, sql + :tickets: 10817 + + Improved compilation of :func:`_sql.any_` / :func:`_sql.all_` in the + context of a negation of boolean comparison, will now render ``NOT (expr)`` + rather than reversing the equality operator to not equals, allowing + finer-grained control of negations for these non-typical operators. diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 939b14c5d4..072acafed3 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -56,7 +56,6 @@ def _boolean_compare( negate_op: Optional[OperatorType] = None, reverse: bool = False, _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), - _any_all_expr: bool = False, result_type: Optional[TypeEngine[bool]] = None, **kwargs: Any, ) -> OperatorExpression[bool]: @@ -90,7 +89,7 @@ def _boolean_compare( negate=negate_op, modifiers=kwargs, ) - elif _any_all_expr: + elif expr._is_collection_aggregate: obj = coercions.expect( roles.ConstExprRole, element=obj, operator=op, expr=expr ) diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index e6d7ad7da8..45eb8f3c55 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -803,6 +803,7 @@ class CompilerColumnElement( __slots__ = () _propagate_attrs = util.EMPTY_DICT + _is_collection_aggregate = False # SQLCoreOperations should be suiting the ExpressionElementRole @@ -1407,6 +1408,7 @@ class ColumnElement( _is_column_element = True _insert_sentinel: bool = False _omit_from_statements = False + _is_collection_aggregate = False foreign_keys: AbstractSet[ForeignKey] = frozenset() @@ -2361,6 +2363,8 @@ class TextClause( _omit_from_statements = False + _is_collection_aggregate = False + @property def _hide_froms(self) -> Iterable[FromClause]: return () @@ -2966,6 +2970,9 @@ class OperatorExpression(ColumnElement[_T]): *(left_flattened + right_flattened), ) + if right._is_collection_aggregate: + negate = None + return BinaryExpression( left, right, op, type_=type_, negate=negate, modifiers=modifiers ) @@ -3804,6 +3811,7 @@ class CollectionAggregate(UnaryExpression[_T]): """ inherit_cache = True + _is_collection_aggregate = True @classmethod def _create_any( @@ -3845,7 +3853,7 @@ class CollectionAggregate(UnaryExpression[_T]): raise exc.ArgumentError( "Only comparison operators may be used with ANY/ALL" ) - kwargs["reverse"] = kwargs["_any_all_expr"] = True + kwargs["reverse"] = True return self.comparator.operate(operators.mirror(op), *other, **kwargs) def reverse_operate(self, op, other, **kwargs): @@ -4033,7 +4041,7 @@ class BinaryExpression(OperatorExpression[_T]): modifiers=self.modifiers, ) else: - return super()._negate() + return self.self_group()._negate() class Slice(ColumnElement[Any]): diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py index d91f760706..53fad3ea21 100644 --- a/lib/sqlalchemy/sql/operators.py +++ b/lib/sqlalchemy/sql/operators.py @@ -1819,10 +1819,10 @@ class ColumnOperators(Operators): See the documentation for :func:`_sql.any_` for examples. .. note:: be sure to not confuse the newer - :meth:`_sql.ColumnOperators.any_` method with its older - :class:`_types.ARRAY`-specific counterpart, the - :meth:`_types.ARRAY.Comparator.any` method, which a different - calling syntax and usage pattern. + :meth:`_sql.ColumnOperators.any_` method with the **legacy** + version of this method, the :meth:`_types.ARRAY.Comparator.any` + method that's specific to :class:`_types.ARRAY`, which uses a + different calling style. """ return self.operate(any_op) @@ -1834,10 +1834,10 @@ class ColumnOperators(Operators): See the documentation for :func:`_sql.all_` for examples. .. note:: be sure to not confuse the newer - :meth:`_sql.ColumnOperators.all_` method with its older - :class:`_types.ARRAY`-specific counterpart, the - :meth:`_types.ARRAY.Comparator.all` method, which a different - calling syntax and usage pattern. + :meth:`_sql.ColumnOperators.all_` method with the **legacy** + version of this method, the :meth:`_types.ARRAY.Comparator.all` + method that's specific to :class:`_types.ARRAY`, which uses a + different calling style. """ return self.operate(all_op) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 91e382de69..0963e8ed20 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2924,7 +2924,7 @@ class ARRAY( def any(self, other, operator=None): """Return ``other operator ANY (array)`` clause. - .. note:: This method is an :class:`_types.ARRAY` - specific + .. legacy:: This method is an :class:`_types.ARRAY` - specific construct that is now superseded by the :func:`_sql.any_` function, which features a different calling style. The :func:`_sql.any_` function is also mirrored at the method level @@ -2958,9 +2958,8 @@ class ARRAY( arr_type = self.type - # send plain BinaryExpression so that negate remains at None, - # leading to NOT expr for negation. - return elements.BinaryExpression( + return elements.CollectionAggregate._create_any(self.expr).operate( + operators.mirror(operator), coercions.expect( roles.BinaryElementRole, element=other, @@ -2968,19 +2967,17 @@ class ARRAY( expr=self.expr, bindparam_type=arr_type.item_type, ), - elements.CollectionAggregate._create_any(self.expr), - operator, ) @util.preload_module("sqlalchemy.sql.elements") def all(self, other, operator=None): """Return ``other operator ALL (array)`` clause. - .. note:: This method is an :class:`_types.ARRAY` - specific - construct that is now superseded by the :func:`_sql.any_` + .. legacy:: This method is an :class:`_types.ARRAY` - specific + construct that is now superseded by the :func:`_sql.all_` function, which features a different calling style. The - :func:`_sql.any_` function is also mirrored at the method level - via the :meth:`_sql.ColumnOperators.any_` method. + :func:`_sql.all_` function is also mirrored at the method level + via the :meth:`_sql.ColumnOperators.all_` method. Usage of array-specific :meth:`_types.ARRAY.Comparator.all` is as follows:: @@ -3010,9 +3007,8 @@ class ARRAY( arr_type = self.type - # send plain BinaryExpression so that negate remains at None, - # leading to NOT expr for negation. - return elements.BinaryExpression( + return elements.CollectionAggregate._create_all(self.expr).operate( + operators.mirror(operator), coercions.expect( roles.BinaryElementRole, element=other, @@ -3020,8 +3016,6 @@ class ARRAY( expr=self.expr, bindparam_type=arr_type.item_type, ), - elements.CollectionAggregate._create_all(self.expr), - operator, ) comparator_factory = Comparator diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index af51010c76..7e61920aa2 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -4540,7 +4540,7 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): ) return t - @testing.combinations( + null_comparisons = testing.combinations( lambda col: any_(col) == None, lambda col: col.any_() == None, lambda col: any_(col) == null(), @@ -4551,12 +4551,23 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): lambda col: None == col.any_(), argnames="expr", ) + + @null_comparisons @testing.combinations("int", "array", argnames="datatype") def test_any_generic_null(self, datatype, expr, t_fixture): col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval self.assert_compile(expr(col), "NULL = ANY (tab1.%s)" % col.name) + @null_comparisons + @testing.combinations("int", "array", argnames="datatype") + def test_any_generic_null_negate(self, datatype, expr, t_fixture): + col = t_fixture.c.data if datatype == "int" else t_fixture.c.arrval + + self.assert_compile( + ~expr(col), "NOT (NULL = ANY (tab1.%s))" % col.name + ) + @testing.fixture( params=[ ("ANY", any_), @@ -4565,48 +4576,78 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): ("ALL", lambda x: x.all_()), ] ) - def operator(self, request): + def any_all_operators(self, request): return request.param + # test legacy array any() / all(). these are superseded by the + # any_() / all_() versions @testing.fixture( params=[ ("ANY", lambda x, *o: x.any(*o)), ("ALL", lambda x, *o: x.all(*o)), ] ) - def array_op(self, request): + def legacy_any_all_operators(self, request): return request.param - def test_array(self, t_fixture, operator): + def test_array(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(t.c.arrval), f":param_1 = {op} (tab1.arrval)", checkparams={"param_1": 5}, ) - def test_comparator_array(self, t_fixture, operator): + def test_comparator_inline_negate(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators + self.assert_compile( + 5 != fn(t.c.arrval), + f":param_1 != {op} (tab1.arrval)", + checkparams={"param_1": 5}, + ) + + @testing.combinations( + (operator.eq, "="), + (operator.ne, "!="), + (operator.gt, ">"), + (operator.le, "<="), + argnames="operator,opstring", + ) + def test_comparator_outer_negate( + self, t_fixture, any_all_operators, operator, opstring + ): + """test #10817""" + t = t_fixture + op, fn = any_all_operators + self.assert_compile( + ~(operator(5, fn(t.c.arrval))), + f"NOT (:param_1 {opstring} {op} (tab1.arrval))", + checkparams={"param_1": 5}, + ) + + def test_comparator_array(self, t_fixture, any_all_operators): + t = t_fixture + op, fn = any_all_operators self.assert_compile( 5 > fn(t.c.arrval), f":param_1 > {op} (tab1.arrval)", checkparams={"param_1": 5}, ) - def test_comparator_array_wexpr(self, t_fixture, operator): + def test_comparator_array_wexpr(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( t.c.data > fn(t.c.arrval), f"tab1.data > {op} (tab1.arrval)", checkparams={}, ) - def test_illegal_ops(self, t_fixture, operator): + def test_illegal_ops(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators assert_raises_message( exc.ArgumentError, @@ -4622,10 +4663,10 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): t.c.data + fn(t.c.arrval), f"tab1.data + {op} (tab1.arrval)" ) - def test_bindparam_coercion(self, t_fixture, array_op): + def test_bindparam_coercion(self, t_fixture, legacy_any_all_operators): """test #7979""" t = t_fixture - op, fn = array_op + op, fn = legacy_any_all_operators expr = fn(t.c.arrval, bindparam("param")) expected = f"%(param)s = {op} (tab1.arrval)" @@ -4633,9 +4674,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): self.assert_compile(expr, expected, dialect="postgresql") - def test_array_comparator_accessor(self, t_fixture, array_op): + def test_array_comparator_accessor( + self, t_fixture, legacy_any_all_operators + ): t = t_fixture - op, fn = array_op + op, fn = legacy_any_all_operators self.assert_compile( fn(t.c.arrval, 5, operator.gt), @@ -4643,9 +4686,11 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"arrval_1": 5}, ) - def test_array_comparator_negate_accessor(self, t_fixture, array_op): + def test_array_comparator_negate_accessor( + self, t_fixture, legacy_any_all_operators + ): t = t_fixture - op, fn = array_op + op, fn = legacy_any_all_operators self.assert_compile( ~fn(t.c.arrval, 5, operator.gt), @@ -4653,9 +4698,9 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"arrval_1": 5}, ) - def test_array_expression(self, t_fixture, operator): + def test_array_expression(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(t.c.arrval[5:6] + postgresql.array([3, 4])), @@ -4671,9 +4716,9 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): dialect="postgresql", ) - def test_subq(self, t_fixture, operator): + def test_subq(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(select(t.c.data).where(t.c.data < 10).scalar_subquery()), @@ -4682,9 +4727,9 @@ class AnyAllTest(fixtures.TestBase, testing.AssertsCompiledSQL): checkparams={"data_1": 10, "param_1": 5}, ) - def test_scalar_values(self, t_fixture, operator): + def test_scalar_values(self, t_fixture, any_all_operators): t = t_fixture - op, fn = operator + op, fn = any_all_operators self.assert_compile( 5 == fn(values(t.c.data).data([(1,), (42,)]).scalar_values()),