From: Mike Bayer Date: Tue, 12 Apr 2022 17:52:31 +0000 (-0400) Subject: implement multi-element expression constructs X-Git-Tag: rel_2_0_0b1~351^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=428262a2d5374613f4a4cf925bbd9e94e0e34acc;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git implement multi-element expression constructs Improved the construction of SQL binary expressions to allow for very long expressions against the same associative operator without special steps needed in order to avoid high memory use and excess recursion depth. A particular binary operation ``A op B`` can now be joined against another element ``op C`` and the resulting structure will be "flattened" so that the representation as well as SQL compilation does not require recursion. To implement this more cleanly, the biggest change here is that column-oriented lists of things are broken away from ClauseList in a new class ExpressionClauseList, that also forms the basis of BooleanClauseList. ClauseList is still used for the generic "comma-separated list" of things such as Tuple and things like ORDER BY, as well as in some API endpoints. Also adds __slots__ to the TypeEngine-bound Comparator classes. Still can't really do __slots__ on ClauseElement. Fixes: #7744 Change-Id: I81a8ceb6f8f3bb0fe52d58f3cb42e4b6c2bc9018 --- diff --git a/doc/build/changelog/unreleased_20/7744.rst b/doc/build/changelog/unreleased_20/7744.rst new file mode 100644 index 0000000000..b69ee7aff8 --- /dev/null +++ b/doc/build/changelog/unreleased_20/7744.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: bug, sql + :tickets: 7744 + + Improved the construction of SQL binary expressions to allow for very long + expressions against the same associative operator without special steps + needed in order to avoid high memory use and excess recursion depth. A + particular binary operation ``A op B`` can now be joined against another + element ``op C`` and the resulting structure will be "flattened" so that + the representation as well as SQL compilation does not require recursion. + + One effect of this change is that string concatenation expressions which + use SQL functions come out as "flat", e.g. MySQL will now render + ``concat('x', 'y', 'z', ...)``` rather than nesting together two-element + functions like ``concat(concat('x', 'y'), 'z')``. Third-party dialects + which override the string concatenation operator will need to implement + a new method ``def visit_concat_op_expression_clauselist()`` to + accompany the existing ``def visit_concat_op_binary()`` method. diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 35428b659a..2bacaaf333 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -1833,6 +1833,11 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_char_length_func(self, fn, **kw): return "LEN%s" % self.function_argspec(fn, **kw) + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return " + ".join(self.process(elem, **kw) for elem in clauselist) + def visit_concat_op_binary(self, binary, operator, **kw): return "%s + %s" % ( self.process(binary.left, **kw), diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 25f4c6945a..b53e55abf2 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1322,6 +1322,13 @@ class MySQLCompiler(compiler.SQLCompiler): return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) + def visit_concat_op_expression_clauselist( + self, clauselist, operator, **kw + ): + return "concat(%s)" % ( + ", ".join(self.process(elem, **kw) for elem in clauselist.clauses) + ) + def visit_concat_op_binary(self, binary, operator, **kw): return "concat(%s, %s)" % ( self.process(binary.left, **kw), diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 74643c4d92..7eec7b86fb 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -5,14 +5,19 @@ # This module is part of SQLAlchemy and is released under # the MIT License: https://www.opensource.org/licenses/mit-license.php +from __future__ import annotations + import re +from typing import Any +from typing import TypeVar from ... import types as sqltypes from ... import util -from ...sql import coercions from ...sql import expression from ...sql import operators -from ...sql import roles + + +_T = TypeVar("_T", bound=Any) def Any(other, arrexpr, operator=operators.eq): @@ -33,7 +38,7 @@ def All(other, arrexpr, operator=operators.eq): return arrexpr.all(other, operator) -class array(expression.ClauseList, expression.ColumnElement): +class array(expression.ExpressionClauseList[_T]): """A PostgreSQL ARRAY literal. @@ -90,16 +95,19 @@ class array(expression.ClauseList, expression.ColumnElement): inherit_cache = True def __init__(self, clauses, **kw): - clauses = [ - coercions.expect(roles.ExpressionElementRole, c) for c in clauses - ] - - self._type_tuple = [arg.type for arg in clauses] - main_type = kw.pop( - "type_", - self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE, + + type_arg = kw.pop("type_", None) + super(array, self).__init__(operators.comma_op, *clauses, **kw) + + self._type_tuple = [arg.type for arg in self.clauses] + + main_type = ( + type_arg + if type_arg is not None + else self._type_tuple[0] + if self._type_tuple + else sqltypes.NULLTYPE ) - super(array, self).__init__(*clauses, **kw) if isinstance(main_type, ARRAY): self.type = ARRAY( diff --git a/lib/sqlalchemy/orm/evaluator.py b/lib/sqlalchemy/orm/evaluator.py index 453fc8903c..1b3340dc5d 100644 --- a/lib/sqlalchemy/orm/evaluator.py +++ b/lib/sqlalchemy/orm/evaluator.py @@ -94,6 +94,9 @@ class EvaluatorCompiler: def visit_tuple(self, clause): return self.visit_clauselist(clause) + def visit_expression_clauselist(self, clause): + return self.visit_clauselist(clause) + def visit_clauselist(self, clause): evaluators = [self.process(clause) for clause in clause.clauses] diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 93b49ab254..7298d3630e 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -884,12 +884,12 @@ def _emit_update_statements( clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: - clauses.clauses.append( + clauses._append_inplace( col == sql.bindparam(col._label, type_=col.type) ) if needs_version_id: - clauses.clauses.append( + clauses._append_inplace( mapper.version_id_col == sql.bindparam( mapper.version_id_col._label, @@ -1316,12 +1316,12 @@ def _emit_post_update_statements( clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: - clauses.clauses.append( + clauses._append_inplace( col == sql.bindparam(col._label, type_=col.type) ) if needs_version_id: - clauses.clauses.append( + clauses._append_inplace( mapper.version_id_col == sql.bindparam( mapper.version_id_col._label, @@ -1437,12 +1437,12 @@ def _emit_delete_statements( clauses = BooleanClauseList._construct_raw(operators.and_) for col in mapper._pks_by_table[table]: - clauses.clauses.append( + clauses._append_inplace( col == sql.bindparam(col.key, type_=col.type) ) if need_version_id: - clauses.clauses.append( + clauses._append_inplace( mapper.version_id_col == sql.bindparam( mapper.version_id_col.key, type_=mapper.version_id_col.type diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 522a0bd4a0..9c074db33e 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2013,6 +2013,24 @@ class SQLCompiler(Compiled): return self._generate_delimited_list(clauselist.clauses, sep, **kw) + def visit_expression_clauselist(self, clauselist, **kw): + operator_ = clauselist.operator + + disp = self._get_operator_dispatch( + operator_, "expression_clauselist", None + ) + if disp: + return disp(clauselist, operator_, **kw) + + try: + opstring = OPERATORS[operator_] + except KeyError as err: + raise exc.UnsupportedCompilationError(self, operator_) from err + else: + return self._generate_delimited_list( + clauselist.clauses, opstring, **kw + ) + def visit_case(self, clause, **kwargs): x = "CASE " if clause.value is not None: diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py index 944a0a5ce6..512fca8d09 100644 --- a/lib/sqlalchemy/sql/default_comparator.py +++ b/lib/sqlalchemy/sql/default_comparator.py @@ -27,11 +27,12 @@ from . import type_api from .elements import and_ from .elements import BinaryExpression from .elements import ClauseElement -from .elements import ClauseList from .elements import CollationClause from .elements import CollectionAggregate +from .elements import ExpressionClauseList from .elements import False_ from .elements import Null +from .elements import OperatorExpression from .elements import or_ from .elements import True_ from .elements import UnaryExpression @@ -56,11 +57,9 @@ def _boolean_compare( reverse: bool = False, _python_is_types: Tuple[Type[Any], ...] = (type(None), bool), _any_all_expr: bool = False, - result_type: Optional[ - Union[Type[TypeEngine[bool]], TypeEngine[bool]] - ] = None, + result_type: Optional[TypeEngine[bool]] = None, **kwargs: Any, -) -> BinaryExpression[bool]: +) -> OperatorExpression[bool]: if result_type is None: result_type = type_api.BOOLEANTYPE @@ -71,7 +70,7 @@ def _boolean_compare( if op in (operators.eq, operators.ne) and isinstance( obj, (bool, True_, False_) ): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), op, @@ -83,7 +82,7 @@ def _boolean_compare( operators.is_distinct_from, operators.is_not_distinct_from, ): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), op, @@ -98,7 +97,7 @@ def _boolean_compare( else: # all other None uses IS, IS NOT if op in (operators.eq, operators.is_): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), operators.is_, @@ -106,7 +105,7 @@ def _boolean_compare( type_=result_type, ) elif op in (operators.ne, operators.is_not): - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, coercions.expect(roles.ConstExprRole, obj), operators.is_not, @@ -125,7 +124,7 @@ def _boolean_compare( ) if reverse: - return BinaryExpression( + return OperatorExpression._construct_for_op( obj, expr, op, @@ -134,7 +133,7 @@ def _boolean_compare( modifiers=kwargs, ) else: - return BinaryExpression( + return OperatorExpression._construct_for_op( expr, obj, op, @@ -169,11 +168,9 @@ def _binary_operate( obj: roles.BinaryElementRole[Any], *, reverse: bool = False, - result_type: Optional[ - Union[Type["TypeEngine[_T]"], "TypeEngine[_T]"] - ] = None, + result_type: Optional[TypeEngine[_T]] = None, **kw: Any, -) -> BinaryExpression[_T]: +) -> OperatorExpression[_T]: coerced_obj = coercions.expect( roles.BinaryElementRole, obj, expr=expr, operator=op @@ -189,7 +186,9 @@ def _binary_operate( op, right.comparator ) - return BinaryExpression(left, right, op, type_=result_type, modifiers=kw) + return OperatorExpression._construct_for_op( + left, right, op, type_=result_type, modifiers=kw + ) def _conjunction_operate( @@ -311,7 +310,9 @@ def _between_impl( """See :meth:`.ColumnOperators.between`.""" return BinaryExpression( expr, - ClauseList( + ExpressionClauseList._construct_for_list( + operators.and_, + type_api.NULLTYPE, coercions.expect( roles.BinaryElementRole, cleft, @@ -324,9 +325,7 @@ def _between_impl( expr=expr, operator=operators.and_, ), - operator=operators.and_, group=False, - group_contents=False, ), op, negate=operators.not_between_op diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 8057582835..d47d138f7c 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -1323,7 +1323,11 @@ class ColumnElement( if self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity: return AsBoolean(self, operators.is_false, operators.is_true) else: - return cast("UnaryExpression[_T]", super()._negate()) + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) type: TypeEngine[_T] @@ -2501,6 +2505,8 @@ class ClauseList( __visit_name__ = "clauselist" + # this is used only by the ORM in a legacy use case for + # composite attributes _is_clause_list = True _traverse_internals: _TraverseInternalsType = [ @@ -2516,18 +2522,14 @@ class ClauseList( operator: OperatorType = operators.comma_op, group: bool = True, group_contents: bool = True, - _flatten_sub_clauses: bool = False, _literal_as_text_role: Type[roles.SQLRole] = roles.WhereHavingRole, ): self.operator = operator self.group = group self.group_contents = group_contents clauses_iterator: Iterable[_ColumnExpressionArgument[Any]] = clauses - if _flatten_sub_clauses: - clauses_iterator = util.flatten_iterator(clauses_iterator) - - self._text_converter_role: Type[roles.SQLRole] = _literal_as_text_role text_converter_role: Type[roles.SQLRole] = _literal_as_text_role + self._text_converter_role = text_converter_role if self.group_contents: self.clauses = [ @@ -2594,8 +2596,176 @@ class ClauseList( return self -class BooleanClauseList(ClauseList, ColumnElement[bool]): - __visit_name__ = "clauselist" +class OperatorExpression(ColumnElement[_T]): + """base for expressions that contain an operator and operands + + .. versionadded:: 2.0 + + """ + + operator: OperatorType + type: TypeEngine[_T] + + group: bool = True + + @property + def is_comparison(self): + return operators.is_comparison(self.operator) + + def self_group(self, against=None): + if ( + self.group + and operators.is_precedent(self.operator, against) + or ( + # a negate against a non-boolean operator + # doesn't make too much sense but we should + # group for that + against is operators.inv + and not operators.is_boolean(self.operator) + ) + ): + return Grouping(self) + else: + return self + + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + raise NotImplementedError() + + @classmethod + def _construct_for_op( + cls, + left: ColumnElement[Any], + right: ColumnElement[Any], + op: OperatorType, + *, + type_: TypeEngine[_T], + negate: Optional[OperatorType] = None, + modifiers: Optional[Mapping[str, Any]] = None, + ) -> OperatorExpression[_T]: + + if operators.is_associative(op): + assert ( + negate is None + ), f"negate not supported for associative operator {op}" + + multi = False + if getattr( + left, "operator", None + ) is op and type_._compare_type_affinity(left.type): + multi = True + left_flattened = left._flattened_operator_clauses + else: + left_flattened = (left,) + + if getattr( + right, "operator", None + ) is op and type_._compare_type_affinity(right.type): + multi = True + right_flattened = right._flattened_operator_clauses + else: + right_flattened = (right,) + + if multi: + return ExpressionClauseList._construct_for_list( + op, type_, *(left_flattened + right_flattened) + ) + + return BinaryExpression( + left, right, op, type_=type_, negate=negate, modifiers=modifiers + ) + + +class ExpressionClauseList(OperatorExpression[_T]): + """Describe a list of clauses, separated by an operator, + in a column expression context. + + :class:`.ExpressionClauseList` differs from :class:`.ClauseList` in that + it represents a column-oriented DQL expression only, not an open ended + list of anything comma separated. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "expression_clauselist" + + _traverse_internals: _TraverseInternalsType = [ + ("clauses", InternalTraversal.dp_clauseelement_tuple), + ("operator", InternalTraversal.dp_operator), + ] + + clauses: typing_Tuple[ColumnElement[Any], ...] + + group: bool + + def __init__( + self, + operator: OperatorType, + *clauses: _ColumnExpressionArgument[Any], + type_: Optional[_TypeEngineArgument[_T]] = None, + ): + self.operator = operator + + self.clauses = tuple( + coercions.expect( + roles.ExpressionElementRole, clause, apply_propagate_attrs=self + ) + for clause in clauses + ) + self._is_implicitly_boolean = operators.is_boolean(self.operator) + self.type = type_api.to_instance(type_) # type: ignore + + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + return self.clauses + + def __iter__(self) -> Iterator[ColumnElement[Any]]: + return iter(self.clauses) + + def __len__(self) -> int: + return len(self.clauses) + + @property + def _select_iterable(self) -> _SelectIterable: + return (self,) + + @util.ro_non_memoized_property + def _from_objects(self) -> List[FromClause]: + return list(itertools.chain(*[c._from_objects for c in self.clauses])) + + def _append_inplace(self, clause: ColumnElement[Any]) -> None: + self.clauses += (clause,) + + @classmethod + def _construct_for_list( + cls, + operator: OperatorType, + type_: TypeEngine[_T], + *clauses: ColumnElement[Any], + group: bool = True, + ) -> ExpressionClauseList[_T]: + self = cls.__new__(cls) + self.group = group + self.clauses = clauses + self.operator = operator + self.type = type_ + return self + + def _negate(self) -> Any: + grouped = self.self_group(against=operators.inv) + assert isinstance(grouped, ColumnElement) + return UnaryExpression( + grouped, operator=operators.inv, wraps_column_expression=True + ) + + +class BooleanClauseList(ExpressionClauseList[bool]): + __visit_name__ = "expression_clauselist" inherit_cache = True def __init__(self, *arg, **kw): @@ -2668,7 +2838,15 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): if lcc > 1: # multiple elements. Return regular BooleanClauseList # which will link elements against the operator. - return cls._construct_raw(operator, convert_clauses) # type: ignore # noqa: E501 + + flattened_clauses = itertools.chain.from_iterable( + (c for c in to_flat._flattened_operator_clauses) + if getattr(to_flat, "operator", None) is operator + else (to_flat,) + for to_flat in convert_clauses + ) + + return cls._construct_raw(operator, flattened_clauses) # type: ignore # noqa: E501 elif lcc == 1: # just one element. return it as a single boolean element, # not a list and discard the operator. @@ -2726,10 +2904,9 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): clauses: Optional[Sequence[ColumnElement[Any]]] = None, ) -> BooleanClauseList: self = cls.__new__(cls) - self.clauses = list(clauses) if clauses else [] + self.clauses = tuple(clauses) if clauses else () self.group = True self.operator = operator - self.group_contents = True self.type = type_api.BOOLEANTYPE self._is_implicitly_boolean = True return self @@ -2768,9 +2945,6 @@ class BooleanClauseList(ClauseList, ColumnElement[bool]): else: return super(BooleanClauseList, self).self_group(against=against) - def _negate(self): - return ClauseList._negate(self) - and_ = BooleanClauseList.and_ or_ = BooleanClauseList.or_ @@ -3357,7 +3531,7 @@ class AsBoolean(WrapsColumnExpression[bool], UnaryExpression[bool]): return AsBoolean(self.element, self.negate, self.operator) -class BinaryExpression(ColumnElement[_T]): +class BinaryExpression(OperatorExpression[_T]): """Represent an expression that is ``LEFT RIGHT``. A :class:`.BinaryExpression` is generated automatically @@ -3394,12 +3568,12 @@ class BinaryExpression(ColumnElement[_T]): modifiers: Optional[Mapping[str, Any]] left: ColumnElement[Any] - right: Union[ColumnElement[Any], ClauseList] + right: ColumnElement[Any] def __init__( self, left: ColumnElement[Any], - right: Union[ColumnElement[Any], ClauseList], + right: ColumnElement[Any], operator: OperatorType, type_: Optional[_TypeEngineArgument[_T]] = None, negate: Optional[OperatorType] = None, @@ -3427,6 +3601,12 @@ class BinaryExpression(ColumnElement[_T]): else: self.modifiers = modifiers + @property + def _flattened_operator_clauses( + self, + ) -> typing_Tuple[ColumnElement[Any], ...]: + return (self.left, self.right) + def __bool__(self): """Implement Python-side "bool" for BinaryExpression as a simple "identity" check for the left and right attributes, @@ -3465,8 +3645,6 @@ class BinaryExpression(ColumnElement[_T]): else: raise TypeError("Boolean value of this clause is not defined") - __nonzero__ = __bool__ - if typing.TYPE_CHECKING: def __invert__( @@ -3474,21 +3652,10 @@ class BinaryExpression(ColumnElement[_T]): ) -> "BinaryExpression[_T]": ... - @property - def is_comparison(self): - return operators.is_comparison(self.operator) - @util.ro_non_memoized_property def _from_objects(self) -> List[FromClause]: return self.left._from_objects + self.right._from_objects - def self_group(self, against=None): - - if operators.is_precedent(self.operator, against): - return Grouping(self) - else: - return self - def _negate(self): if self.negate is not None: return BinaryExpression( diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 455e74f7b0..d08bbf4eb3 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -81,6 +81,7 @@ from .elements import ClauseList as ClauseList from .elements import CollectionAggregate as CollectionAggregate from .elements import ColumnClause as ColumnClause from .elements import ColumnElement as ColumnElement +from .elements import ExpressionClauseList as ExpressionClauseList from .elements import Extract as Extract from .elements import False_ as False_ from .elements import FunctionFilter as FunctionFilter diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 803e856548..8d98f893fb 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -84,6 +84,7 @@ class HasExpressionLookup(TypeEngineMixin): raise NotImplementedError() class Comparator(TypeEngine.Comparator[_CT]): + __slots__ = () _blank_dict = util.EMPTY_DICT @@ -114,6 +115,8 @@ class Concatenable(TypeEngineMixin): typically strings.""" class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _adapt_expression( self, op: OperatorType, @@ -143,6 +146,8 @@ class Indexable(TypeEngineMixin): """ class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _setup_getitem(self, index): raise NotImplementedError() @@ -174,12 +179,9 @@ class String(Concatenable, TypeEngine[str]): __visit_name__ = "string" def __init__( - # note pylance appears to require the "self" type in a constructor - # for the _T type to be correctly recognized when we send the - # class as the argument, e.g. `column("somecol", String)` self, - length=None, - collation=None, + length: Optional[int] = None, + collation: Optional[str] = None, ): """ Create a string-holding type. @@ -1508,6 +1510,8 @@ class Enum(String, SchemaType, Emulated, TypeEngine[Union[str, enum.Enum]]): ) from err class Comparator(String.Comparator[str]): + __slots__ = () + type: String def _adapt_expression( @@ -1963,7 +1967,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator[dt.timedelta]): TypeDecorator.Comparator[_CT], _AbstractInterval.Comparator[_CT], ): - pass + __slots__ = () comparator_factory = Comparator @@ -2385,6 +2389,8 @@ class JSON(Indexable, TypeEngine[Any]): class Comparator(Indexable.Comparator[_T], Concatenable.Comparator[_T]): """Define comparison operations for :class:`_types.JSON`.""" + __slots__ = () + def _setup_getitem(self, index): if not isinstance(index, str) and isinstance( index, collections_abc.Sequence @@ -2710,6 +2716,8 @@ class ARRAY( """ + __slots__ = () + def _setup_getitem(self, index): arr_type = cast(ARRAY, self.type) @@ -3221,6 +3229,8 @@ class NullType(TypeEngine[None]): return process class Comparator(TypeEngine.Comparator[_T]): + __slots__ = () + def _adapt_expression( self, op: OperatorType, diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index cbc4e9e707..c23cd04dd4 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -924,7 +924,7 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): ): return COMPARE_FAILED - def compare_clauselist(self, left, right, **kw): + def compare_expression_clauselist(self, left, right, **kw): if left.operator is right.operator: if operators.is_associative(left.operator): if self._compare_unordered_sequences( @@ -938,6 +938,9 @@ class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots): else: return COMPARE_FAILED + def compare_clauselist(self, left, right, **kw): + return self.compare_expression_clauselist(left, right, **kw) + def compare_binary(self, left, right, **kw): if left.operator == right.operator: if operators.is_commutative(left.operator): diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index cc14dd9c4f..25fe844c38 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -320,6 +320,31 @@ class StringTest(_LiteralRoundTripFixture, fixtures.TestBase): data = r"backslash one \ backslash two \\ end" literal_round_trip(String(40), [data], [data]) + def test_concatenate_binary(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_(connection.scalar(select(literal("a") + "b")), "ab") + + def test_concatenate_clauselist(self, connection): + """dialects with special string concatenation operators should + implement visit_concat_op_binary() and visit_concat_op_clauselist() + in their compiler. + + .. versionchanged:: 2.0 visit_concat_op_clauselist() is also needed + for dialects to override the string concatenation operator. + + """ + eq_( + connection.scalar(select(literal("a") + "b" + "c" + "d" + "e")), + "abcde", + ) + class _DateFixture(_LiteralRoundTripFixture, fixtures.TestBase): compare = None diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py index fa8f37c1f8..3fb52416ec 100644 --- a/test/dialect/mysql/test_compiler.py +++ b/test/dialect/mysql/test_compiler.py @@ -1121,7 +1121,7 @@ class InsertOnDuplicateTest(fixtures.TestBase, AssertsCompiledSQL): expected_sql = ( "INSERT INTO foos (id, bar) VALUES (%s, %s), (%s, %s) ON " "DUPLICATE KEY UPDATE bar = coalesce(VALUES(bar)), " - "baz = (concat(concat(VALUES(baz), %s), VALUES(bar)))" + "baz = (concat(VALUES(baz), %s, VALUES(bar)))" ) self.assert_compile( stmt, diff --git a/test/sql/test_operators.py b/test/sql/test_operators.py index e5176713ab..88d1ea0530 100644 --- a/test/sql/test_operators.py +++ b/test/sql/test_operators.py @@ -165,6 +165,21 @@ class DefaultColumnComparatorTest( loop = LoopOperate() is_(operator(loop, *arg), operator) + @testing.combinations( + operators.add, + operators.and_, + operators.or_, + operators.mul, + argnames="op", + ) + def test_nonsensical_negations(self, op): + + opstring = compiler.OPERATORS[op] + self.assert_compile( + select(~op(column("x"), column("q"))), + f"SELECT NOT (x{opstring}q) AS anon_1", + ) + def test_null_true_false_is_sanity_checks(self): d = default.DefaultDialect() @@ -328,6 +343,176 @@ class DefaultColumnComparatorTest( ) +class MultiElementExprTest(fixtures.TestBase, testing.AssertsCompiledSQL): + __dialect__ = "default" + + @testing.combinations(True, False, argnames="reverse") + @testing.combinations(True, False, argnames="negate") + def test_associatives_mismatched_type(self, reverse, negate): + """test we get two separate exprs if the types dont match, operator + is not lost. + + the expressions here don't generally make sense from a SQL + perspective, we are checking just that the operators / parenthesis / + negation works out in the SQL string to reasonably correspond to + what the Python structures look like. + + """ + + expr1 = column("i1", Integer) + column("i2", Integer) + + expr2 = column("d1", String) + column("d2", String) + + if reverse: + expr = expr2 + expr1 + + self.assert_compile( + select(expr), "SELECT (d1 || d2) + i1 + i2 AS anon_1" + ) + else: + expr = expr1 + expr2 + + self.assert_compile( + select(expr), "SELECT i1 + i2 + d1 || d2 AS anon_1" + ) + + @testing.combinations( + operators.add, + operators.and_, + operators.or_, + operators.mul, + argnames="op", + ) + @testing.combinations(True, False, argnames="reverse") + @testing.combinations(True, False, argnames="negate") + def test_associatives(self, op, reverse, negate): + t1 = table("t", column("q"), column("p")) + + num = 500 + + expr = op(t1.c.q, t1.c.p) + + if reverse: + for i in range(num - 1, -1, -1): + expr = op(column(f"d{i}"), expr) + else: + for i in range(num): + expr = op(expr, column(f"d{i}")) + + opstring = compiler.OPERATORS[op] + exprs = opstring.join(f"d{i}" for i in range(num)) + + if negate: + self.assert_compile( + select(~expr), + f"SELECT NOT (t.q{opstring}t.p{opstring}{exprs}) " + "AS anon_1 FROM t" + if not reverse + else f"SELECT NOT ({exprs}{opstring}t.q{opstring}t.p) " + "AS anon_1 FROM t", + ) + else: + self.assert_compile( + select(expr), + f"SELECT t.q{opstring}t.p{opstring}{exprs} AS anon_1 FROM t" + if not reverse + else f"SELECT {exprs}{opstring}t.q{opstring}t.p " + f"AS anon_1 FROM t", + ) + + @testing.combinations( + operators.gt, + operators.eq, + operators.le, + operators.sub, + argnames="op", + ) + @testing.combinations(True, False, argnames="reverse") + @testing.combinations(True, False, argnames="negate") + def test_non_associatives(self, op, reverse, negate): + """similar tests as test_associatives but for non-assoc + operators. + + the expressions here don't generally make sense from a SQL + perspective, we are checking just that the operators / parenthesis / + negation works out in the SQL string to reasonably correspond to + what the Python structures look like. + + """ + t1 = table("t", column("q"), column("p")) + + num = 5 + + expr = op(t1.c.q, t1.c.p) + + if reverse: + for i in range(num - 1, -1, -1): + expr = op(column(f"d{i}"), expr) + else: + for i in range(num): + expr = op(expr, column(f"d{i}")) + + opstring = compiler.OPERATORS[op] + exprs = opstring.join(f"d{i}" for i in range(num)) + + if negate: + negate_op = { + operators.gt: operators.le, + operators.eq: operators.ne, + operators.le: operators.gt, + }.get(op, None) + + if negate_op: + negate_opstring = compiler.OPERATORS[negate_op] + if reverse: + str_expr = ( + f"d0{negate_opstring}(d1{opstring}(d2{opstring}" + f"(d3{opstring}(d4{opstring}(t.q{opstring}t.p)))))" + ) + else: + str_expr = ( + f"(((((t.q{opstring}t.p){opstring}d0){opstring}d1)" + f"{opstring}d2){opstring}d3){negate_opstring}d4" + ) + else: + if reverse: + str_expr = ( + f"NOT (d0{opstring}(d1{opstring}(d2{opstring}" + f"(d3{opstring}(d4{opstring}(t.q{opstring}t.p))))))" + ) + else: + str_expr = ( + f"NOT ((((((t.q{opstring}t.p){opstring}d0)" + f"{opstring}d1){opstring}d2){opstring}d3){opstring}d4)" + ) + + self.assert_compile( + select(~expr), + f"SELECT {str_expr} AS anon_1 FROM t" + if not reverse + else f"SELECT {str_expr} AS anon_1 FROM t", + ) + else: + + if reverse: + str_expr = ( + f"d0{opstring}(d1{opstring}(d2{opstring}" + f"(d3{opstring}(d4{opstring}(t.q{opstring}t.p)))))" + ) + else: + str_expr = ( + f"(((((t.q{opstring}t.p){opstring}d0)" + f"{opstring}d1){opstring}d2){opstring}d3){opstring}d4" + ) + + self.assert_compile( + select(expr), + f"SELECT {str_expr} AS anon_1 FROM t" + if not reverse + else f"SELECT {str_expr} AS anon_1 FROM t", + ) + + class CustomUnaryOperatorTest(fixtures.TestBase, testing.AssertsCompiledSQL): __dialect__ = "default" @@ -2954,7 +3139,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_contains_concat(self): self.assert_compile( column("x").contains("y"), - "x LIKE concat(concat('%%', %s), '%%')", + "x LIKE concat('%%', %s, '%%')", checkparams={"x_1": "y"}, dialect=mysql.dialect(), ) @@ -2962,7 +3147,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_not_contains_concat(self): self.assert_compile( ~column("x").contains("y"), - "x NOT LIKE concat(concat('%%', %s), '%%')", + "x NOT LIKE concat('%%', %s, '%%')", checkparams={"x_1": "y"}, dialect=mysql.dialect(), ) @@ -2970,7 +3155,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_contains_literal_concat(self): self.assert_compile( column("x").contains(literal_column("y")), - "x LIKE concat(concat('%%', y), '%%')", + "x LIKE concat('%%', y, '%%')", checkparams={}, dialect=mysql.dialect(), ) @@ -2978,7 +3163,7 @@ class ComposedLikeOperatorsTest(fixtures.TestBase, testing.AssertsCompiledSQL): def test_contains_text_concat(self): self.assert_compile( column("x").contains(text("y")), - "x LIKE concat(concat('%%', y), '%%')", + "x LIKE concat('%%', y, '%%')", checkparams={}, dialect=mysql.dialect(), )