]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implement multi-element expression constructs
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 12 Apr 2022 17:52:31 +0000 (13:52 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Wed, 13 Apr 2022 21:19:31 +0000 (17:19 -0400)
Improved the construction of SQL binary expressions to allow for very long
expressions against the same associative operator without special steps
needed in order to avoid high memory use and excess recursion depth. A
particular binary operation ``A op B`` can now be joined against another
element ``op C`` and the resulting structure will be "flattened" so that
the representation as well as SQL compilation does not require recursion.

To implement this more cleanly, the biggest change here is that
column-oriented lists of things are broken away from ClauseList
in a new class ExpressionClauseList, that also forms the basis
of BooleanClauseList. ClauseList is still used for the generic
"comma-separated list" of things such as Tuple and things like
ORDER BY, as well as in some API endpoints.

Also adds __slots__ to the TypeEngine-bound Comparator
classes.   Still can't really do __slots__ on ClauseElement.

Fixes: #7744
Change-Id: I81a8ceb6f8f3bb0fe52d58f3cb42e4b6c2bc9018

15 files changed:
doc/build/changelog/unreleased_20/7744.rst [new file with mode: 0644]
lib/sqlalchemy/dialects/mssql/base.py
lib/sqlalchemy/dialects/mysql/base.py
lib/sqlalchemy/dialects/postgresql/array.py
lib/sqlalchemy/orm/evaluator.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/default_comparator.py
lib/sqlalchemy/sql/elements.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/sqltypes.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/testing/suite/test_types.py
test/dialect/mysql/test_compiler.py
test/sql/test_operators.py

diff --git a/doc/build/changelog/unreleased_20/7744.rst b/doc/build/changelog/unreleased_20/7744.rst
new file mode 100644 (file)
index 0000000..b69ee7a
--- /dev/null
@@ -0,0 +1,18 @@
+.. change::
+    :tags: bug, sql
+    :tickets: 7744
+
+    Improved the construction of SQL binary expressions to allow for very long
+    expressions against the same associative operator without special steps
+    needed in order to avoid high memory use and excess recursion depth. A
+    particular binary operation ``A op B`` can now be joined against another
+    element ``op C`` and the resulting structure will be "flattened" so that
+    the representation as well as SQL compilation does not require recursion.
+
+    One effect of this change is that string concatenation expressions which
+    use SQL functions come out as "flat", e.g. MySQL will now render
+    ``concat('x', 'y', 'z', ...)``` rather than nesting together two-element
+    functions like ``concat(concat('x', 'y'), 'z')``.  Third-party dialects
+    which override the string concatenation operator will need to implement
+    a new method ``def visit_concat_op_expression_clauselist()`` to
+    accompany the existing ``def visit_concat_op_binary()`` method.
index 35428b659a86972c848b48fff560edf2d8690fb4..2bacaaf3338123822dce20a490c7b5da067eae9d 100644 (file)
@@ -1833,6 +1833,11 @@ class MSSQLCompiler(compiler.SQLCompiler):
     def visit_char_length_func(self, fn, **kw):
         return "LEN%s" % self.function_argspec(fn, **kw)
 
+    def visit_concat_op_expression_clauselist(
+        self, clauselist, operator, **kw
+    ):
+        return " + ".join(self.process(elem, **kw) for elem in clauselist)
+
     def visit_concat_op_binary(self, binary, operator, **kw):
         return "%s + %s" % (
             self.process(binary.left, **kw),
index 25f4c6945aa8561f9744a1dc617e03bd75b968f5..b53e55abf2533e73f8d624243cb7b435b7ae4e1f 100644 (file)
@@ -1322,6 +1322,13 @@ class MySQLCompiler(compiler.SQLCompiler):
 
         return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses)
 
+    def visit_concat_op_expression_clauselist(
+        self, clauselist, operator, **kw
+    ):
+        return "concat(%s)" % (
+            ", ".join(self.process(elem, **kw) for elem in clauselist.clauses)
+        )
+
     def visit_concat_op_binary(self, binary, operator, **kw):
         return "concat(%s, %s)" % (
             self.process(binary.left, **kw),
index 74643c4d92120dac604067e4ab0214938a3d3be8..7eec7b86fb7c909901267c7d88f75d695e201f95 100644 (file)
@@ -5,14 +5,19 @@
 # This module is part of SQLAlchemy and is released under
 # the MIT License: https://www.opensource.org/licenses/mit-license.php
 
+from __future__ import annotations
+
 import re
+from typing import Any
+from typing import TypeVar
 
 from ... import types as sqltypes
 from ... import util
-from ...sql import coercions
 from ...sql import expression
 from ...sql import operators
-from ...sql import roles
+
+
+_T = TypeVar("_T", bound=Any)
 
 
 def Any(other, arrexpr, operator=operators.eq):
@@ -33,7 +38,7 @@ def All(other, arrexpr, operator=operators.eq):
     return arrexpr.all(other, operator)
 
 
-class array(expression.ClauseList, expression.ColumnElement):
+class array(expression.ExpressionClauseList[_T]):
 
     """A PostgreSQL ARRAY literal.
 
@@ -90,16 +95,19 @@ class array(expression.ClauseList, expression.ColumnElement):
     inherit_cache = True
 
     def __init__(self, clauses, **kw):
-        clauses = [
-            coercions.expect(roles.ExpressionElementRole, c) for c in clauses
-        ]
-
-        self._type_tuple = [arg.type for arg in clauses]
-        main_type = kw.pop(
-            "type_",
-            self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE,
+
+        type_arg = kw.pop("type_", None)
+        super(array, self).__init__(operators.comma_op, *clauses, **kw)
+
+        self._type_tuple = [arg.type for arg in self.clauses]
+
+        main_type = (
+            type_arg
+            if type_arg is not None
+            else self._type_tuple[0]
+            if self._type_tuple
+            else sqltypes.NULLTYPE
         )
-        super(array, self).__init__(*clauses, **kw)
 
         if isinstance(main_type, ARRAY):
             self.type = ARRAY(
index 453fc8903cb06e6a5cb1baafd51e320d2364f2cd..1b3340dc5d90a4dba8830f94230af6ad507637a7 100644 (file)
@@ -94,6 +94,9 @@ class EvaluatorCompiler:
     def visit_tuple(self, clause):
         return self.visit_clauselist(clause)
 
+    def visit_expression_clauselist(self, clause):
+        return self.visit_clauselist(clause)
+
     def visit_clauselist(self, clause):
         evaluators = [self.process(clause) for clause in clause.clauses]
 
index 93b49ab254895280ead2d258e32e2f5924928f44..7298d3630ef4f75675f123a76723717968c14849 100644 (file)
@@ -884,12 +884,12 @@ def _emit_update_statements(
         clauses = BooleanClauseList._construct_raw(operators.and_)
 
         for col in mapper._pks_by_table[table]:
-            clauses.clauses.append(
+            clauses._append_inplace(
                 col == sql.bindparam(col._label, type_=col.type)
             )
 
         if needs_version_id:
-            clauses.clauses.append(
+            clauses._append_inplace(
                 mapper.version_id_col
                 == sql.bindparam(
                     mapper.version_id_col._label,
@@ -1316,12 +1316,12 @@ def _emit_post_update_statements(
         clauses = BooleanClauseList._construct_raw(operators.and_)
 
         for col in mapper._pks_by_table[table]:
-            clauses.clauses.append(
+            clauses._append_inplace(
                 col == sql.bindparam(col._label, type_=col.type)
             )
 
         if needs_version_id:
-            clauses.clauses.append(
+            clauses._append_inplace(
                 mapper.version_id_col
                 == sql.bindparam(
                     mapper.version_id_col._label,
@@ -1437,12 +1437,12 @@ def _emit_delete_statements(
         clauses = BooleanClauseList._construct_raw(operators.and_)
 
         for col in mapper._pks_by_table[table]:
-            clauses.clauses.append(
+            clauses._append_inplace(
                 col == sql.bindparam(col.key, type_=col.type)
             )
 
         if need_version_id:
-            clauses.clauses.append(
+            clauses._append_inplace(
                 mapper.version_id_col
                 == sql.bindparam(
                     mapper.version_id_col.key, type_=mapper.version_id_col.type
index 522a0bd4a0830b6e6b135897c69a5259bf071ba3..9c074db33e5e975f96b82f815fc66bd3dfdc037b 100644 (file)
@@ -2013,6 +2013,24 @@ class SQLCompiler(Compiled):
 
         return self._generate_delimited_list(clauselist.clauses, sep, **kw)
 
+    def visit_expression_clauselist(self, clauselist, **kw):
+        operator_ = clauselist.operator
+
+        disp = self._get_operator_dispatch(
+            operator_, "expression_clauselist", None
+        )
+        if disp:
+            return disp(clauselist, operator_, **kw)
+
+        try:
+            opstring = OPERATORS[operator_]
+        except KeyError as err:
+            raise exc.UnsupportedCompilationError(self, operator_) from err
+        else:
+            return self._generate_delimited_list(
+                clauselist.clauses, opstring, **kw
+            )
+
     def visit_case(self, clause, **kwargs):
         x = "CASE "
         if clause.value is not None:
index 944a0a5ce6270479ab952be051c04c1efdfb40ce..512fca8d0939c6894b1919b97fc8802ec7a2b5ce 100644 (file)
@@ -27,11 +27,12 @@ from . import type_api
 from .elements import and_
 from .elements import BinaryExpression
 from .elements import ClauseElement
-from .elements import ClauseList
 from .elements import CollationClause
 from .elements import CollectionAggregate
+from .elements import ExpressionClauseList
 from .elements import False_
 from .elements import Null
+from .elements import OperatorExpression
 from .elements import or_
 from .elements import True_
 from .elements import UnaryExpression
@@ -56,11 +57,9 @@ def _boolean_compare(
     reverse: bool = False,
     _python_is_types: Tuple[Type[Any], ...] = (type(None), bool),
     _any_all_expr: bool = False,
-    result_type: Optional[
-        Union[Type[TypeEngine[bool]], TypeEngine[bool]]
-    ] = None,
+    result_type: Optional[TypeEngine[bool]] = None,
     **kwargs: Any,
-) -> BinaryExpression[bool]:
+) -> OperatorExpression[bool]:
     if result_type is None:
         result_type = type_api.BOOLEANTYPE
 
@@ -71,7 +70,7 @@ def _boolean_compare(
         if op in (operators.eq, operators.ne) and isinstance(
             obj, (bool, True_, False_)
         ):
-            return BinaryExpression(
+            return OperatorExpression._construct_for_op(
                 expr,
                 coercions.expect(roles.ConstExprRole, obj),
                 op,
@@ -83,7 +82,7 @@ def _boolean_compare(
             operators.is_distinct_from,
             operators.is_not_distinct_from,
         ):
-            return BinaryExpression(
+            return OperatorExpression._construct_for_op(
                 expr,
                 coercions.expect(roles.ConstExprRole, obj),
                 op,
@@ -98,7 +97,7 @@ def _boolean_compare(
         else:
             # all other None uses IS, IS NOT
             if op in (operators.eq, operators.is_):
-                return BinaryExpression(
+                return OperatorExpression._construct_for_op(
                     expr,
                     coercions.expect(roles.ConstExprRole, obj),
                     operators.is_,
@@ -106,7 +105,7 @@ def _boolean_compare(
                     type_=result_type,
                 )
             elif op in (operators.ne, operators.is_not):
-                return BinaryExpression(
+                return OperatorExpression._construct_for_op(
                     expr,
                     coercions.expect(roles.ConstExprRole, obj),
                     operators.is_not,
@@ -125,7 +124,7 @@ def _boolean_compare(
         )
 
     if reverse:
-        return BinaryExpression(
+        return OperatorExpression._construct_for_op(
             obj,
             expr,
             op,
@@ -134,7 +133,7 @@ def _boolean_compare(
             modifiers=kwargs,
         )
     else:
-        return BinaryExpression(
+        return OperatorExpression._construct_for_op(
             expr,
             obj,
             op,
@@ -169,11 +168,9 @@ def _binary_operate(
     obj: roles.BinaryElementRole[Any],
     *,
     reverse: bool = False,
-    result_type: Optional[
-        Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"]
-    ] = None,
+    result_type: Optional[TypeEngine[_T]] = None,
     **kw: Any,
-) -> BinaryExpression[_T]:
+) -> OperatorExpression[_T]:
 
     coerced_obj = coercions.expect(
         roles.BinaryElementRole, obj, expr=expr, operator=op
@@ -189,7 +186,9 @@ def _binary_operate(
             op, right.comparator
         )
 
-    return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
+    return OperatorExpression._construct_for_op(
+        left, right, op, type_=result_type, modifiers=kw
+    )
 
 
 def _conjunction_operate(
@@ -311,7 +310,9 @@ def _between_impl(
     """See :meth:`.ColumnOperators.between`."""
     return BinaryExpression(
         expr,
-        ClauseList(
+        ExpressionClauseList._construct_for_list(
+            operators.and_,
+            type_api.NULLTYPE,
             coercions.expect(
                 roles.BinaryElementRole,
                 cleft,
@@ -324,9 +325,7 @@ def _between_impl(
                 expr=expr,
                 operator=operators.and_,
             ),
-            operator=operators.and_,
             group=False,
-            group_contents=False,
         ),
         op,
         negate=operators.not_between_op
index 805758283538e8a40d9212f6680ab6666e9ce185..d47d138f7cc3af2c77e539ceeb3c5b3ac80186a6 100644 (file)
@@ -1323,7 +1323,11 @@ class ColumnElement(
         if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
             return AsBoolean(self, operators.is_false, operators.is_true)
         else:
-            return cast("UnaryExpression[_T]", super()._negate())
+            grouped = self.self_group(against=operators.inv)
+            assert isinstance(grouped, ColumnElement)
+            return UnaryExpression(
+                grouped, operator=operators.inv, wraps_column_expression=True
+            )
 
     type: TypeEngine[_T]
 
@@ -2501,6 +2505,8 @@ class ClauseList(
 
     __visit_name__ = "clauselist"
 
+    # this is used only by the ORM in a legacy use case for
+    # composite attributes
     _is_clause_list = True
 
     _traverse_internals: _TraverseInternalsType = [
@@ -2516,18 +2522,14 @@ class ClauseList(
         operator: OperatorType = operators.comma_op,
         group: bool = True,
         group_contents: bool = True,
-        _flatten_sub_clauses: bool = False,
         _literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole,
     ):
         self.operator = operator
         self.group = group
         self.group_contents = group_contents
         clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses
-        if _flatten_sub_clauses:
-            clauses_iterator = util.flatten_iterator(clauses_iterator)
-
-        self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
         text_converter_role: Type[roles.SQLRole] = _literal_as_text_role
+        self._text_converter_role = text_converter_role
 
         if self.group_contents:
             self.clauses = [
@@ -2594,8 +2596,176 @@ class ClauseList(
             return self
 
 
-class BooleanClauseList(ClauseList, ColumnElement[bool]):
-    __visit_name__ = "clauselist"
+class OperatorExpression(ColumnElement[_T]):
+    """base for expressions that contain an operator and operands
+
+    .. versionadded:: 2.0
+
+    """
+
+    operator: OperatorType
+    type: TypeEngine[_T]
+
+    group: bool = True
+
+    @property
+    def is_comparison(self):
+        return operators.is_comparison(self.operator)
+
+    def self_group(self, against=None):
+        if (
+            self.group
+            and operators.is_precedent(self.operator, against)
+            or (
+                # a negate against a non-boolean operator
+                # doesn't make too much sense but we should
+                # group for that
+                against is operators.inv
+                and not operators.is_boolean(self.operator)
+            )
+        ):
+            return Grouping(self)
+        else:
+            return self
+
+    @property
+    def _flattened_operator_clauses(
+        self,
+    ) -> typing_Tuple[ColumnElement[Any], ...]:
+        raise NotImplementedError()
+
+    @classmethod
+    def _construct_for_op(
+        cls,
+        left: ColumnElement[Any],
+        right: ColumnElement[Any],
+        op: OperatorType,
+        *,
+        type_: TypeEngine[_T],
+        negate: Optional[OperatorType] = None,
+        modifiers: Optional[Mapping[str, Any]] = None,
+    ) -> OperatorExpression[_T]:
+
+        if operators.is_associative(op):
+            assert (
+                negate is None
+            ), f"negate not supported for associative operator {op}"
+
+            multi = False
+            if getattr(
+                left, "operator", None
+            ) is op and type_._compare_type_affinity(left.type):
+                multi = True
+                left_flattened = left._flattened_operator_clauses
+            else:
+                left_flattened = (left,)
+
+            if getattr(
+                right, "operator", None
+            ) is op and type_._compare_type_affinity(right.type):
+                multi = True
+                right_flattened = right._flattened_operator_clauses
+            else:
+                right_flattened = (right,)
+
+            if multi:
+                return ExpressionClauseList._construct_for_list(
+                    op, type_, *(left_flattened + right_flattened)
+                )
+
+        return BinaryExpression(
+            left, right, op, type_=type_, negate=negate, modifiers=modifiers
+        )
+
+
+class ExpressionClauseList(OperatorExpression[_T]):
+    """Describe a list of clauses, separated by an operator,
+    in a column expression context.
+
+    :class:`.ExpressionClauseList` differs from :class:`.ClauseList` in that
+    it represents a column-oriented DQL expression only, not an open ended
+    list of anything comma separated.
+
+    .. versionadded:: 2.0
+
+    """
+
+    __visit_name__ = "expression_clauselist"
+
+    _traverse_internals: _TraverseInternalsType = [
+        ("clauses", InternalTraversal.dp_clauseelement_tuple),
+        ("operator", InternalTraversal.dp_operator),
+    ]
+
+    clauses: typing_Tuple[ColumnElement[Any], ...]
+
+    group: bool
+
+    def __init__(
+        self,
+        operator: OperatorType,
+        *clauses: _ColumnExpressionArgument[Any],
+        type_: Optional[_TypeEngineArgument[_T]] = None,
+    ):
+        self.operator = operator
+
+        self.clauses = tuple(
+            coercions.expect(
+                roles.ExpressionElementRole, clause, apply_propagate_attrs=self
+            )
+            for clause in clauses
+        )
+        self._is_implicitly_boolean = operators.is_boolean(self.operator)
+        self.type = type_api.to_instance(type_)  # type: ignore
+
+    @property
+    def _flattened_operator_clauses(
+        self,
+    ) -> typing_Tuple[ColumnElement[Any], ...]:
+        return self.clauses
+
+    def __iter__(self) -> Iterator[ColumnElement[Any]]:
+        return iter(self.clauses)
+
+    def __len__(self) -> int:
+        return len(self.clauses)
+
+    @property
+    def _select_iterable(self) -> _SelectIterable:
+        return (self,)
+
+    @util.ro_non_memoized_property
+    def _from_objects(self) -> List[FromClause]:
+        return list(itertools.chain(*[c._from_objects for c in self.clauses]))
+
+    def _append_inplace(self, clause: ColumnElement[Any]) -> None:
+        self.clauses += (clause,)
+
+    @classmethod
+    def _construct_for_list(
+        cls,
+        operator: OperatorType,
+        type_: TypeEngine[_T],
+        *clauses: ColumnElement[Any],
+        group: bool = True,
+    ) -> ExpressionClauseList[_T]:
+        self = cls.__new__(cls)
+        self.group = group
+        self.clauses = clauses
+        self.operator = operator
+        self.type = type_
+        return self
+
+    def _negate(self) -> Any:
+        grouped = self.self_group(against=operators.inv)
+        assert isinstance(grouped, ColumnElement)
+        return UnaryExpression(
+            grouped, operator=operators.inv, wraps_column_expression=True
+        )
+
+
+class BooleanClauseList(ExpressionClauseList[bool]):
+    __visit_name__ = "expression_clauselist"
     inherit_cache = True
 
     def __init__(self, *arg, **kw):
@@ -2668,7 +2838,15 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
         if lcc > 1:
             # multiple elements.  Return regular BooleanClauseList
             # which will link elements against the operator.
-            return cls._construct_raw(operator, convert_clauses)  # type: ignore # noqa: E501
+
+            flattened_clauses = itertools.chain.from_iterable(
+                (c for c in to_flat._flattened_operator_clauses)
+                if getattr(to_flat, "operator", None) is operator
+                else (to_flat,)
+                for to_flat in convert_clauses
+            )
+
+            return cls._construct_raw(operator, flattened_clauses)  # type: ignore # noqa: E501
         elif lcc == 1:
             # just one element.  return it as a single boolean element,
             # not a list and discard the operator.
@@ -2726,10 +2904,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
         clauses: Optional[Sequence[ColumnElement[Any]]] = None,
     ) -> BooleanClauseList:
         self = cls.__new__(cls)
-        self.clauses = list(clauses) if clauses else []
+        self.clauses = tuple(clauses) if clauses else ()
         self.group = True
         self.operator = operator
-        self.group_contents = True
         self.type = type_api.BOOLEANTYPE
         self._is_implicitly_boolean = True
         return self
@@ -2768,9 +2945,6 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]):
         else:
             return super(BooleanClauseList, self).self_group(against=against)
 
-    def _negate(self):
-        return ClauseList._negate(self)
-
 
 and_ = BooleanClauseList.and_
 or_ = BooleanClauseList.or_
@@ -3357,7 +3531,7 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]):
             return AsBoolean(self.element, self.negate, self.operator)
 
 
-class BinaryExpression(ColumnElement[_T]):
+class BinaryExpression(OperatorExpression[_T]):
     """Represent an expression that is ``LEFT <operator> RIGHT``.
 
     A :class:`.BinaryExpression` is generated automatically
@@ -3394,12 +3568,12 @@ class BinaryExpression(ColumnElement[_T]):
     modifiers: Optional[Mapping[str, Any]]
 
     left: ColumnElement[Any]
-    right: Union[ColumnElement[Any], ClauseList]
+    right: ColumnElement[Any]
 
     def __init__(
         self,
         left: ColumnElement[Any],
-        right: Union[ColumnElement[Any], ClauseList],
+        right: ColumnElement[Any],
         operator: OperatorType,
         type_: Optional[_TypeEngineArgument[_T]] = None,
         negate: Optional[OperatorType] = None,
@@ -3427,6 +3601,12 @@ class BinaryExpression(ColumnElement[_T]):
         else:
             self.modifiers = modifiers
 
+    @property
+    def _flattened_operator_clauses(
+        self,
+    ) -> typing_Tuple[ColumnElement[Any], ...]:
+        return (self.left, self.right)
+
     def __bool__(self):
         """Implement Python-side "bool" for BinaryExpression as a
         simple "identity" check for the left and right attributes,
@@ -3465,8 +3645,6 @@ class BinaryExpression(ColumnElement[_T]):
         else:
             raise TypeError("Boolean value of this clause is not defined")
 
-    __nonzero__ = __bool__
-
     if typing.TYPE_CHECKING:
 
         def __invert__(
@@ -3474,21 +3652,10 @@ class BinaryExpression(ColumnElement[_T]):
         ) -> "BinaryExpression[_T]":
             ...
 
-    @property
-    def is_comparison(self):
-        return operators.is_comparison(self.operator)
-
     @util.ro_non_memoized_property
     def _from_objects(self) -> List[FromClause]:
         return self.left._from_objects + self.right._from_objects
 
-    def self_group(self, against=None):
-
-        if operators.is_precedent(self.operator, against):
-            return Grouping(self)
-        else:
-            return self
-
     def _negate(self):
         if self.negate is not None:
             return BinaryExpression(
index 455e74f7b0e3ea13773dd8d57a0b6e9200364d67..d08bbf4eb3fe2585e498a0229feaa94e1ffc44de 100644 (file)
@@ -81,6 +81,7 @@ from .elements import ClauseList as ClauseList
 from .elements import CollectionAggregate as CollectionAggregate
 from .elements import ColumnClause as ColumnClause
 from .elements import ColumnElement as ColumnElement
+from .elements import ExpressionClauseList as ExpressionClauseList
 from .elements import Extract as Extract
 from .elements import False_ as False_
 from .elements import FunctionFilter as FunctionFilter
index 803e85654894cdc3f0810c74fc26e9f5482fbfda..8d98f893fbe3a4f90663ea31de4b36c12031e984 100644 (file)
@@ -84,6 +84,7 @@ class HasExpressionLookup(TypeEngineMixin):
         raise NotImplementedError()
 
     class Comparator(TypeEngine.Comparator[_CT]):
+        __slots__ = ()
 
         _blank_dict = util.EMPTY_DICT
 
@@ -114,6 +115,8 @@ class Concatenable(TypeEngineMixin):
     typically strings."""
 
     class Comparator(TypeEngine.Comparator[_T]):
+        __slots__ = ()
+
         def _adapt_expression(
             self,
             op: OperatorType,
@@ -143,6 +146,8 @@ class Indexable(TypeEngineMixin):
     """
 
     class Comparator(TypeEngine.Comparator[_T]):
+        __slots__ = ()
+
         def _setup_getitem(self, index):
             raise NotImplementedError()
 
@@ -174,12 +179,9 @@ class String(Concatenable, TypeEngine[str]):
     __visit_name__ = "string"
 
     def __init__(
-        # note pylance appears to require the "self" type in a constructor
-        # for the _T type to be correctly recognized when we send the
-        # class as the argument, e.g. `column("somecol", String)`
         self,
-        length=None,
-        collation=None,
+        length: Optional[int] = None,
+        collation: Optional[str] = None,
     ):
         """
         Create a string-holding type.
@@ -1508,6 +1510,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]):
                 ) from err
 
     class Comparator(String.Comparator[str]):
+        __slots__ = ()
+
         type: String
 
         def _adapt_expression(
@@ -1963,7 +1967,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]):
         TypeDecorator.Comparator[_CT],
         _AbstractInterval.Comparator[_CT],
     ):
-        pass
+        __slots__ = ()
 
     comparator_factory = Comparator
 
@@ -2385,6 +2389,8 @@ class JSON(Indexable, TypeEngine[Any]):
     class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]):
         """Define comparison operations for :class:`_types.JSON`."""
 
+        __slots__ = ()
+
         def _setup_getitem(self, index):
             if not isinstance(index, str) and isinstance(
                 index, collections_abc.Sequence
@@ -2710,6 +2716,8 @@ class ARRAY(
 
         """
 
+        __slots__ = ()
+
         def _setup_getitem(self, index):
 
             arr_type = cast(ARRAY, self.type)
@@ -3221,6 +3229,8 @@ class NullType(TypeEngine[None]):
         return process
 
     class Comparator(TypeEngine.Comparator[_T]):
+        __slots__ = ()
+
         def _adapt_expression(
             self,
             op: OperatorType,
index cbc4e9e707b202f4f7a92968c9a11fa4c4030354..c23cd04dd4b2e1dfeef9532c384ea3e8e7f29fa9 100644 (file)
@@ -924,7 +924,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
                 ):
                     return COMPARE_FAILED
 
-    def compare_clauselist(self, left, right, **kw):
+    def compare_expression_clauselist(self, left, right, **kw):
         if left.operator is right.operator:
             if operators.is_associative(left.operator):
                 if self._compare_unordered_sequences(
@@ -938,6 +938,9 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
         else:
             return COMPARE_FAILED
 
+    def compare_clauselist(self, left, right, **kw):
+        return self.compare_expression_clauselist(left, right, **kw)
+
     def compare_binary(self, left, right, **kw):
         if left.operator == right.operator:
             if operators.is_commutative(left.operator):
index cc14dd9c4ffc8bf30995f9ec2a0802f2136dfc0d..25fe844c38f3e947c19bebfcc5be0dae18c72b5d 100644 (file)
@@ -320,6 +320,31 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase):
         data = r"backslash one \ backslash two \\ end"
         literal_round_trip(String(40), [data], [data])
 
+    def test_concatenate_binary(self, connection):
+        """dialects with special string concatenation operators should
+        implement visit_concat_op_binary() and visit_concat_op_clauselist()
+        in their compiler.
+
+        .. versionchanged:: 2.0  visit_concat_op_clauselist() is also needed
+           for dialects to override the string concatenation operator.
+
+        """
+        eq_(connection.scalar(select(literal("a") + "b")), "ab")
+
+    def test_concatenate_clauselist(self, connection):
+        """dialects with special string concatenation operators should
+        implement visit_concat_op_binary() and visit_concat_op_clauselist()
+        in their compiler.
+
+        .. versionchanged:: 2.0  visit_concat_op_clauselist() is also needed
+           for dialects to override the string concatenation operator.
+
+        """
+        eq_(
+            connection.scalar(select(literal("a") + "b" + "c" + "d" + "e")),
+            "abcde",
+        )
+
 
 class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase):
     compare = None
index fa8f37c1f83ea6a8f4db7f434b613113c14585a6..3fb52416ec9a051b9a8d3fdc73b5266418792434 100644 (file)
@@ -1121,7 +1121,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL):
         expected_sql = (
             "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON "
             "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), "
-            "baz = (concat(concat(VALUES(baz), %s), VALUES(bar)))"
+            "baz = (concat(VALUES(baz), %s, VALUES(bar)))"
         )
         self.assert_compile(
             stmt,
index e5176713aba5e79ccbe49483546b16220e7ceb6b..88d1ea0530f060eba4069344348fe483539aac96 100644 (file)
@@ -165,6 +165,21 @@ class DefaultColumnComparatorTest(
         loop = LoopOperate()
         is_(operator(loop, *arg), operator)
 
+    @testing.combinations(
+        operators.add,
+        operators.and_,
+        operators.or_,
+        operators.mul,
+        argnames="op",
+    )
+    def test_nonsensical_negations(self, op):
+
+        opstring = compiler.OPERATORS[op]
+        self.assert_compile(
+            select(~op(column("x"), column("q"))),
+            f"SELECT NOT (x{opstring}q) AS anon_1",
+        )
+
     def test_null_true_false_is_sanity_checks(self):
 
         d = default.DefaultDialect()
@@ -328,6 +343,176 @@ class DefaultColumnComparatorTest(
             )
 
 
+class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL):
+    __dialect__ = "default"
+
+    @testing.combinations(True, False, argnames="reverse")
+    @testing.combinations(True, False, argnames="negate")
+    def test_associatives_mismatched_type(self, reverse, negate):
+        """test we get two separate exprs if the types dont match, operator
+        is not lost.
+
+        the expressions here don't generally make sense from a SQL
+        perspective, we are checking just that the operators / parenthesis /
+        negation works out in the SQL string to reasonably correspond to
+        what the Python structures look like.
+
+        """
+
+        expr1 = column("i1", Integer) + column("i2", Integer)
+
+        expr2 = column("d1", String) + column("d2", String)
+
+        if reverse:
+            expr = expr2 + expr1
+
+            self.assert_compile(
+                select(expr), "SELECT (d1 || d2) + i1 + i2 AS anon_1"
+            )
+        else:
+            expr = expr1 + expr2
+
+            self.assert_compile(
+                select(expr), "SELECT i1 + i2 + d1 || d2 AS anon_1"
+            )
+
+    @testing.combinations(
+        operators.add,
+        operators.and_,
+        operators.or_,
+        operators.mul,
+        argnames="op",
+    )
+    @testing.combinations(True, False, argnames="reverse")
+    @testing.combinations(True, False, argnames="negate")
+    def test_associatives(self, op, reverse, negate):
+        t1 = table("t", column("q"), column("p"))
+
+        num = 500
+
+        expr = op(t1.c.q, t1.c.p)
+
+        if reverse:
+            for i in range(num - 1, -1, -1):
+                expr = op(column(f"d{i}"), expr)
+        else:
+            for i in range(num):
+                expr = op(expr, column(f"d{i}"))
+
+        opstring = compiler.OPERATORS[op]
+        exprs = opstring.join(f"d{i}" for i in range(num))
+
+        if negate:
+            self.assert_compile(
+                select(~expr),
+                f"SELECT NOT (t.q{opstring}t.p{opstring}{exprs}) "
+                "AS anon_1 FROM t"
+                if not reverse
+                else f"SELECT NOT ({exprs}{opstring}t.q{opstring}t.p) "
+                "AS anon_1 FROM t",
+            )
+        else:
+            self.assert_compile(
+                select(expr),
+                f"SELECT t.q{opstring}t.p{opstring}{exprs} AS anon_1 FROM t"
+                if not reverse
+                else f"SELECT {exprs}{opstring}t.q{opstring}t.p "
+                f"AS anon_1 FROM t",
+            )
+
+    @testing.combinations(
+        operators.gt,
+        operators.eq,
+        operators.le,
+        operators.sub,
+        argnames="op",
+    )
+    @testing.combinations(True, False, argnames="reverse")
+    @testing.combinations(True, False, argnames="negate")
+    def test_non_associatives(self, op, reverse, negate):
+        """similar tests as test_associatives but for non-assoc
+        operators.
+
+        the expressions here don't generally make sense from a SQL
+        perspective, we are checking just that the operators / parenthesis /
+        negation works out in the SQL string to reasonably correspond to
+        what the Python structures look like.
+
+        """
+        t1 = table("t", column("q"), column("p"))
+
+        num = 5
+
+        expr = op(t1.c.q, t1.c.p)
+
+        if reverse:
+            for i in range(num - 1, -1, -1):
+                expr = op(column(f"d{i}"), expr)
+        else:
+            for i in range(num):
+                expr = op(expr, column(f"d{i}"))
+
+        opstring = compiler.OPERATORS[op]
+        exprs = opstring.join(f"d{i}" for i in range(num))
+
+        if negate:
+            negate_op = {
+                operators.gt: operators.le,
+                operators.eq: operators.ne,
+                operators.le: operators.gt,
+            }.get(op, None)
+
+            if negate_op:
+                negate_opstring = compiler.OPERATORS[negate_op]
+                if reverse:
+                    str_expr = (
+                        f"d0{negate_opstring}(d1{opstring}(d2{opstring}"
+                        f"(d3{opstring}(d4{opstring}(t.q{opstring}t.p)))))"
+                    )
+                else:
+                    str_expr = (
+                        f"(((((t.q{opstring}t.p){opstring}d0){opstring}d1)"
+                        f"{opstring}d2){opstring}d3){negate_opstring}d4"
+                    )
+            else:
+                if reverse:
+                    str_expr = (
+                        f"NOT (d0{opstring}(d1{opstring}(d2{opstring}"
+                        f"(d3{opstring}(d4{opstring}(t.q{opstring}t.p))))))"
+                    )
+                else:
+                    str_expr = (
+                        f"NOT ((((((t.q{opstring}t.p){opstring}d0)"
+                        f"{opstring}d1){opstring}d2){opstring}d3){opstring}d4)"
+                    )
+
+            self.assert_compile(
+                select(~expr),
+                f"SELECT {str_expr} AS anon_1 FROM t"
+                if not reverse
+                else f"SELECT {str_expr} AS anon_1 FROM t",
+            )
+        else:
+
+            if reverse:
+                str_expr = (
+                    f"d0{opstring}(d1{opstring}(d2{opstring}"
+                    f"(d3{opstring}(d4{opstring}(t.q{opstring}t.p)))))"
+                )
+            else:
+                str_expr = (
+                    f"(((((t.q{opstring}t.p){opstring}d0)"
+                    f"{opstring}d1){opstring}d2){opstring}d3){opstring}d4"
+                )
+
+            self.assert_compile(
+                select(expr),
+                f"SELECT {str_expr} AS anon_1 FROM t"
+                if not reverse
+                else f"SELECT {str_expr} AS anon_1 FROM t",
+            )
+
+
 class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     __dialect__ = "default"
 
@@ -2954,7 +3139,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     def test_contains_concat(self):
         self.assert_compile(
             column("x").contains("y"),
-            "x LIKE concat(concat('%%', %s), '%%')",
+            "x LIKE concat('%%', %s, '%%')",
             checkparams={"x_1": "y"},
             dialect=mysql.dialect(),
         )
@@ -2962,7 +3147,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     def test_not_contains_concat(self):
         self.assert_compile(
             ~column("x").contains("y"),
-            "x NOT LIKE concat(concat('%%', %s), '%%')",
+            "x NOT LIKE concat('%%', %s, '%%')",
             checkparams={"x_1": "y"},
             dialect=mysql.dialect(),
         )
@@ -2970,7 +3155,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     def test_contains_literal_concat(self):
         self.assert_compile(
             column("x").contains(literal_column("y")),
-            "x LIKE concat(concat('%%', y), '%%')",
+            "x LIKE concat('%%', y, '%%')",
             checkparams={},
             dialect=mysql.dialect(),
         )
@@ -2978,7 +3163,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL):
     def test_contains_text_concat(self):
         self.assert_compile(
             column("x").contains(text("y")),
-            "x LIKE concat(concat('%%', y), '%%')",
+            "x LIKE concat('%%', y, '%%')",
             checkparams={},
             dialect=mysql.dialect(),
         )