From: Mike Bayer Date: Wed, 3 Jun 2020 21:38:35 +0000 (-0400) Subject: Convert bulk update/delete to new execution model X-Git-Tag: rel_1_4_0b1~277 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=3ab2364e78641c4f0e4b6456afc2cbed39b0d0e6;p=thirdparty%2Fsqlalchemy%2Fsqlalchemy.git Convert bulk update/delete to new execution model This reorganizes the BulkUD model in sqlalchemy.orm.persistence to be based on the CompileState concept and to allow plain update() / delete() to be passed to session.execute() where the ORM synchronize session logic will take place. Also gets "synchronize_session='fetch'" working with horizontal sharding. Adding a few more result.scalar_one() types of methods as scalar_one() seems like what is normally desired. Fixes: #5160 Change-Id: I8001ebdad089da34119eb459709731ba6c0ba975 --- diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index 1d832e4afa..d03d79df72 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -1630,6 +1630,15 @@ class CursorResult(BaseCursorResult, Result): def _raw_row_iterator(self): return self._fetchiter_impl() + def merge(self, *others): + merged_result = super(CursorResult, self).merge(*others) + setup_rowcounts = not self._metadata.returns_rows + if setup_rowcounts: + merged_result.rowcount = sum( + result.rowcount for result in (self,) + others + ) + return merged_result + def close(self): """Close this :class:`_engine.CursorResult`. diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index 600229037b..b29bc22d44 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -951,7 +951,7 @@ class Result(InPlaceGenerative): """ return self._allrows() - def _only_one_row(self, raise_for_second_row, raise_for_none): + def _only_one_row(self, raise_for_second_row, raise_for_none, scalar): onerow = self._fetchone_impl row = onerow(hard_close=True) @@ -1010,27 +1010,43 @@ class Result(InPlaceGenerative): # if we checked for second row then that would have # closed us :) self._soft_close(hard=True) - post_creational_filter = self._post_creational_filter - if post_creational_filter: - row = post_creational_filter(row) - return row + if not scalar: + post_creational_filter = self._post_creational_filter + if post_creational_filter: + row = post_creational_filter(row) + + if scalar and row: + return row[0] + else: + return row def first(self): """Fetch the first row or None if no row is present. Closes the result set and discards remaining rows. + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the :meth:`.Result.scalar` method, + or combine :meth:`.Result.scalars` and :meth:`.Result.first`. + .. comment: A warning is emitted if additional rows remain. :return: a :class:`.Row` object if no filters are applied, or None if no rows remain. When filters are applied, such as :meth:`_engine.Result.mappings` - or :meth:`._engine.Result.scalar`, different kinds of objects + or :meth:`._engine.Result.scalars`, different kinds of objects may be returned. + .. seealso:: + + :meth:`_result.Result.scalar` + + :meth:`_result.Result.one` + """ - return self._only_one_row(False, False) + return self._only_one_row(False, False, False) def one_or_none(self): """Return at most one result or raise an exception. @@ -1055,15 +1071,50 @@ class Result(InPlaceGenerative): :meth:`_result.Result.one` """ - return self._only_one_row(True, False) + return self._only_one_row(True, False, False) + + def scalar_one(self): + """Return exactly one scalar result or raise an exception. + + This is equvalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one`. + + .. seealso:: + + :meth:`.Result.one` + + :meth:`.Result.scalars` + + """ + return self._only_one_row(True, True, True) + + def scalar_one_or_none(self): + """Return exactly one or no scalar result. + + This is equvalent to calling :meth:`.Result.scalars` and then + :meth:`.Result.one_or_none`. + + .. seealso:: + + :meth:`.Result.one_or_none` + + :meth:`.Result.scalars` + + """ + return self._only_one_row(True, False, True) def one(self): - """Return exactly one result or raise an exception. + """Return exactly one row or raise an exception. Raises :class:`.NoResultFound` if the result returns no rows, or :class:`.MultipleResultsFound` if multiple rows would be returned. + .. note:: This method returns one **row**, e.g. tuple, by default. + To return exactly one single scalar value, that is, the first + column of the first row, use the :meth:`.Result.scalar_one` method, + or combine :meth:`.Result.scalars` and :meth:`.Result.one`. + .. versionadded:: 1.4 :return: The first :class:`.Row`. @@ -1079,24 +1130,26 @@ class Result(InPlaceGenerative): :meth:`_result.Result.one_or_none` + :meth:`_result.Result.scalar_one` + """ - return self._only_one_row(True, True) + return self._only_one_row(True, True, False) def scalar(self): """Fetch the first column of the first row, and close the result set. + Returns None if there are no rows to fetch. + + No validation is performed to test if additional rows remain. + After calling this method, the object is fully closed, e.g. the :meth:`_engine.CursorResult.close` method will have been called. - :return: a Python scalar value , or None if no rows remain + :return: a Python scalar value , or None if no rows remain. """ - row = self.first() - if row is not None: - return row[0] - else: - return None + return self._only_one_row(False, False, True) class FrozenResult(object): diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index c3ac71c103..0983807cb9 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -50,58 +50,6 @@ class ShardedQuery(Query): """ return self.execution_options(_sa_shard_id=shard_id) - def _execute_crud(self, stmt, mapper): - def exec_for_shard(shard_id): - conn = self.session.connection( - mapper=mapper, - shard_id=shard_id, - clause=stmt, - close_with_result=True, - ) - result = conn._execute_20( - stmt, self.load_options._params, self._execution_options - ) - return result - - if self._shard_id is not None: - return exec_for_shard(self._shard_id) - else: - rowcount = 0 - results = [] - # TODO: this will have to be the new object - for shard_id in self.execute_chooser(self): - result = exec_for_shard(shard_id) - rowcount += result.rowcount - results.append(result) - - return ShardedResult(results, rowcount) - - -class ShardedResult(object): - """A value object that represents multiple :class:`_engine.CursorResult` - objects. - - This is used by the :meth:`.ShardedQuery._execute_crud` hook to return - an object that takes the place of the single :class:`_engine.CursorResult`. - - Attribute include ``result_proxies``, which is a sequence of the - actual :class:`_engine.CursorResult` objects, - as well as ``aggregate_rowcount`` - or ``rowcount``, which is the sum of all the individual rowcount values. - - .. versionadded:: 1.3 - """ - - __slots__ = ("result_proxies", "aggregate_rowcount") - - def __init__(self, result_proxies, aggregate_rowcount): - self.result_proxies = result_proxies - self.aggregate_rowcount = aggregate_rowcount - - @property - def rowcount(self): - return self.aggregate_rowcount - class ShardedSession(Session): def __init__( @@ -259,37 +207,40 @@ class ShardedSession(Session): def execute_and_instances(orm_context): - if orm_context.bind_arguments.get("_horizontal_shard", False): - return None - params = orm_context.parameters - load_options = orm_context.load_options + if orm_context.is_select: + load_options = active_options = orm_context.load_options + update_options = None + if params is None: + params = active_options._params + + else: + load_options = None + update_options = active_options = orm_context.update_delete_options + session = orm_context.session # orm_query = orm_context.orm_query - if params is None: - params = load_options._params - - def iter_for_shard(shard_id, load_options): + def iter_for_shard(shard_id, load_options, update_options): execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) - bind_arguments["_horizontal_shard"] = True bind_arguments["shard_id"] = shard_id - load_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_load_options"] = load_options + if orm_context.is_select: + load_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_load_options"] = load_options + else: + update_options += {"_refresh_identity_token": shard_id} + execution_options["_sa_orm_update_options"] = update_options - return session.execute( - orm_context.statement, - orm_context.parameters, - execution_options, - bind_arguments, + return orm_context.invoke_statement( + bind_arguments=bind_arguments, execution_options=execution_options ) - if load_options._refresh_identity_token is not None: - shard_id = load_options._refresh_identity_token + if active_options._refresh_identity_token is not None: + shard_id = active_options._refresh_identity_token elif "_sa_shard_id" in orm_context.merged_execution_options: shard_id = orm_context.merged_execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: @@ -298,11 +249,11 @@ def execute_and_instances(orm_context): shard_id = None if shard_id is not None: - return iter_for_shard(shard_id, load_options) + return iter_for_shard(shard_id, load_options, update_options) else: partial = [] for shard_id in session.execute_chooser(orm_context): - result_ = iter_for_shard(shard_id, load_options) + result_ = iter_for_shard(shard_id, load_options, update_options) partial.append(result_) return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/ext/hybrid.py b/lib/sqlalchemy/ext/hybrid.py index 9f73b5d31b..efd8d7d6b2 100644 --- a/lib/sqlalchemy/ext/hybrid.py +++ b/lib/sqlalchemy/ext/hybrid.py @@ -777,7 +777,7 @@ things it can be used for. from .. import util from ..orm import attributes from ..orm import interfaces - +from ..sql import elements HYBRID_METHOD = util.symbol("HYBRID_METHOD") """Symbol indicating an :class:`InspectionAttr` that's @@ -1144,6 +1144,9 @@ class ExprComparator(Comparator): return self.hybrid.info def _bulk_update_tuples(self, value): + if isinstance(value, elements.BindParameter): + value = value.value + if isinstance(self.expression, attributes.QueryableAttribute): return self.expression._bulk_update_tuples(value) elif self.hybrid.update_expr is not None: diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index bd4074ea11..a16db66f6d 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -189,7 +189,7 @@ class ORMCompileState(CompileState): @classmethod def orm_pre_session_exec( - cls, session, statement, execution_options, bind_arguments + cls, session, statement, params, execution_options, bind_arguments ): load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options @@ -216,6 +216,8 @@ class ORMCompileState(CompileState): if load_options._autoflush: session._autoflush() + return execution_options + @classmethod def orm_setup_cursor_result( cls, session, statement, execution_options, bind_arguments, result diff --git a/lib/sqlalchemy/orm/descriptor_props.py b/lib/sqlalchemy/orm/descriptor_props.py index 6be4f0dff8..027f2521b1 100644 --- a/lib/sqlalchemy/orm/descriptor_props.py +++ b/lib/sqlalchemy/orm/descriptor_props.py @@ -420,6 +420,9 @@ class CompositeProperty(DescriptorProperty): return CompositeProperty.CompositeBundle(self.prop, clauses) def _bulk_update_tuples(self, value): + if isinstance(value, sql.elements.BindParameter): + value = value.value + if value is None: values = [None for key in self.prop._attribute_keys] elif isinstance(value, self.prop.composite_class): diff --git a/lib/sqlalchemy/orm/events.py b/lib/sqlalchemy/orm/events.py index be7aa272ea..217aa76c75 100644 --- a/lib/sqlalchemy/orm/events.py +++ b/lib/sqlalchemy/orm/events.py @@ -1764,7 +1764,7 @@ class SessionEvents(event.Events): lambda update_context: ( update_context.session, update_context.query, - update_context.context, + None, update_context.result, ), ) @@ -1782,12 +1782,13 @@ class SessionEvents(event.Events): was called upon. * ``values`` The "values" dictionary that was passed to :meth:`_query.Query.update`. - * ``context`` The :class:`.QueryContext` object, corresponding - to the invocation of an ORM query. * ``result`` the :class:`_engine.CursorResult` returned as a result of the bulk UPDATE operation. + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + .. seealso:: :meth:`.QueryEvents.before_compile_update` @@ -1802,7 +1803,7 @@ class SessionEvents(event.Events): lambda delete_context: ( delete_context.session, delete_context.query, - delete_context.context, + None, delete_context.result, ), ) @@ -1818,12 +1819,13 @@ class SessionEvents(event.Events): * ``query`` -the :class:`_query.Query` object that this update operation was called upon. - * ``context`` The :class:`.QueryContext` object, corresponding - to the invocation of an ORM query. * ``result`` the :class:`_engine.CursorResult` returned as a result of the bulk DELETE operation. + .. versionchanged:: 1.4 the update_context no longer has a + ``QueryContext`` object associated with it. + .. seealso:: :meth:`.QueryEvents.before_compile_delete` diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 4166e6d2a9..c4cb89c038 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -2235,14 +2235,28 @@ class Mapper( @HasMemoized.memoized_instancemethod def __clause_element__(self): - return self.selectable._annotate( - { - "entity_namespace": self, - "parententity": self, - "parentmapper": self, - "compile_state_plugin": "orm", - } - )._set_propagate_attrs( + + annotations = { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "compile_state_plugin": "orm", + } + if self.persist_selectable is not self.local_table: + # joined table inheritance, with polymorphic selectable, + # etc. + annotations["dml_table"] = self.local_table._annotate( + { + "entity_namespace": self, + "parententity": self, + "parentmapper": self, + "compile_state_plugin": "orm", + } + )._set_propagate_attrs( + {"compile_state_plugin": "orm", "plugin_subject": self} + ) + + return self.selectable._annotate(annotations)._set_propagate_attrs( {"compile_state_plugin": "orm", "plugin_subject": self} ) diff --git a/lib/sqlalchemy/orm/persistence.py b/lib/sqlalchemy/orm/persistence.py index 163ebf22a5..19d43d354d 100644 --- a/lib/sqlalchemy/orm/persistence.py +++ b/lib/sqlalchemy/orm/persistence.py @@ -28,11 +28,15 @@ from .. import exc as sa_exc from .. import future from .. import sql from .. import util +from ..future import select as future_select from ..sql import coercions from ..sql import expression from ..sql import operators from ..sql import roles -from ..sql.base import _from_objects +from ..sql.base import CompileState +from ..sql.base import Options +from ..sql.dml import DeleteDMLState +from ..sql.dml import UpdateDMLState from ..sql.elements import BooleanClauseList @@ -1650,243 +1654,193 @@ def _sort_states(mapper, states): ) -class BulkUD(object): - """Handle bulk update and deletes via a :class:`_query.Query`.""" +_EMPTY_DICT = util.immutabledict() - def __init__(self, query): - self.query = query.enable_eagerloads(False) - self._validate_query_state() - def _validate_query_state(self): - for attr, methname, notset, op in ( - ("_limit_clause", "limit()", None, operator.is_), - ("_offset_clause", "offset()", None, operator.is_), - ("_order_by_clauses", "order_by()", (), operator.eq), - ("_group_by_clauses", "group_by()", (), operator.eq), - ("_distinct", "distinct()", False, operator.is_), - ( - "_from_obj", - "join(), outerjoin(), select_from(), or from_self()", - (), - operator.eq, - ), - ( - "_legacy_setup_joins", - "join(), outerjoin(), select_from(), or from_self()", - (), - operator.eq, - ), - ): - if not op(getattr(self.query, attr), notset): - raise sa_exc.InvalidRequestError( - "Can't call Query.update() or Query.delete() " - "when %s has been called" % (methname,) - ) - - @property - def session(self): - return self.query.session +class BulkUDCompileState(CompileState): + class default_update_options(Options): + _synchronize_session = "evaluate" + _autoflush = True + _subject_mapper = None + _resolved_values = _EMPTY_DICT + _resolved_keys_as_propnames = _EMPTY_DICT + _value_evaluators = _EMPTY_DICT + _matched_objects = None + _matched_rows = None + _refresh_identity_token = None @classmethod - def _factory(cls, lookup, synchronize_session, *arg): - try: - klass = lookup[synchronize_session] - except KeyError as err: - util.raise_( - sa_exc.ArgumentError( - "Valid strategies for session synchronization " - "are %s" % (", ".join(sorted(repr(x) for x in lookup))) - ), - replace_context=err, + def orm_pre_session_exec( + cls, session, statement, params, execution_options, bind_arguments + ): + sync = execution_options.get("synchronize_session", None) + if sync is None: + sync = statement._execution_options.get( + "synchronize_session", None ) - else: - return klass(*arg) - - def exec_(self): - self._do_before_compile() - self._do_pre() - self._do_pre_synchronize() - self._do_exec() - self._do_post_synchronize() - self._do_post() - - def _execute_stmt(self, stmt): - self.result = self.query._execute_crud(stmt, self.mapper) - self.rowcount = self.result.rowcount - - def _do_before_compile(self): - raise NotImplementedError() - @util.preload_module("sqlalchemy.orm.context") - def _do_pre(self): - query_context = util.preloaded.orm_context - query = self.query - - self.compile_state = ( - self.context - ) = compile_state = query._compile_state() - - self.mapper = compile_state._entity_zero() - - if isinstance( - compile_state._entities[0], query_context._RawColumnEntity, - ): - # check for special case of query(table) - tables = set() - for ent in compile_state._entities: - if not isinstance(ent, query_context._RawColumnEntity,): - tables.clear() - break - else: - tables.update(_from_objects(ent.column)) + update_options = execution_options.get( + "_sa_orm_update_options", + BulkUDCompileState.default_update_options, + ) - if len(tables) != 1: - raise sa_exc.InvalidRequestError( - "This operation requires only one Table or " - "entity be specified as the target." + if sync is not None: + if sync not in ("evaluate", "fetch", False): + raise sa_exc.ArgumentError( + "Valid strategies for session synchronization " + "are 'evaluate', 'fetch', False" ) - else: - self.primary_table = tables.pop() + update_options += {"_synchronize_session": sync} + bind_arguments["clause"] = statement + try: + plugin_subject = statement._propagate_attrs["plugin_subject"] + except KeyError: + assert False, "statement had 'orm' plugin but no plugin_subject" else: - self.primary_table = compile_state._only_entity_zero( - "This operation requires only one Table or " - "entity be specified as the target." - ).mapper.local_table + bind_arguments["mapper"] = plugin_subject.mapper - session = query.session + update_options += {"_subject_mapper": plugin_subject.mapper} - if query.load_options._autoflush: + if update_options._autoflush: session._autoflush() - def _do_pre_synchronize(self): - pass + if update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) - def _do_post_synchronize(self): - pass + return util.immutabledict(execution_options).union( + dict(_sa_orm_update_options=update_options) + ) + @classmethod + def orm_setup_cursor_result( + cls, session, statement, execution_options, bind_arguments, result + ): + update_options = execution_options["_sa_orm_update_options"] + if update_options._synchronize_session == "evaluate": + cls._do_post_synchronize_evaluate(session, update_options) + elif update_options._synchronize_session == "fetch": + cls._do_post_synchronize_fetch(session, update_options) -class BulkEvaluate(BulkUD): - """BulkUD which does the 'evaluate' method of session state resolution.""" + return result - def _additional_evaluators(self, evaluator_compiler): - pass + @classmethod + def _do_pre_synchronize_evaluate( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + mapper = update_options._subject_mapper + target_cls = mapper.class_ - def _do_pre_synchronize(self): - query = self.query - target_cls = self.compile_state._mapper_zero().class_ + value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT try: evaluator_compiler = evaluator.EvaluatorCompiler(target_cls) - if query._where_criteria: + if statement._where_criteria: eval_condition = evaluator_compiler.process( - *query._where_criteria + *statement._where_criteria ) else: def eval_condition(obj): return True - self._additional_evaluators(evaluator_compiler) + # TODO: something more robust for this conditional + if statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values(mapper, statement) + value_evaluators = {} + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + for key, value in resolved_keys_as_propnames: + value_evaluators[key] = evaluator_compiler.process( + coercions.expect(roles.ExpressionElementRole, value) + ) except evaluator.UnevaluatableError as err: util.raise_( sa_exc.InvalidRequestError( 'Could not evaluate current criteria in Python: "%s". ' "Specify 'fetch' or False for the " - "synchronize_session parameter." % err + "synchronize_session execution option." % err ), from_=err, ) # TODO: detect when the where clause is a trivial primary key match - self.matched_objects = [ + matched_objects = [ obj - for ( - cls, - pk, - identity_token, - ), obj in query.session.identity_map.items() - if issubclass(cls, target_cls) and eval_condition(obj) + for (cls, pk, identity_token,), obj in session.identity_map.items() + if issubclass(cls, target_cls) + and eval_condition(obj) + and identity_token == update_options._refresh_identity_token ] - - -class BulkFetch(BulkUD): - """BulkUD which does the 'fetch' method of session state resolution.""" - - def _do_pre_synchronize(self): - query = self.query - session = query.session - select_stmt = self.compile_state.statement.with_only_columns( - self.primary_table.primary_key - ) - self.matched_rows = session.execute( - select_stmt, mapper=self.mapper, params=query.load_options._params - ).fetchall() - - -class BulkUpdate(BulkUD): - """BulkUD which handles UPDATEs.""" - - def __init__(self, query, values, update_kwargs): - super(BulkUpdate, self).__init__(query) - self.values = values - self.update_kwargs = update_kwargs + return update_options + { + "_matched_objects": matched_objects, + "_value_evaluators": value_evaluators, + "_resolved_keys_as_propnames": resolved_keys_as_propnames, + } @classmethod - def factory(cls, query, synchronize_session, values, update_kwargs): - return BulkUD._factory( - { - "evaluate": BulkUpdateEvaluate, - "fetch": BulkUpdateFetch, - False: BulkUpdate, - }, - synchronize_session, - query, - values, - update_kwargs, - ) - - def _do_before_compile(self): - if self.query.dispatch.before_compile_update: - for fn in self.query.dispatch.before_compile_update: - new_query = fn(self.query, self) - if new_query is not None: - self.query = new_query + def _get_resolved_values(cls, mapper, statement): + if statement._multi_values: + return [] + elif statement._ordered_values: + iterator = statement._ordered_values + elif statement._values: + iterator = statement._values.items() + else: + return [] - @property - def _resolved_values(self): values = [] - for k, v in ( - self.values.items() - if hasattr(self.values, "items") - else self.values - ): - if self.mapper: - if isinstance(k, util.string_types): - desc = sql.util._entity_namespace_key(self.mapper, k) - values.extend(desc._bulk_update_tuples(v)) - elif isinstance(k, attributes.QueryableAttribute): - values.extend(k._bulk_update_tuples(v)) + if iterator: + for k, v in iterator: + if mapper: + if isinstance(k, util.string_types): + desc = sql.util._entity_namespace_key(mapper, k) + values.extend(desc._bulk_update_tuples(v)) + elif isinstance(k, attributes.QueryableAttribute): + values.extend(k._bulk_update_tuples(v)) + else: + values.append((k, v)) else: values.append((k, v)) - else: - values.append((k, v)) return values - @property - def _resolved_values_keys_as_propnames(self): + @classmethod + def _resolved_keys_as_propnames(cls, mapper, resolved_values): values = [] - for k, v in self._resolved_values: + for k, v in resolved_values: if isinstance(k, attributes.QueryableAttribute): values.append((k.key, v)) continue elif hasattr(k, "__clause_element__"): k = k.__clause_element__() - if self.mapper and isinstance(k, expression.ColumnElement): + if mapper and isinstance(k, expression.ColumnElement): try: - attr = self.mapper._columntoproperty[k] + attr = mapper._columntoproperty[k] except orm_exc.UnmappedColumnError: pass else: @@ -1897,87 +1851,99 @@ class BulkUpdate(BulkUD): ) return values - def _do_exec(self): - values = self._resolved_values + @classmethod + def _do_pre_synchronize_fetch( + cls, + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ): + mapper = update_options._subject_mapper - if not self.update_kwargs.get("preserve_parameter_order", False): - values = dict(values) + if mapper: + primary_table = mapper.local_table + else: + primary_table = statement._raw_columns[0] - update_stmt = sql.update( - self.primary_table, **self.update_kwargs - ).values(values) + # note this creates a Select() *without* the ORM plugin. + # we don't want that here. + select_stmt = future_select(*primary_table.primary_key) + select_stmt._where_criteria = statement._where_criteria - update_stmt._where_criteria = self.compile_state._where_criteria + matched_rows = session.execute( + select_stmt, params, execution_options, bind_arguments + ).fetchall() - self._execute_stmt(update_stmt) + if statement.__visit_name__ == "update": + resolved_values = cls._get_resolved_values(mapper, statement) + resolved_keys_as_propnames = cls._resolved_keys_as_propnames( + mapper, resolved_values + ) + else: + resolved_keys_as_propnames = _EMPTY_DICT - def _do_post(self): - session = self.query.session - session.dispatch.after_bulk_update(self) + return update_options + { + "_matched_rows": matched_rows, + "_resolved_keys_as_propnames": resolved_keys_as_propnames, + } -class BulkDelete(BulkUD): - """BulkUD which handles DELETEs.""" +@CompileState.plugin_for("orm", "update") +class BulkORMUpdate(UpdateDMLState, BulkUDCompileState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): - def __init__(self, query): - super(BulkDelete, self).__init__(query) + self = cls.__new__(cls) - @classmethod - def factory(cls, query, synchronize_session): - return BulkUD._factory( - { - "evaluate": BulkDeleteEvaluate, - "fetch": BulkDeleteFetch, - False: BulkDelete, - }, - synchronize_session, - query, + self.mapper = mapper = statement.table._annotations.get( + "parentmapper", None ) - def _do_before_compile(self): - if self.query.dispatch.before_compile_delete: - for fn in self.query.dispatch.before_compile_delete: - new_query = fn(self.query, self) - if new_query is not None: - self.query = new_query + self._resolved_values = cls._get_resolved_values(mapper, statement) - def _do_exec(self): - delete_stmt = sql.delete(self.primary_table,) - delete_stmt._where_criteria = self.compile_state._where_criteria + if not statement._preserve_parameter_order and statement._values: + self._resolved_values = dict(self._resolved_values) - self._execute_stmt(delete_stmt) + new_stmt = sql.Update.__new__(sql.Update) + new_stmt.__dict__.update(statement.__dict__) + new_stmt.table = mapper.local_table - def _do_post(self): - session = self.query.session - session.dispatch.after_bulk_delete(self) + # note if the statement has _multi_values, these + # are passed through to the new statement, which will then raise + # InvalidRequestError because UPDATE doesn't support multi_values + # right now. + if statement._ordered_values: + new_stmt._ordered_values = self._resolved_values + elif statement._values: + new_stmt._values = self._resolved_values + UpdateDMLState.__init__(self, new_stmt, compiler, **kw) -class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): - """BulkUD which handles UPDATEs using the "evaluate" - method of session resolution.""" + return self - def _additional_evaluators(self, evaluator_compiler): - self.value_evaluators = {} - values = self._resolved_values_keys_as_propnames - for key, value in values: - self.value_evaluators[key] = evaluator_compiler.process( - coercions.expect(roles.ExpressionElementRole, value) - ) + @classmethod + def _do_post_synchronize_evaluate(cls, session, update_options): - def _do_post_synchronize(self): - session = self.query.session states = set() - evaluated_keys = list(self.value_evaluators.keys()) - for obj in self.matched_objects: + evaluated_keys = list(update_options._value_evaluators.keys()) + for obj in update_options._matched_objects: + state, dict_ = ( attributes.instance_state(obj), attributes.instance_dict(obj), ) + assert ( + state.identity_token == update_options._refresh_identity_token + ) + # only evaluate unmodified attributes to_evaluate = state.unmodified.intersection(evaluated_keys) for key in to_evaluate: - dict_[key] = self.value_evaluators[key](obj) + dict_[key] = update_options._value_evaluators[key](obj) state.manager.dispatch.refresh(state, None, to_evaluate) @@ -1991,39 +1957,25 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate): states.add(state) session._register_altered(states) - -class BulkDeleteEvaluate(BulkEvaluate, BulkDelete): - """BulkUD which handles DELETEs using the "evaluate" - method of session resolution.""" - - def _do_post_synchronize(self): - self.query.session._remove_newly_deleted( - [attributes.instance_state(obj) for obj in self.matched_objects] - ) - - -class BulkUpdateFetch(BulkFetch, BulkUpdate): - """BulkUD which handles UPDATEs using the "fetch" - method of session resolution.""" - - def _do_post_synchronize(self): - session = self.query.session - target_mapper = self.compile_state._mapper_zero() + @classmethod + def _do_post_synchronize_fetch(cls, session, update_options): + target_mapper = update_options._subject_mapper states = set( [ attributes.instance_state(session.identity_map[identity_key]) for identity_key in [ target_mapper.identity_key_from_primary_key( - list(primary_key) + list(primary_key), + identity_token=update_options._refresh_identity_token, ) - for primary_key in self.matched_rows + for primary_key in update_options._matched_rows ] if identity_key in session.identity_map ] ) - values = self._resolved_values_keys_as_propnames + values = update_options._resolved_keys_as_propnames attrib = set(k for k, v in values) for state in states: to_expire = attrib.intersection(state.dict) @@ -2032,18 +1984,38 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate): session._register_altered(states) -class BulkDeleteFetch(BulkFetch, BulkDelete): - """BulkUD which handles DELETEs using the "fetch" - method of session resolution.""" +@CompileState.plugin_for("orm", "delete") +class BulkORMDelete(DeleteDMLState, BulkUDCompileState): + @classmethod + def create_for_statement(cls, statement, compiler, **kw): + self = cls.__new__(cls) + + self.mapper = statement.table._annotations.get("parentmapper", None) + + DeleteDMLState.__init__(self, statement, compiler, **kw) + + return self + + @classmethod + def _do_post_synchronize_evaluate(cls, session, update_options): + + session._remove_newly_deleted( + [ + attributes.instance_state(obj) + for obj in update_options._matched_objects + ] + ) + + @classmethod + def _do_post_synchronize_fetch(cls, session, update_options): + target_mapper = update_options._subject_mapper - def _do_post_synchronize(self): - session = self.query.session - target_mapper = self.compile_state._mapper_zero() - for primary_key in self.matched_rows: + for primary_key in update_options._matched_rows: # TODO: inline this and call remove_newly_deleted # once identity_key = target_mapper.identity_key_from_primary_key( - list(primary_key) + list(primary_key), + identity_token=update_options._refresh_identity_token, ) if identity_key in session.identity_map: session._remove_newly_deleted( diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 5137f9b1d4..284ea9d72b 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -19,12 +19,12 @@ database to return iterable result sets. """ import itertools +import operator from . import attributes from . import exc as orm_exc from . import interfaces from . import loading -from . import persistence from .base import _assertions from .context import _column_descriptions from .context import _legacy_determine_last_joined_entity @@ -2825,15 +2825,6 @@ class Query( return result - def _execute_crud(self, stmt, mapper): - conn = self.session.connection( - mapper=mapper, clause=stmt, close_with_result=True - ) - - return conn._execute_20( - stmt, self.load_options._params, self._execution_options - ) - def __str__(self): statement = self._statement_20() @@ -3178,9 +3169,27 @@ class Query( """ - delete_op = persistence.BulkDelete.factory(self, synchronize_session) - delete_op.exec_() - return delete_op.rowcount + bulk_del = BulkDelete(self,) + if self.dispatch.before_compile_delete: + for fn in self.dispatch.before_compile_delete: + new_query = fn(bulk_del.query, bulk_del) + if new_query is not None: + bulk_del.query = new_query + + self = bulk_del.query + + delete_ = sql.delete(*self._raw_columns) + delete_._where_criteria = self._where_criteria + result = self.session.execute( + delete_, + self.load_options._params, + execution_options={"synchronize_session": synchronize_session}, + ) + bulk_del.result = result + self.session.dispatch.after_bulk_delete(bulk_del) + result.close() + + return result.rowcount def update(self, values, synchronize_session="evaluate", update_args=None): r"""Perform a bulk update query. @@ -3313,11 +3322,27 @@ class Query( """ update_args = update_args or {} - update_op = persistence.BulkUpdate.factory( - self, synchronize_session, values, update_args + + bulk_ud = BulkUpdate(self, values, update_args) + + if self.dispatch.before_compile_update: + for fn in self.dispatch.before_compile_update: + new_query = fn(bulk_ud.query, bulk_ud) + if new_query is not None: + bulk_ud.query = new_query + self = bulk_ud.query + + upd = sql.update(*self._raw_columns, **update_args).values(values) + upd._where_criteria = self._where_criteria + result = self.session.execute( + upd, + self.load_options._params, + execution_options={"synchronize_session": synchronize_session}, ) - update_op.exec_() - return update_op.rowcount + bulk_ud.result = result + self.session.dispatch.after_bulk_update(bulk_ud) + result.close() + return result.rowcount def _compile_state(self, for_statement=False, **kw): """Create an out-of-compiler ORMCompileState object. @@ -3427,3 +3452,59 @@ class AliasOption(interfaces.LoaderOption): def process_compile_state(self, compile_state): pass + + +class BulkUD(object): + """State used for the orm.Query version of update() / delete(). + + This object is now specific to Query only. + + """ + + def __init__(self, query): + self.query = query.enable_eagerloads(False) + self._validate_query_state() + self.mapper = self.query._entity_from_pre_ent_zero() + + def _validate_query_state(self): + for attr, methname, notset, op in ( + ("_limit_clause", "limit()", None, operator.is_), + ("_offset_clause", "offset()", None, operator.is_), + ("_order_by_clauses", "order_by()", (), operator.eq), + ("_group_by_clauses", "group_by()", (), operator.eq), + ("_distinct", "distinct()", False, operator.is_), + ( + "_from_obj", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ( + "_legacy_setup_joins", + "join(), outerjoin(), select_from(), or from_self()", + (), + operator.eq, + ), + ): + if not op(getattr(self.query, attr), notset): + raise sa_exc.InvalidRequestError( + "Can't call Query.update() or Query.delete() " + "when %s has been called" % (methname,) + ) + + @property + def session(self): + return self.query.session + + +class BulkUpdate(BulkUD): + """BulkUD which handles UPDATEs.""" + + def __init__(self, query, values, update_kwargs): + super(BulkUpdate, self).__init__(query) + self.values = values + self.update_kwargs = update_kwargs + + +class BulkDelete(BulkUD): + """BulkUD which handles DELETEs.""" diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index ee42419a26..5ad8bcf2f2 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -33,7 +33,9 @@ from .. import future from .. import util from ..inspection import inspect from ..sql import coercions +from ..sql import dml from ..sql import roles +from ..sql import selectable from ..sql import visitors from ..sql.base import CompileState @@ -113,16 +115,24 @@ class ORMExecuteState(util.MemoizedSlots): "_execution_options", "_merged_execution_options", "bind_arguments", + "_compile_state_cls", ) def __init__( - self, session, statement, parameters, execution_options, bind_arguments + self, + session, + statement, + parameters, + execution_options, + bind_arguments, + compile_state_cls, ): self.session = session self.statement = statement self.parameters = parameters self._execution_options = execution_options self.bind_arguments = bind_arguments + self._compile_state_cls = compile_state_cls def invoke_statement( self, @@ -193,6 +203,38 @@ class ORMExecuteState(util.MemoizedSlots): statement, _params, _execution_options, _bind_arguments ) + @property + def is_orm_statement(self): + """return True if the operation is an ORM statement. + + This indictes that the select(), update(), or delete() being + invoked contains ORM entities as subjects. For a statement + that does not have ORM entities and instead refers only to + :class:`.Table` metadata, it is invoked as a Core SQL statement + and no ORM-level automation takes place. + + """ + return self._compile_state_cls is not None + + @property + def is_select(self): + """return True if this is a SELECT operation.""" + return isinstance(self.statement, selectable.Select) + + @property + def is_update(self): + """return True if this is an UPDATE operation.""" + return isinstance(self.statement, dml.Update) + + @property + def is_delete(self): + """return True if this is a DELETE operation.""" + return isinstance(self.statement, dml.Delete) + + @property + def _is_crud(self): + return isinstance(self.statement, (dml.Update, dml.Delete)) + @property def execution_options(self): """Placeholder for execution options. @@ -270,10 +312,30 @@ class ORMExecuteState(util.MemoizedSlots): def load_options(self): """Return the load_options that will be used for this execution.""" + if not self.is_select: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against a SELECT statement " + "so there are no load options." + ) return self._execution_options.get( "_sa_orm_load_options", context.QueryContext.default_load_options ) + @property + def update_delete_options(self): + """Return the update_delete_options that will be used for this + execution.""" + + if not self._is_crud: + raise sa_exc.InvalidRequestError( + "This ORM execution is not against an UPDATE or DELETE " + "statement so there are no update options." + ) + return self._execution_options.get( + "_sa_orm_update_options", + persistence.BulkUDCompileState.default_update_options, + ) + @property def user_defined_options(self): """The sequence of :class:`.UserDefinedOptions` that have been @@ -1455,35 +1517,37 @@ class Session(_SessionClassMethods): compile_state_cls = CompileState._get_plugin_class_for_plugin( statement, "orm" ) + else: + compile_state_cls = None - compile_state_cls.orm_pre_session_exec( - self, statement, execution_options, bind_arguments + if compile_state_cls is not None: + execution_options = compile_state_cls.orm_pre_session_exec( + self, statement, params, execution_options, bind_arguments ) - - if self.dispatch.do_orm_execute: - skip_events = bind_arguments.pop("_sa_skip_events", False) - - if not skip_events: - orm_exec_state = ORMExecuteState( - self, - statement, - params, - execution_options, - bind_arguments, - ) - for fn in self.dispatch.do_orm_execute: - result = fn(orm_exec_state) - if result: - return result - else: - compile_state_cls = None bind_arguments.setdefault("clause", statement) if statement._is_future: execution_options = util.immutabledict().merge_with( execution_options, {"future_result": True} ) + if self.dispatch.do_orm_execute: + # run this event whether or not we are in ORM mode + skip_events = bind_arguments.get("_sa_skip_events", False) + if not skip_events: + orm_exec_state = ORMExecuteState( + self, + statement, + params, + execution_options, + bind_arguments, + compile_state_cls, + ) + for fn in self.dispatch.do_orm_execute: + result = fn(orm_exec_state) + if result: + return result + bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind, close_with_result=True) @@ -1601,8 +1665,8 @@ class Session(_SessionClassMethods): self.__binds[insp] = bind elif insp.is_mapper: self.__binds[insp.class_] = bind - for selectable in insp._all_tables: - self.__binds[selectable] = bind + for _selectable in insp._all_tables: + self.__binds[_selectable] = bind else: raise sa_exc.ArgumentError( "Not an acceptable bind target: %s" % key @@ -1664,7 +1728,9 @@ class Session(_SessionClassMethods): """ self._add_bind(table, bind) - def get_bind(self, mapper=None, clause=None, bind=None): + def get_bind( + self, mapper=None, clause=None, bind=None, _sa_skip_events=None + ): """Return a "bind" to which this :class:`.Session` is bound. The "bind" is usually an instance of :class:`_engine.Engine`, diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index f143190890..5dd3b519ac 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -446,10 +446,14 @@ class CompileState(object): plugin_name = statement._propagate_attrs.get( "compile_state_plugin", "default" ) - else: - plugin_name = "default" + klass = cls.plugins.get( + (plugin_name, statement.__visit_name__), None + ) + if klass is None: + klass = cls.plugins[("default", statement.__visit_name__)] - klass = cls.plugins[(plugin_name, statement.__visit_name__)] + else: + klass = cls.plugins[("default", statement.__visit_name__)] if klass is cls: return cls(statement, compiler, **kw) diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index db43e42a63..4c6a0317a4 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -755,6 +755,16 @@ class AnonymizedFromClauseImpl(StrictFromClauseImpl): return element.alias(name=name, flat=flat) +class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl): + __slots__ = () + + def _post_coercion(self, element, **kw): + if "dml_table" in element._annotations: + return element._annotations["dml_table"] + else: + return element + + class DMLSelectImpl(_NoTextCoercion, RoleImpl): __slots__ = () diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index f4160b5520..2519438d1b 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -3215,6 +3215,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isupdate = True + if not self.compile_state: + self.compile_state = compile_state extra_froms = compile_state._extra_froms is_multitable = bool(extra_froms) @@ -3342,6 +3344,8 @@ class SQLCompiler(Compiled): toplevel = not self.stack if toplevel: self.isdelete = True + if not self.compile_state: + self.compile_state = compile_state extra_froms = compile_state._extra_froms diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py index 467a764d62..a82641d77c 100644 --- a/lib/sqlalchemy/sql/dml.py +++ b/lib/sqlalchemy/sql/dml.py @@ -19,6 +19,7 @@ from .base import CompileState from .base import DialectKWArgs from .base import Executable from .base import HasCompileState +from .elements import BooleanClauseList from .elements import ClauseElement from .elements import Null from .selectable import HasCTE @@ -150,7 +151,6 @@ class UpdateDMLState(DMLState): def __init__(self, statement, compiler, **kw): self.statement = statement - self.isupdate = True self._preserve_parameter_order = statement._preserve_parameter_order if statement._ordered_values is not None: @@ -447,7 +447,9 @@ class ValuesBase(UpdateBase): _returning = () def __init__(self, table, values, prefixes): - self.table = coercions.expect(roles.FromClauseRole, table) + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) if values is not None: self.values.non_generative(self, values) if prefixes: @@ -949,6 +951,28 @@ class DMLWhereBase(object): coercions.expect(roles.WhereHavingRole, whereclause), ) + def filter(self, *criteria): + """A synonym for the :meth:`_dml.DMLWhereBase.where` method.""" + + return self.where(*criteria) + + @property + def whereclause(self): + """Return the completed WHERE clause for this :class:`.DMLWhereBase` + statement. + + This assembles the current collection of WHERE criteria + into a single :class:`_expression.BooleanClauseList` construct. + + + .. versionadded:: 1.4 + + """ + + return BooleanClauseList._construct_for_whereclause( + self._where_criteria + ) + class Update(DMLWhereBase, ValuesBase): """Represent an Update construct. @@ -1266,7 +1290,9 @@ class Delete(DMLWhereBase, UpdateBase): """ self._bind = bind - self.table = coercions.expect(roles.FromClauseRole, table) + self.table = coercions.expect( + roles.DMLTableRole, table, apply_propagate_attrs=self + ) self._returning = returning if prefixes: diff --git a/lib/sqlalchemy/sql/roles.py b/lib/sqlalchemy/sql/roles.py index 5a55fe5f28..3d94ec9ff5 100644 --- a/lib/sqlalchemy/sql/roles.py +++ b/lib/sqlalchemy/sql/roles.py @@ -184,10 +184,15 @@ class CompoundElementRole(SQLRole): ) +# TODO: are we using this? class DMLRole(StatementRole): pass +class DMLTableRole(FromClauseRole): + _role_name = "subject table for an INSERT, UPDATE or DELETE" + + class DMLColumnRole(SQLRole): _role_name = "SET/VALUES column expression or string key" diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d6845e05f7..a95fc561ae 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -789,7 +789,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable): self._reset_column_collection() -class Join(FromClause): +class Join(roles.DMLTableRole, FromClause): """represent a ``JOIN`` construct between two :class:`_expression.FromClause` elements. @@ -1406,7 +1406,7 @@ class AliasedReturnsRows(NoInit, FromClause): return self.element.bind -class Alias(AliasedReturnsRows): +class Alias(roles.DMLTableRole, AliasedReturnsRows): """Represents an table or selectable alias (AS). Represents an alias, as typically applied to any table or @@ -1987,7 +1987,7 @@ class FromGrouping(GroupedElement, FromClause): self.element = state["element"] -class TableClause(Immutable, FromClause): +class TableClause(roles.DMLTableRole, Immutable, FromClause): """Represents a minimal "table" construct. This is a lightweight table object that has only a name, a diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index 388097e45a..68281f33d1 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -10,6 +10,7 @@ from .. import util from ..inspection import inspect from ..util import collections_abc from ..util import HasMemoized +from ..util import py37 SKIP_TRAVERSE = util.symbol("skip_traverse") COMPARE_FAILED = False @@ -562,23 +563,38 @@ class _CacheKey(ExtendedInternalTraversal): ) def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams): + if py37: + # in py37 we can assume two dictionaries created in the same + # insert ordering will retain that sorting + return ( + attrname, + tuple( + ( + k._gen_cache_key(anon_map, bindparams) + if hasattr(k, "__clause_element__") + else k, + obj[k]._gen_cache_key(anon_map, bindparams), + ) + for k in obj + ), + ) + else: + expr_values = {k for k in obj if hasattr(k, "__clause_element__")} + if expr_values: + # expr values can't be sorted deterministically right now, + # so no cache + anon_map[NO_CACHE] = True + return () - expr_values = {k for k in obj if hasattr(k, "__clause_element__")} - if expr_values: - # expr values can't be sorted deterministically right now, - # so no cache - anon_map[NO_CACHE] = True - return () - - str_values = expr_values.symmetric_difference(obj) + str_values = expr_values.symmetric_difference(obj) - return ( - attrname, - tuple( - (k, obj[k]._gen_cache_key(anon_map, bindparams)) - for k in sorted(str_values) - ), - ) + return ( + attrname, + tuple( + (k, obj[k]._gen_cache_key(anon_map, bindparams)) + for k in sorted(str_values) + ), + ) def visit_dml_multi_values( self, attrname, obj, parent, anon_map, bindparams @@ -1130,6 +1146,18 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots): for lv, rv in zip(left, right): if not self._compare_dml_values_or_ce(lv, rv, **kw): return COMPARE_FAILED + elif isinstance(right, collections_abc.Sequence): + return COMPARE_FAILED + elif py37: + # dictionaries guaranteed to support insert ordering in + # py37 so that we can compare the keys in order. without + # this, we can't compare SQL expression keys because we don't + # know which key is which + for (lk, lv), (rk, rv) in zip(left.items(), right.items()): + if not self._compare_dml_values_or_ce(lk, rk, **kw): + return COMPARE_FAILED + if not self._compare_dml_values_or_ce(lv, rv, **kw): + return COMPARE_FAILED else: for lk in left: lv = left[lk] diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 0ea9f067e5..54da06a3da 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -403,10 +403,6 @@ class AssertsCompiledSQL(object): LABEL_STYLE_TABLENAME_PLUS_COL ) clause = compile_state.statement - elif isinstance(clause, orm.persistence.BulkUD): - with mock.patch.object(clause, "_execute_stmt") as stmt_mock: - clause.exec_() - clause = stmt_mock.mock_calls[0][1][0] if compile_kwargs: kw["compile_kwargs"] = compile_kwargs diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index 55a6cdcf90..273570357b 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -65,6 +65,7 @@ from .compat import pickle # noqa from .compat import print_ # noqa from .compat import py2k # noqa from .compat import py36 # noqa +from .compat import py37 # noqa from .compat import py3k # noqa from .compat import quote_plus # noqa from .compat import raise_ # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 247dbc13c3..5c46395f9a 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -15,6 +15,7 @@ import platform import sys +py37 = sys.version_info >= (3, 7) py36 = sys.version_info >= (3, 6) py3k = sys.version_info >= (3, 0) py2k = sys.version_info < (3, 0) diff --git a/test/aaa_profiling/test_orm.py b/test/aaa_profiling/test_orm.py index 1b31c96e9c..188e8e929a 100644 --- a/test/aaa_profiling/test_orm.py +++ b/test/aaa_profiling/test_orm.py @@ -859,6 +859,7 @@ class JoinedEagerLoadTest(fixtures.MappedTest): ORMCompileState.orm_pre_session_exec( sess, compile_state.select_statement, + {}, exec_opts, bind_arguments, ) diff --git a/test/base/test_result.py b/test/base/test_result.py index ca65111da2..ce0e7b9452 100644 --- a/test/base/test_result.py +++ b/test/base/test_result.py @@ -466,10 +466,16 @@ class ResultTest(fixtures.TestBase): def test_scalar_one(self): result = self._fixture(num_rows=1) + row = result.scalar_one() + eq_(row, 1) + + def test_scalars_plus_one(self): + result = self._fixture(num_rows=1) + row = result.scalars().one() eq_(row, 1) - def test_scalar_one_none(self): + def test_scalars_plus_one_none(self): result = self._fixture(num_rows=0) result = result.scalars() @@ -488,6 +494,21 @@ class ResultTest(fixtures.TestBase): result.one, ) + def test_one_or_none(self): + result = self._fixture(num_rows=1) + + eq_(result.one_or_none(), (1, 1, 1)) + + def test_scalar_one_or_none(self): + result = self._fixture(num_rows=1) + + eq_(result.scalar_one_or_none(), 1) + + def test_scalar_one_or_none_none(self): + result = self._fixture(num_rows=0) + + eq_(result.scalar_one_or_none(), None) + def test_one_or_none_none(self): result = self._fixture(num_rows=0) diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index bb10337051..c0029fbb63 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -3,6 +3,7 @@ import os from sqlalchemy import Column from sqlalchemy import DateTime +from sqlalchemy import delete from sqlalchemy import event from sqlalchemy import Float from sqlalchemy import ForeignKey @@ -13,6 +14,7 @@ from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing +from sqlalchemy import update from sqlalchemy import util from sqlalchemy.ext.horizontal_shard import ShardedSession from sqlalchemy.future import select as future_select @@ -444,7 +446,7 @@ class ShardTest(object): t = get_tokyo(sess2) eq_(t.city, tokyo.city) - def test_bulk_update(self): + def test_bulk_update_synchronize_evaluate(self): sess = self._fixture_data() eq_( @@ -456,7 +458,8 @@ class ShardTest(object): eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) sess.query(Report).filter(Report.temperature >= 80).update( - {"temperature": Report.temperature + 6} + {"temperature": Report.temperature + 6}, + synchronize_session="evaluate", ) eq_( @@ -467,13 +470,58 @@ class ShardTest(object): # test synchronize session as well eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) - def test_bulk_delete(self): + def test_bulk_update_synchronize_fetch(self): + sess = self._fixture_data() + + eq_( + set(row.temperature for row in sess.query(Report.temperature)), + {80.0, 75.0, 85.0}, + ) + + temps = sess.query(Report).all() + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + + sess.query(Report).filter(Report.temperature >= 80).update( + {"temperature": Report.temperature + 6}, + synchronize_session="fetch", + ) + + eq_( + set(row.temperature for row in sess.query(Report.temperature)), + {86.0, 75.0, 91.0}, + ) + + # test synchronize session as well + eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) + + def test_bulk_delete_synchronize_evaluate(self): + sess = self._fixture_data() + + temps = sess.query(Report).all() + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + + sess.query(Report).filter(Report.temperature >= 80).delete( + synchronize_session="evaluate" + ) + + eq_( + set(row.temperature for row in sess.query(Report.temperature)), + {75.0}, + ) + + # test synchronize session as well + for t in temps: + assert inspect(t).deleted is (t.temperature >= 80) + + def test_bulk_delete_synchronize_fetch(self): sess = self._fixture_data() temps = sess.query(Report).all() eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) - sess.query(Report).filter(Report.temperature >= 80).delete() + sess.query(Report).filter(Report.temperature >= 80).delete( + synchronize_session="fetch" + ) eq_( set(row.temperature for row in sess.query(Report.temperature)), @@ -484,6 +532,118 @@ class ShardTest(object): for t in temps: assert inspect(t).deleted is (t.temperature >= 80) + def test_bulk_update_future_synchronize_evaluate(self): + sess = self._fixture_data() + + eq_( + set( + row.temperature + for row in sess.execute(future_select(Report.temperature)) + ), + {80.0, 75.0, 85.0}, + ) + + temps = sess.execute(future_select(Report)).scalars().all() + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + + sess.execute( + update(Report) + .filter(Report.temperature >= 80) + .values({"temperature": Report.temperature + 6},) + .execution_options(synchronize_session="evaluate") + ) + + eq_( + set( + row.temperature + for row in sess.execute(future_select(Report.temperature)) + ), + {86.0, 75.0, 91.0}, + ) + + # test synchronize session as well + eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) + + def test_bulk_update_future_synchronize_fetch(self): + sess = self._fixture_data() + + eq_( + set( + row.temperature + for row in sess.execute(future_select(Report.temperature)) + ), + {80.0, 75.0, 85.0}, + ) + + temps = sess.execute(future_select(Report)).scalars().all() + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + + sess.execute( + update(Report) + .filter(Report.temperature >= 80) + .values({"temperature": Report.temperature + 6},) + .execution_options(synchronize_session="fetch") + ) + + eq_( + set( + row.temperature + for row in sess.execute(future_select(Report.temperature)) + ), + {86.0, 75.0, 91.0}, + ) + + # test synchronize session as well + eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0}) + + def test_bulk_delete_future_synchronize_evaluate(self): + sess = self._fixture_data() + + temps = sess.execute(future_select(Report)).scalars().all() + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + + sess.execute( + delete(Report) + .filter(Report.temperature >= 80) + .execution_options(synchronize_session="evaluate") + ) + + eq_( + set( + row.temperature + for row in sess.execute(future_select(Report.temperature)) + ), + {75.0}, + ) + + # test synchronize session as well + for t in temps: + assert inspect(t).deleted is (t.temperature >= 80) + + def test_bulk_delete_future_synchronize_fetch(self): + sess = self._fixture_data() + + temps = sess.execute(future_select(Report)).scalars().all() + eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0}) + + sess.execute( + delete(Report) + .filter(Report.temperature >= 80) + .execution_options(synchronize_session="fetch") + ) + + eq_( + set( + row.temperature + for row in sess.execute(future_select(Report.temperature)) + ), + {75.0}, + ) + + # test synchronize session as well + for t in temps: + assert inspect(t).deleted is (t.temperature >= 80) + class DistinctEngineShardTest(ShardTest, fixtures.TestBase): def _init_dbs(self): diff --git a/test/ext/test_hybrid.py b/test/ext/test_hybrid.py index 353c52c5c0..fbac35f7ee 100644 --- a/test/ext/test_hybrid.py +++ b/test/ext/test_hybrid.py @@ -9,9 +9,9 @@ from sqlalchemy import String from sqlalchemy.ext import hybrid from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import aliased -from sqlalchemy.orm import persistence from sqlalchemy.orm import relationship from sqlalchemy.orm import Session +from sqlalchemy.sql import update from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ @@ -588,15 +588,10 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_update_plain(self): Person = self.classes.Person - s = Session() - q = s.query(Person) - - bulk_ud = persistence.BulkUpdate.factory( - q, False, {Person.fname: "Dr."}, {} - ) + statement = update(Person).values({Person.fname: "Dr."}) self.assert_compile( - bulk_ud, + statement, "UPDATE person SET first_name=:first_name", params={"first_name": "Dr."}, ) @@ -604,15 +599,10 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL): def test_update_expr(self): Person = self.classes.Person - s = Session() - q = s.query(Person) - - bulk_ud = persistence.BulkUpdate.factory( - q, False, {Person.name: "Dr. No"}, {} - ) + statement = update(Person).values({Person.name: "Dr. No"}) self.assert_compile( - bulk_ud, + statement, "UPDATE person SET first_name=:first_name, last_name=:last_name", params={"first_name": "Dr.", "last_name": "No"}, ) diff --git a/test/orm/test_composites.py b/test/orm/test_composites.py index 0d679e6db5..da9783dfd1 100644 --- a/test/orm/test_composites.py +++ b/test/orm/test_composites.py @@ -4,12 +4,13 @@ from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import update +from sqlalchemy.future import select as future_select from sqlalchemy.orm import aliased from sqlalchemy.orm import composite from sqlalchemy.orm import CompositeProperty from sqlalchemy.orm import configure_mappers from sqlalchemy.orm import mapper -from sqlalchemy.orm import persistence from sqlalchemy.orm import relationship from sqlalchemy.orm import Session from sqlalchemy.testing import assert_raises_message @@ -231,17 +232,20 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): sess = self._fixture() - e1 = sess.query(Edge).filter(Edge.start == Point(14, 5)).one() + e1 = sess.execute( + future_select(Edge).filter(Edge.start == Point(14, 5)) + ).scalar_one() eq_(e1.end, Point(2, 7)) - q = sess.query(Edge).filter(Edge.start == Point(14, 5)) - bulk_ud = persistence.BulkUpdate.factory( - q, False, {Edge.end: Point(16, 10)}, {} + stmt = ( + update(Edge) + .filter(Edge.start == Point(14, 5)) + .values({Edge.end: Point(16, 10)}) ) self.assert_compile( - bulk_ud, + stmt, "UPDATE edges SET x2=:x2, y2=:y2 WHERE edges.x1 = :x1_1 " "AND edges.y1 = :y1_1", params={"x2": 16, "x1_1": 14, "y2": 10, "y1_1": 5}, @@ -253,12 +257,18 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL): sess = self._fixture() - e1 = sess.query(Edge).filter(Edge.start == Point(14, 5)).one() + e1 = sess.execute( + future_select(Edge).filter(Edge.start == Point(14, 5)) + ).scalar_one() eq_(e1.end, Point(2, 7)) - q = sess.query(Edge).filter(Edge.start == Point(14, 5)) - q.update({Edge.end: Point(16, 10)}) + stmt = ( + update(Edge) + .filter(Edge.start == Point(14, 5)) + .values({Edge.end: Point(16, 10)}) + ) + sess.execute(stmt) eq_(e1.end, Point(16, 10)) diff --git a/test/orm/test_core_compilation.py b/test/orm/test_core_compilation.py index a26d0ae267..a4f0841060 100644 --- a/test/orm/test_core_compilation.py +++ b/test/orm/test_core_compilation.py @@ -1633,6 +1633,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL): checkparams={"id_1": 5, "name": "ed"}, ) + self.assert_compile( + update(User).values({User.name: "ed"}).where(User.id == 5), + "UPDATE users SET name=:name WHERE users.id = :id_1", + checkparams={"id_1": 5, "name": "ed"}, + ) + def test_delete_from_entity(self): from sqlalchemy.sql import delete diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 299aba8099..2e19b94354 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -1684,7 +1684,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_(upd.session, sess) eq_( canary.after_bulk_update_legacy.mock_calls, - [call(sess, upd.query, upd.context, upd.result)], + [call(sess, upd.query, None, upd.result)], ) def test_on_bulk_delete_hook(self): @@ -1714,7 +1714,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): eq_(upd.session, sess) eq_( canary.after_bulk_delete_legacy.mock_calls, - [call(sess, upd.query, upd.context, upd.result)], + [call(sess, upd.query, None, upd.result)], ) diff --git a/test/orm/test_events.py b/test/orm/test_events.py index c1457289aa..b68e0d2e65 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -703,6 +703,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): s.refresh(a1) # joined eager load didn't continue eq_(len(a1.bs), 1) + s.close() @_combinations def test_flag_resolves_existing(self, target, event_name, fn): @@ -715,6 +716,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): s.expire(a1) event.listen(target, event_name, fn, restore_load_context=True) s.query(A).all() + s.close() @_combinations def test_flag_resolves(self, target, event_name, fn): @@ -728,6 +730,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest): s.refresh(a1) # joined eager load continued eq_(len(a1.bs), 3) + s.close() class DeclarativeEventListenTest( @@ -1768,6 +1771,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): upd = canary.after_bulk_update.mock_calls[0][1][0] eq_(upd.session, sess) + eq_(upd.result.rowcount, 0) def test_on_bulk_delete_hook(self): User, users = self.classes.User, self.tables.users @@ -1787,6 +1791,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest): upd = canary.after_bulk_delete.mock_calls[0][1][0] eq_(upd.session, sess) + eq_(upd.result.rowcount, 0) def test_connection_emits_after_begin(self): sess, canary = self._listener_fixture(bind=testing.db) diff --git a/test/orm/test_update_delete.py b/test/orm/test_update_delete.py index 9017ca84ee..5430fbffc5 100644 --- a/test/orm/test_update_delete.py +++ b/test/orm/test_update_delete.py @@ -1,6 +1,7 @@ from sqlalchemy import Boolean from sqlalchemy import case from sqlalchemy import column +from sqlalchemy import event from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import func @@ -10,6 +11,8 @@ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import update +from sqlalchemy.future import select as future_select from sqlalchemy.orm import backref from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper @@ -20,10 +23,8 @@ from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import eq_ from sqlalchemy.testing import fixtures -from sqlalchemy.testing import mock from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import Table -from sqlalchemy.util import collections_abc class UpdateDeleteTest(fixtures.MappedTest): @@ -385,6 +386,58 @@ class UpdateDeleteTest(fixtures.MappedTest): list(zip([15, 27, 19, 27])), ) + def test_update_future(self): + User, users = self.classes.User, self.tables.users + + sess = Session() + + john, jack, jill, jane = ( + sess.execute(future_select(User).order_by(User.id)).scalars().all() + ) + + sess.execute( + update(User) + .where(User.age > 29) + .values({"age": User.age - 10}) + .execution_options(synchronize_session="evaluate"), + ) + + eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27]) + eq_( + sess.execute(future_select(User.age).order_by(User.id)).all(), + list(zip([25, 37, 29, 27])), + ) + + sess.execute( + update(User) + .where(User.age > 29) + .values({User.age: User.age - 10}) + .execution_options(synchronize_session="evaluate") + ) + eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27]) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 27, 29, 27])), + ) + + sess.query(User).filter(User.age > 27).update( + {users.c.age_int: User.age - 10}, synchronize_session="evaluate" + ) + eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27]) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([25, 27, 19, 27])), + ) + + sess.query(User).filter(User.age == 25).update( + {User.age: User.age - 10}, synchronize_session="fetch" + ) + eq_([john.age, jack.age, jill.age, jane.age], [15, 27, 19, 27]) + eq_( + sess.query(User.age).order_by(User.id).all(), + list(zip([15, 27, 19, 27])), + ) + def test_update_against_table_col(self): User, users = self.classes.User, self.tables.users @@ -677,41 +730,111 @@ class UpdateDeleteTest(fixtures.MappedTest): # Do an update using unordered dict and check that the parameters used # are ordered in table order + + m1 = testing.mock.Mock() + + @event.listens_for(session, "after_bulk_update") + def do_orm_execute(bulk_ud): + m1(bulk_ud.result.context.compiled.compile_state.statement) + q = session.query(User) - with mock.patch.object(q, "_execute_crud") as exec_: - q.filter(User.id == 15).update({"name": "foob", "id": 123}) - # Confirm that parameters are a dict instead of tuple or list - params = exec_.mock_calls[0][1][0]._values - assert isinstance(params, collections_abc.Mapping) + q.filter(User.id == 15).update({"name": "foob", "age": 123}) + assert m1.mock_calls[0][1][0]._values - def test_update_preserve_parameter_order(self): + def test_update_preserve_parameter_order_query(self): User = self.classes.User session = Session() # Do update using a tuple and check that order is preserved - q = session.query(User) - with mock.patch.object(q, "_execute_crud") as exec_: - q.filter(User.id == 15).update( - (("id", 123), ("name", "foob")), - update_args={"preserve_parameter_order": True}, - ) + + m1 = testing.mock.Mock() + + @event.listens_for(session, "after_bulk_update") + def do_orm_execute(bulk_ud): + cols = [ - c.key for c, v in exec_.mock_calls[0][1][0]._ordered_values + c.key + for c, v in ( + ( + bulk_ud.result.context + ).compiled.compile_state.statement._ordered_values + ) ] - eq_(["id", "name"], cols) + m1(cols) - # Now invert the order and use a list instead, and check that order is - # also preserved q = session.query(User) - with mock.patch.object(q, "_execute_crud") as exec_: - q.filter(User.id == 15).update( - [("name", "foob"), ("id", 123)], - update_args={"preserve_parameter_order": True}, + q.filter(User.id == 15).update( + (("age", 123), ("name", "foob")), + update_args={"preserve_parameter_order": True}, + ) + + eq_(m1.mock_calls[0][1][0], ["age_int", "name"]) + + m1.mock_calls = [] + + q = session.query(User) + q.filter(User.id == 15).update( + [("name", "foob"), ("age", 123)], + update_args={"preserve_parameter_order": True}, + ) + eq_(m1.mock_calls[0][1][0], ["name", "age_int"]) + + def test_update_multi_values_error_future(self): + User = self.classes.User + session = Session() + + # Do update using a tuple and check that order is preserved + + stmt = ( + update(User) + .filter(User.id == 15) + .values([("id", 123), ("name", "foob")]) + ) + + assert_raises_message( + exc.InvalidRequestError, + "UPDATE construct does not support multiple parameter sets.", + session.execute, + stmt, + ) + + def test_update_preserve_parameter_order_future(self): + User = self.classes.User + session = Session() + + # Do update using a tuple and check that order is preserved + + stmt = ( + update(User) + .filter(User.id == 15) + .ordered_values(("age", 123), ("name", "foob")) + ) + result = session.execute(stmt) + cols = [ + c.key + for c, v in ( + ( + result.context + ).compiled.compile_state.statement._ordered_values ) - cols = [ - c.key for c, v in exec_.mock_calls[0][1][0]._ordered_values - ] - eq_(["name", "id"], cols) + ] + eq_(["age_int", "name"], cols) + + # Now invert the order and use a list instead, and check that order is + # also preserved + stmt = ( + update(User) + .filter(User.id == 15) + .ordered_values(("name", "foob"), ("age", 123),) + ) + result = session.execute(stmt) + cols = [ + c.key + for c, v in ( + result.context + ).compiled.compile_state.statement._ordered_values + ] + eq_(["name", "age_int"], cols) class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest): @@ -1103,16 +1226,23 @@ class ExpressionUpdateTest(fixtures.MappedTest): def test_update_args(self): Data = self.classes.Data - session = testing.mock.Mock(wraps=Session()) + session = Session() update_args = {"mysql_limit": 1} + m1 = testing.mock.Mock() + + @event.listens_for(session, "after_bulk_update") + def do_orm_execute(bulk_ud): + update_stmt = ( + bulk_ud.result.context.compiled.compile_state.statement + ) + m1(update_stmt) + q = session.query(Data) - with testing.mock.patch.object(q, "_execute_crud") as exec_: - q.update({Data.cnt: Data.cnt + 1}, update_args=update_args) - eq_(exec_.call_count, 1) - args, kwargs = exec_.mock_calls[0][1:3] - eq_(len(args), 2) - update_stmt = args[0] + q.update({Data.cnt: Data.cnt + 1}, update_args=update_args) + + update_stmt = m1.mock_calls[0][1][0] + eq_(update_stmt.dialect_kwargs, update_args) @@ -1163,18 +1293,22 @@ class InheritTest(fixtures.DeclarativeMappedTest): ) s.commit() - def test_illegal_metadata(self): + @testing.only_on("mysql", "Multi table update") + def test_update_from_join_no_problem(self): person = self.classes.Person.__table__ engineer = self.classes.Engineer.__table__ sess = Session() - assert_raises_message( - exc.InvalidRequestError, - "This operation requires only one Table or entity be " - "specified as the target.", - sess.query(person.join(engineer)).update, - {}, + sess.query(person.join(engineer)).filter(person.c.name == "e2").update( + {person.c.name: "updated", engineer.c.engineer_name: "e2a"}, ) + obj = sess.execute( + future_select(self.classes.Engineer).filter( + self.classes.Engineer.name == "updated" + ) + ).scalar() + eq_(obj.name, "updated") + eq_(obj.engineer_name, "e2a") def test_update_subtable_only(self): Engineer = self.classes.Engineer diff --git a/test/sql/test_compare.py b/test/sql/test_compare.py index 2d84ab6764..3f74bdbccc 100644 --- a/test/sql/test_compare.py +++ b/test/sql/test_compare.py @@ -87,6 +87,16 @@ table_a_2_bs = Table( table_b = Table("b", meta, Column("a", Integer), Column("b", Integer)) +table_b_b = Table( + "b_b", + meta, + Column("a", Integer), + Column("b", Integer), + Column("c", Integer), + Column("d", Integer), + Column("e", Integer), +) + table_c = Table("c", meta, Column("x", Integer), Column("y", Integer)) table_d = Table("d", meta, Column("y", Integer), Column("z", Integer)) @@ -711,6 +721,54 @@ class CoreFixtures(object): fixtures.append(_statements_w_anonymous_col_names) + def _update_dml_w_dicts(): + return ( + table_b_b.update().values( + { + table_b_b.c.a: 5, + table_b_b.c.b: 5, + table_b_b.c.c: 5, + table_b_b.c.d: 5, + } + ), + # equivalent, but testing dictionary insert ordering as cache key + # / compare + table_b_b.update().values( + { + table_b_b.c.a: 5, + table_b_b.c.c: 5, + table_b_b.c.b: 5, + table_b_b.c.d: 5, + } + ), + table_b_b.update().values( + {table_b_b.c.a: 5, table_b_b.c.b: 5, "c": 5, table_b_b.c.d: 5} + ), + table_b_b.update().values( + { + table_b_b.c.a: 5, + table_b_b.c.b: 5, + table_b_b.c.c: 5, + table_b_b.c.d: 5, + table_b_b.c.e: 10, + } + ), + table_b_b.update() + .values( + { + table_b_b.c.a: 5, + table_b_b.c.b: 5, + table_b_b.c.c: 5, + table_b_b.c.d: 5, + table_b_b.c.e: 10, + } + ) + .where(table_b_b.c.c > 10), + ) + + if util.py37: + fixtures.append(_update_dml_w_dicts) + class CacheKeyFixture(object): def _run_cache_key_fixture(self, fixture, compare_values):