]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Convert bulk update/delete to new execution model
authorMike Bayer <mike_mp@zzzcomputing.com>
Wed, 3 Jun 2020 21:38:35 +0000 (17:38 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 6 Jun 2020 17:31:54 +0000 (13:31 -0400)
This reorganizes the BulkUD model in sqlalchemy.orm.persistence
to be based on the CompileState concept and to allow plain
update() / delete() to be passed to session.execute() where
the ORM synchronize session logic will take place.
Also gets "synchronize_session='fetch'" working with horizontal
sharding.

Adding a few more result.scalar_one() types of methods
as scalar_one() seems like what is normally desired.

Fixes: #5160
Change-Id: I8001ebdad089da34119eb459709731ba6c0ba975

31 files changed:
lib/sqlalchemy/engine/cursor.py
lib/sqlalchemy/engine/result.py
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/ext/hybrid.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/descriptor_props.py
lib/sqlalchemy/orm/events.py
lib/sqlalchemy/orm/mapper.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
lib/sqlalchemy/sql/base.py
lib/sqlalchemy/sql/coercions.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/dml.py
lib/sqlalchemy/sql/roles.py
lib/sqlalchemy/sql/selectable.py
lib/sqlalchemy/sql/traversals.py
lib/sqlalchemy/testing/assertions.py
lib/sqlalchemy/util/__init__.py
lib/sqlalchemy/util/compat.py
test/aaa_profiling/test_orm.py
test/base/test_result.py
test/ext/test_horizontal_shard.py
test/ext/test_hybrid.py
test/orm/test_composites.py
test/orm/test_core_compilation.py
test/orm/test_deprecations.py
test/orm/test_events.py
test/orm/test_update_delete.py
test/sql/test_compare.py

index 1d832e4afa537691b395acb2926d56e559899f23..d03d79df72bcdb530f230f56791bd7b3a6d5f134 100644 (file)
@@ -1630,6 +1630,15 @@ class CursorResult(BaseCursorResult, Result):
     def _raw_row_iterator(self):
         return self._fetchiter_impl()
 
+    def merge(self, *others):
+        merged_result = super(CursorResult, self).merge(*others)
+        setup_rowcounts = not self._metadata.returns_rows
+        if setup_rowcounts:
+            merged_result.rowcount = sum(
+                result.rowcount for result in (self,) + others
+            )
+        return merged_result
+
     def close(self):
         """Close this :class:`_engine.CursorResult`.
 
index 600229037bd209f74f6874d68300f58d18eb4a24..b29bc22d44fe6a1c264d4884669acfccbbb491de 100644 (file)
@@ -951,7 +951,7 @@ class Result(InPlaceGenerative):
         """
         return self._allrows()
 
-    def _only_one_row(self, raise_for_second_row, raise_for_none):
+    def _only_one_row(self, raise_for_second_row, raise_for_none, scalar):
         onerow = self._fetchone_impl
 
         row = onerow(hard_close=True)
@@ -1010,27 +1010,43 @@ class Result(InPlaceGenerative):
             # if we checked for second row then that would have
             # closed us :)
             self._soft_close(hard=True)
-        post_creational_filter = self._post_creational_filter
-        if post_creational_filter:
-            row = post_creational_filter(row)
 
-        return row
+        if not scalar:
+            post_creational_filter = self._post_creational_filter
+            if post_creational_filter:
+                row = post_creational_filter(row)
+
+        if scalar and row:
+            return row[0]
+        else:
+            return row
 
     def first(self):
         """Fetch the first row or None if no row is present.
 
         Closes the result set and discards remaining rows.
 
+        .. note::  This method returns one **row**, e.g. tuple, by default.
+           To return exactly one single scalar value, that is, the first
+           column of the first row, use the :meth:`.Result.scalar` method,
+           or combine :meth:`.Result.scalars` and :meth:`.Result.first`.
+
         .. comment: A warning is emitted if additional rows remain.
 
         :return: a :class:`.Row` object if no filters are applied, or None
          if no rows remain.
          When filters are applied, such as :meth:`_engine.Result.mappings`
-         or :meth:`._engine.Result.scalar`, different kinds of objects
+         or :meth:`._engine.Result.scalars`, different kinds of objects
          may be returned.
 
+         .. seealso::
+
+            :meth:`_result.Result.scalar`
+
+            :meth:`_result.Result.one`
+
         """
-        return self._only_one_row(False, False)
+        return self._only_one_row(False, False, False)
 
     def one_or_none(self):
         """Return at most one result or raise an exception.
@@ -1055,15 +1071,50 @@ class Result(InPlaceGenerative):
             :meth:`_result.Result.one`
 
         """
-        return self._only_one_row(True, False)
+        return self._only_one_row(True, False, False)
+
+    def scalar_one(self):
+        """Return exactly one scalar result or raise an exception.
+
+        This is equvalent to calling :meth:`.Result.scalars` and then
+        :meth:`.Result.one`.
+
+        .. seealso::
+
+            :meth:`.Result.one`
+
+            :meth:`.Result.scalars`
+
+        """
+        return self._only_one_row(True, True, True)
+
+    def scalar_one_or_none(self):
+        """Return exactly one or no scalar result.
+
+        This is equvalent to calling :meth:`.Result.scalars` and then
+        :meth:`.Result.one_or_none`.
+
+        .. seealso::
+
+            :meth:`.Result.one_or_none`
+
+            :meth:`.Result.scalars`
+
+        """
+        return self._only_one_row(True, False, True)
 
     def one(self):
-        """Return exactly one result or raise an exception.
+        """Return exactly one row or raise an exception.
 
         Raises :class:`.NoResultFound` if the result returns no
         rows, or :class:`.MultipleResultsFound` if multiple rows
         would be returned.
 
+        .. note::  This method returns one **row**, e.g. tuple, by default.
+           To return exactly one single scalar value, that is, the first
+           column of the first row, use the :meth:`.Result.scalar_one` method,
+           or combine :meth:`.Result.scalars` and :meth:`.Result.one`.
+
         .. versionadded:: 1.4
 
         :return: The first :class:`.Row`.
@@ -1079,24 +1130,26 @@ class Result(InPlaceGenerative):
 
             :meth:`_result.Result.one_or_none`
 
+            :meth:`_result.Result.scalar_one`
+
         """
-        return self._only_one_row(True, True)
+        return self._only_one_row(True, True, False)
 
     def scalar(self):
         """Fetch the first column of the first row, and close the result set.
 
+        Returns None if there are no rows to fetch.
+
+        No validation is performed to test if additional rows remain.
+
         After calling this method, the object is fully closed,
         e.g. the :meth:`_engine.CursorResult.close`
         method will have been called.
 
-        :return: a Python scalar value , or None if no rows remain
+        :return: a Python scalar value , or None if no rows remain.
 
         """
-        row = self.first()
-        if row is not None:
-            return row[0]
-        else:
-            return None
+        return self._only_one_row(False, False, True)
 
 
 class FrozenResult(object):
index c3ac71c10366e9973fff15550ea38f8fe6795d24..0983807cb9664172fdc89e09bb81fd487c05f19a 100644 (file)
@@ -50,58 +50,6 @@ class ShardedQuery(Query):
         """
         return self.execution_options(_sa_shard_id=shard_id)
 
-    def _execute_crud(self, stmt, mapper):
-        def exec_for_shard(shard_id):
-            conn = self.session.connection(
-                mapper=mapper,
-                shard_id=shard_id,
-                clause=stmt,
-                close_with_result=True,
-            )
-            result = conn._execute_20(
-                stmt, self.load_options._params, self._execution_options
-            )
-            return result
-
-        if self._shard_id is not None:
-            return exec_for_shard(self._shard_id)
-        else:
-            rowcount = 0
-            results = []
-            # TODO: this will have to be the new object
-            for shard_id in self.execute_chooser(self):
-                result = exec_for_shard(shard_id)
-                rowcount += result.rowcount
-                results.append(result)
-
-            return ShardedResult(results, rowcount)
-
-
-class ShardedResult(object):
-    """A value object that represents multiple :class:`_engine.CursorResult`
-    objects.
-
-    This is used by the :meth:`.ShardedQuery._execute_crud` hook to return
-    an object that takes the place of the single :class:`_engine.CursorResult`.
-
-    Attribute include ``result_proxies``, which is a sequence of the
-    actual :class:`_engine.CursorResult` objects,
-    as well as ``aggregate_rowcount``
-    or ``rowcount``, which is the sum of all the individual rowcount values.
-
-    .. versionadded::  1.3
-    """
-
-    __slots__ = ("result_proxies", "aggregate_rowcount")
-
-    def __init__(self, result_proxies, aggregate_rowcount):
-        self.result_proxies = result_proxies
-        self.aggregate_rowcount = aggregate_rowcount
-
-    @property
-    def rowcount(self):
-        return self.aggregate_rowcount
-
 
 class ShardedSession(Session):
     def __init__(
@@ -259,37 +207,40 @@ class ShardedSession(Session):
 
 
 def execute_and_instances(orm_context):
-    if orm_context.bind_arguments.get("_horizontal_shard", False):
-        return None
-
     params = orm_context.parameters
 
-    load_options = orm_context.load_options
+    if orm_context.is_select:
+        load_options = active_options = orm_context.load_options
+        update_options = None
+        if params is None:
+            params = active_options._params
+
+    else:
+        load_options = None
+        update_options = active_options = orm_context.update_delete_options
+
     session = orm_context.session
     # orm_query = orm_context.orm_query
 
-    if params is None:
-        params = load_options._params
-
-    def iter_for_shard(shard_id, load_options):
+    def iter_for_shard(shard_id, load_options, update_options):
         execution_options = dict(orm_context.local_execution_options)
 
         bind_arguments = dict(orm_context.bind_arguments)
-        bind_arguments["_horizontal_shard"] = True
         bind_arguments["shard_id"] = shard_id
 
-        load_options += {"_refresh_identity_token": shard_id}
-        execution_options["_sa_orm_load_options"] = load_options
+        if orm_context.is_select:
+            load_options += {"_refresh_identity_token": shard_id}
+            execution_options["_sa_orm_load_options"] = load_options
+        else:
+            update_options += {"_refresh_identity_token": shard_id}
+            execution_options["_sa_orm_update_options"] = update_options
 
-        return session.execute(
-            orm_context.statement,
-            orm_context.parameters,
-            execution_options,
-            bind_arguments,
+        return orm_context.invoke_statement(
+            bind_arguments=bind_arguments, execution_options=execution_options
         )
 
-    if load_options._refresh_identity_token is not None:
-        shard_id = load_options._refresh_identity_token
+    if active_options._refresh_identity_token is not None:
+        shard_id = active_options._refresh_identity_token
     elif "_sa_shard_id" in orm_context.merged_execution_options:
         shard_id = orm_context.merged_execution_options["_sa_shard_id"]
     elif "shard_id" in orm_context.bind_arguments:
@@ -298,11 +249,11 @@ def execute_and_instances(orm_context):
         shard_id = None
 
     if shard_id is not None:
-        return iter_for_shard(shard_id, load_options)
+        return iter_for_shard(shard_id, load_options, update_options)
     else:
         partial = []
         for shard_id in session.execute_chooser(orm_context):
-            result_ = iter_for_shard(shard_id, load_options)
+            result_ = iter_for_shard(shard_id, load_options, update_options)
             partial.append(result_)
 
         return partial[0].merge(*partial[1:])
index 9f73b5d31bc369dd0fcf460c653e61e9eaa63650..efd8d7d6b2a2b7642362deac0e2117873e296395 100644 (file)
@@ -777,7 +777,7 @@ things it can be used for.
 from .. import util
 from ..orm import attributes
 from ..orm import interfaces
-
+from ..sql import elements
 
 HYBRID_METHOD = util.symbol("HYBRID_METHOD")
 """Symbol indicating an :class:`InspectionAttr` that's
@@ -1144,6 +1144,9 @@ class ExprComparator(Comparator):
         return self.hybrid.info
 
     def _bulk_update_tuples(self, value):
+        if isinstance(value, elements.BindParameter):
+            value = value.value
+
         if isinstance(self.expression, attributes.QueryableAttribute):
             return self.expression._bulk_update_tuples(value)
         elif self.hybrid.update_expr is not None:
index bd4074ea1115609a662a6c7f7cff657587d6a4f7..a16db66f6d28e871445b2aab05ce39317ec6e75f 100644 (file)
@@ -189,7 +189,7 @@ class ORMCompileState(CompileState):
 
     @classmethod
     def orm_pre_session_exec(
-        cls, session, statement, execution_options, bind_arguments
+        cls, session, statement, params, execution_options, bind_arguments
     ):
         load_options = execution_options.get(
             "_sa_orm_load_options", QueryContext.default_load_options
@@ -216,6 +216,8 @@ class ORMCompileState(CompileState):
         if load_options._autoflush:
             session._autoflush()
 
+        return execution_options
+
     @classmethod
     def orm_setup_cursor_result(
         cls, session, statement, execution_options, bind_arguments, result
index 6be4f0dff808abc85580dedb1250e9a6ae237bf1..027f2521b164b5fafc56f8b9c1942b831ef06b1b 100644 (file)
@@ -420,6 +420,9 @@ class CompositeProperty(DescriptorProperty):
             return CompositeProperty.CompositeBundle(self.prop, clauses)
 
         def _bulk_update_tuples(self, value):
+            if isinstance(value, sql.elements.BindParameter):
+                value = value.value
+
             if value is None:
                 values = [None for key in self.prop._attribute_keys]
             elif isinstance(value, self.prop.composite_class):
index be7aa272ea46e2797f7fe568ea60de9581401379..217aa76c75f8c700bc5300ef51a0623119234ff8 100644 (file)
@@ -1764,7 +1764,7 @@ class SessionEvents(event.Events):
         lambda update_context: (
             update_context.session,
             update_context.query,
-            update_context.context,
+            None,
             update_context.result,
         ),
     )
@@ -1782,12 +1782,13 @@ class SessionEvents(event.Events):
               was called upon.
             * ``values`` The "values" dictionary that was passed to
               :meth:`_query.Query.update`.
-            * ``context`` The :class:`.QueryContext` object, corresponding
-              to the invocation of an ORM query.
             * ``result`` the :class:`_engine.CursorResult`
               returned as a result of the
               bulk UPDATE operation.
 
+        .. versionchanged:: 1.4 the update_context no longer has a
+           ``QueryContext`` object associated with it.
+
         .. seealso::
 
             :meth:`.QueryEvents.before_compile_update`
@@ -1802,7 +1803,7 @@ class SessionEvents(event.Events):
         lambda delete_context: (
             delete_context.session,
             delete_context.query,
-            delete_context.context,
+            None,
             delete_context.result,
         ),
     )
@@ -1818,12 +1819,13 @@ class SessionEvents(event.Events):
             * ``query`` -the :class:`_query.Query`
               object that this update operation
               was called upon.
-            * ``context`` The :class:`.QueryContext` object, corresponding
-              to the invocation of an ORM query.
             * ``result`` the :class:`_engine.CursorResult`
               returned as a result of the
               bulk DELETE operation.
 
+        .. versionchanged:: 1.4 the update_context no longer has a
+           ``QueryContext`` object associated with it.
+
         .. seealso::
 
             :meth:`.QueryEvents.before_compile_delete`
index 4166e6d2a97e2880151e4854062ddfe03d0fee1a..c4cb89c0382b8ff5827888e116234aff11cb84dc 100644 (file)
@@ -2235,14 +2235,28 @@ class Mapper(
 
     @HasMemoized.memoized_instancemethod
     def __clause_element__(self):
-        return self.selectable._annotate(
-            {
-                "entity_namespace": self,
-                "parententity": self,
-                "parentmapper": self,
-                "compile_state_plugin": "orm",
-            }
-        )._set_propagate_attrs(
+
+        annotations = {
+            "entity_namespace": self,
+            "parententity": self,
+            "parentmapper": self,
+            "compile_state_plugin": "orm",
+        }
+        if self.persist_selectable is not self.local_table:
+            # joined table inheritance, with polymorphic selectable,
+            # etc.
+            annotations["dml_table"] = self.local_table._annotate(
+                {
+                    "entity_namespace": self,
+                    "parententity": self,
+                    "parentmapper": self,
+                    "compile_state_plugin": "orm",
+                }
+            )._set_propagate_attrs(
+                {"compile_state_plugin": "orm", "plugin_subject": self}
+            )
+
+        return self.selectable._annotate(annotations)._set_propagate_attrs(
             {"compile_state_plugin": "orm", "plugin_subject": self}
         )
 
index 163ebf22a5925536d969dc22ac0c6e235feb2016..19d43d354dd6db049c92fdb26e89fa337c65eba7 100644 (file)
@@ -28,11 +28,15 @@ from .. import exc as sa_exc
 from .. import future
 from .. import sql
 from .. import util
+from ..future import select as future_select
 from ..sql import coercions
 from ..sql import expression
 from ..sql import operators
 from ..sql import roles
-from ..sql.base import _from_objects
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.dml import DeleteDMLState
+from ..sql.dml import UpdateDMLState
 from ..sql.elements import BooleanClauseList
 
 
@@ -1650,243 +1654,193 @@ def _sort_states(mapper, states):
     )
 
 
-class BulkUD(object):
-    """Handle bulk update and deletes via a :class:`_query.Query`."""
+_EMPTY_DICT = util.immutabledict()
 
-    def __init__(self, query):
-        self.query = query.enable_eagerloads(False)
-        self._validate_query_state()
 
-    def _validate_query_state(self):
-        for attr, methname, notset, op in (
-            ("_limit_clause", "limit()", None, operator.is_),
-            ("_offset_clause", "offset()", None, operator.is_),
-            ("_order_by_clauses", "order_by()", (), operator.eq),
-            ("_group_by_clauses", "group_by()", (), operator.eq),
-            ("_distinct", "distinct()", False, operator.is_),
-            (
-                "_from_obj",
-                "join(), outerjoin(), select_from(), or from_self()",
-                (),
-                operator.eq,
-            ),
-            (
-                "_legacy_setup_joins",
-                "join(), outerjoin(), select_from(), or from_self()",
-                (),
-                operator.eq,
-            ),
-        ):
-            if not op(getattr(self.query, attr), notset):
-                raise sa_exc.InvalidRequestError(
-                    "Can't call Query.update() or Query.delete() "
-                    "when %s has been called" % (methname,)
-                )
-
-    @property
-    def session(self):
-        return self.query.session
+class BulkUDCompileState(CompileState):
+    class default_update_options(Options):
+        _synchronize_session = "evaluate"
+        _autoflush = True
+        _subject_mapper = None
+        _resolved_values = _EMPTY_DICT
+        _resolved_keys_as_propnames = _EMPTY_DICT
+        _value_evaluators = _EMPTY_DICT
+        _matched_objects = None
+        _matched_rows = None
+        _refresh_identity_token = None
 
     @classmethod
-    def _factory(cls, lookup, synchronize_session, *arg):
-        try:
-            klass = lookup[synchronize_session]
-        except KeyError as err:
-            util.raise_(
-                sa_exc.ArgumentError(
-                    "Valid strategies for session synchronization "
-                    "are %s" % (", ".join(sorted(repr(x) for x in lookup)))
-                ),
-                replace_context=err,
+    def orm_pre_session_exec(
+        cls, session, statement, params, execution_options, bind_arguments
+    ):
+        sync = execution_options.get("synchronize_session", None)
+        if sync is None:
+            sync = statement._execution_options.get(
+                "synchronize_session", None
             )
-        else:
-            return klass(*arg)
-
-    def exec_(self):
-        self._do_before_compile()
-        self._do_pre()
-        self._do_pre_synchronize()
-        self._do_exec()
-        self._do_post_synchronize()
-        self._do_post()
-
-    def _execute_stmt(self, stmt):
-        self.result = self.query._execute_crud(stmt, self.mapper)
-        self.rowcount = self.result.rowcount
-
-    def _do_before_compile(self):
-        raise NotImplementedError()
 
-    @util.preload_module("sqlalchemy.orm.context")
-    def _do_pre(self):
-        query_context = util.preloaded.orm_context
-        query = self.query
-
-        self.compile_state = (
-            self.context
-        ) = compile_state = query._compile_state()
-
-        self.mapper = compile_state._entity_zero()
-
-        if isinstance(
-            compile_state._entities[0], query_context._RawColumnEntity,
-        ):
-            # check for special case of query(table)
-            tables = set()
-            for ent in compile_state._entities:
-                if not isinstance(ent, query_context._RawColumnEntity,):
-                    tables.clear()
-                    break
-                else:
-                    tables.update(_from_objects(ent.column))
+        update_options = execution_options.get(
+            "_sa_orm_update_options",
+            BulkUDCompileState.default_update_options,
+        )
 
-            if len(tables) != 1:
-                raise sa_exc.InvalidRequestError(
-                    "This operation requires only one Table or "
-                    "entity be specified as the target."
+        if sync is not None:
+            if sync not in ("evaluate", "fetch", False):
+                raise sa_exc.ArgumentError(
+                    "Valid strategies for session synchronization "
+                    "are 'evaluate', 'fetch', False"
                 )
-            else:
-                self.primary_table = tables.pop()
+            update_options += {"_synchronize_session": sync}
 
+        bind_arguments["clause"] = statement
+        try:
+            plugin_subject = statement._propagate_attrs["plugin_subject"]
+        except KeyError:
+            assert False, "statement had 'orm' plugin but no plugin_subject"
         else:
-            self.primary_table = compile_state._only_entity_zero(
-                "This operation requires only one Table or "
-                "entity be specified as the target."
-            ).mapper.local_table
+            bind_arguments["mapper"] = plugin_subject.mapper
 
-        session = query.session
+        update_options += {"_subject_mapper": plugin_subject.mapper}
 
-        if query.load_options._autoflush:
+        if update_options._autoflush:
             session._autoflush()
 
-    def _do_pre_synchronize(self):
-        pass
+        if update_options._synchronize_session == "evaluate":
+            update_options = cls._do_pre_synchronize_evaluate(
+                session,
+                statement,
+                params,
+                execution_options,
+                bind_arguments,
+                update_options,
+            )
+        elif update_options._synchronize_session == "fetch":
+            update_options = cls._do_pre_synchronize_fetch(
+                session,
+                statement,
+                params,
+                execution_options,
+                bind_arguments,
+                update_options,
+            )
 
-    def _do_post_synchronize(self):
-        pass
+        return util.immutabledict(execution_options).union(
+            dict(_sa_orm_update_options=update_options)
+        )
 
+    @classmethod
+    def orm_setup_cursor_result(
+        cls, session, statement, execution_options, bind_arguments, result
+    ):
+        update_options = execution_options["_sa_orm_update_options"]
+        if update_options._synchronize_session == "evaluate":
+            cls._do_post_synchronize_evaluate(session, update_options)
+        elif update_options._synchronize_session == "fetch":
+            cls._do_post_synchronize_fetch(session, update_options)
 
-class BulkEvaluate(BulkUD):
-    """BulkUD which does the 'evaluate' method of session state resolution."""
+        return result
 
-    def _additional_evaluators(self, evaluator_compiler):
-        pass
+    @classmethod
+    def _do_pre_synchronize_evaluate(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        update_options,
+    ):
+        mapper = update_options._subject_mapper
+        target_cls = mapper.class_
 
-    def _do_pre_synchronize(self):
-        query = self.query
-        target_cls = self.compile_state._mapper_zero().class_
+        value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
 
         try:
             evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
-            if query._where_criteria:
+            if statement._where_criteria:
                 eval_condition = evaluator_compiler.process(
-                    *query._where_criteria
+                    *statement._where_criteria
                 )
             else:
 
                 def eval_condition(obj):
                     return True
 
-            self._additional_evaluators(evaluator_compiler)
+            # TODO: something more robust for this conditional
+            if statement.__visit_name__ == "update":
+                resolved_values = cls._get_resolved_values(mapper, statement)
+                value_evaluators = {}
+                resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+                    mapper, resolved_values
+                )
+                for key, value in resolved_keys_as_propnames:
+                    value_evaluators[key] = evaluator_compiler.process(
+                        coercions.expect(roles.ExpressionElementRole, value)
+                    )
 
         except evaluator.UnevaluatableError as err:
             util.raise_(
                 sa_exc.InvalidRequestError(
                     'Could not evaluate current criteria in Python: "%s". '
                     "Specify 'fetch' or False for the "
-                    "synchronize_session parameter." % err
+                    "synchronize_session execution option." % err
                 ),
                 from_=err,
             )
 
         # TODO: detect when the where clause is a trivial primary key match
-        self.matched_objects = [
+        matched_objects = [
             obj
-            for (
-                cls,
-                pk,
-                identity_token,
-            ), obj in query.session.identity_map.items()
-            if issubclass(cls, target_cls) and eval_condition(obj)
+            for (cls, pk, identity_token,), obj in session.identity_map.items()
+            if issubclass(cls, target_cls)
+            and eval_condition(obj)
+            and identity_token == update_options._refresh_identity_token
         ]
-
-
-class BulkFetch(BulkUD):
-    """BulkUD which does the 'fetch' method of session state resolution."""
-
-    def _do_pre_synchronize(self):
-        query = self.query
-        session = query.session
-        select_stmt = self.compile_state.statement.with_only_columns(
-            self.primary_table.primary_key
-        )
-        self.matched_rows = session.execute(
-            select_stmt, mapper=self.mapper, params=query.load_options._params
-        ).fetchall()
-
-
-class BulkUpdate(BulkUD):
-    """BulkUD which handles UPDATEs."""
-
-    def __init__(self, query, values, update_kwargs):
-        super(BulkUpdate, self).__init__(query)
-        self.values = values
-        self.update_kwargs = update_kwargs
+        return update_options + {
+            "_matched_objects": matched_objects,
+            "_value_evaluators": value_evaluators,
+            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+        }
 
     @classmethod
-    def factory(cls, query, synchronize_session, values, update_kwargs):
-        return BulkUD._factory(
-            {
-                "evaluate": BulkUpdateEvaluate,
-                "fetch": BulkUpdateFetch,
-                False: BulkUpdate,
-            },
-            synchronize_session,
-            query,
-            values,
-            update_kwargs,
-        )
-
-    def _do_before_compile(self):
-        if self.query.dispatch.before_compile_update:
-            for fn in self.query.dispatch.before_compile_update:
-                new_query = fn(self.query, self)
-                if new_query is not None:
-                    self.query = new_query
+    def _get_resolved_values(cls, mapper, statement):
+        if statement._multi_values:
+            return []
+        elif statement._ordered_values:
+            iterator = statement._ordered_values
+        elif statement._values:
+            iterator = statement._values.items()
+        else:
+            return []
 
-    @property
-    def _resolved_values(self):
         values = []
-        for k, v in (
-            self.values.items()
-            if hasattr(self.values, "items")
-            else self.values
-        ):
-            if self.mapper:
-                if isinstance(k, util.string_types):
-                    desc = sql.util._entity_namespace_key(self.mapper, k)
-                    values.extend(desc._bulk_update_tuples(v))
-                elif isinstance(k, attributes.QueryableAttribute):
-                    values.extend(k._bulk_update_tuples(v))
+        if iterator:
+            for k, v in iterator:
+                if mapper:
+                    if isinstance(k, util.string_types):
+                        desc = sql.util._entity_namespace_key(mapper, k)
+                        values.extend(desc._bulk_update_tuples(v))
+                    elif isinstance(k, attributes.QueryableAttribute):
+                        values.extend(k._bulk_update_tuples(v))
+                    else:
+                        values.append((k, v))
                 else:
                     values.append((k, v))
-            else:
-                values.append((k, v))
         return values
 
-    @property
-    def _resolved_values_keys_as_propnames(self):
+    @classmethod
+    def _resolved_keys_as_propnames(cls, mapper, resolved_values):
         values = []
-        for k, v in self._resolved_values:
+        for k, v in resolved_values:
             if isinstance(k, attributes.QueryableAttribute):
                 values.append((k.key, v))
                 continue
             elif hasattr(k, "__clause_element__"):
                 k = k.__clause_element__()
 
-            if self.mapper and isinstance(k, expression.ColumnElement):
+            if mapper and isinstance(k, expression.ColumnElement):
                 try:
-                    attr = self.mapper._columntoproperty[k]
+                    attr = mapper._columntoproperty[k]
                 except orm_exc.UnmappedColumnError:
                     pass
                 else:
@@ -1897,87 +1851,99 @@ class BulkUpdate(BulkUD):
                 )
         return values
 
-    def _do_exec(self):
-        values = self._resolved_values
+    @classmethod
+    def _do_pre_synchronize_fetch(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        update_options,
+    ):
+        mapper = update_options._subject_mapper
 
-        if not self.update_kwargs.get("preserve_parameter_order", False):
-            values = dict(values)
+        if mapper:
+            primary_table = mapper.local_table
+        else:
+            primary_table = statement._raw_columns[0]
 
-        update_stmt = sql.update(
-            self.primary_table, **self.update_kwargs
-        ).values(values)
+        # note this creates a Select() *without* the ORM plugin.
+        # we don't want that here.
+        select_stmt = future_select(*primary_table.primary_key)
+        select_stmt._where_criteria = statement._where_criteria
 
-        update_stmt._where_criteria = self.compile_state._where_criteria
+        matched_rows = session.execute(
+            select_stmt, params, execution_options, bind_arguments
+        ).fetchall()
 
-        self._execute_stmt(update_stmt)
+        if statement.__visit_name__ == "update":
+            resolved_values = cls._get_resolved_values(mapper, statement)
+            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+                mapper, resolved_values
+            )
+        else:
+            resolved_keys_as_propnames = _EMPTY_DICT
 
-    def _do_post(self):
-        session = self.query.session
-        session.dispatch.after_bulk_update(self)
+        return update_options + {
+            "_matched_rows": matched_rows,
+            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+        }
 
 
-class BulkDelete(BulkUD):
-    """BulkUD which handles DELETEs."""
+@CompileState.plugin_for("orm", "update")
+class BulkORMUpdate(UpdateDMLState, BulkUDCompileState):
+    @classmethod
+    def create_for_statement(cls, statement, compiler, **kw):
 
-    def __init__(self, query):
-        super(BulkDelete, self).__init__(query)
+        self = cls.__new__(cls)
 
-    @classmethod
-    def factory(cls, query, synchronize_session):
-        return BulkUD._factory(
-            {
-                "evaluate": BulkDeleteEvaluate,
-                "fetch": BulkDeleteFetch,
-                False: BulkDelete,
-            },
-            synchronize_session,
-            query,
+        self.mapper = mapper = statement.table._annotations.get(
+            "parentmapper", None
         )
 
-    def _do_before_compile(self):
-        if self.query.dispatch.before_compile_delete:
-            for fn in self.query.dispatch.before_compile_delete:
-                new_query = fn(self.query, self)
-                if new_query is not None:
-                    self.query = new_query
+        self._resolved_values = cls._get_resolved_values(mapper, statement)
 
-    def _do_exec(self):
-        delete_stmt = sql.delete(self.primary_table,)
-        delete_stmt._where_criteria = self.compile_state._where_criteria
+        if not statement._preserve_parameter_order and statement._values:
+            self._resolved_values = dict(self._resolved_values)
 
-        self._execute_stmt(delete_stmt)
+        new_stmt = sql.Update.__new__(sql.Update)
+        new_stmt.__dict__.update(statement.__dict__)
+        new_stmt.table = mapper.local_table
 
-    def _do_post(self):
-        session = self.query.session
-        session.dispatch.after_bulk_delete(self)
+        # note if the statement has _multi_values, these
+        # are passed through to the new statement, which will then raise
+        # InvalidRequestError because UPDATE doesn't support multi_values
+        # right now.
+        if statement._ordered_values:
+            new_stmt._ordered_values = self._resolved_values
+        elif statement._values:
+            new_stmt._values = self._resolved_values
 
+        UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
 
-class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
-    """BulkUD which handles UPDATEs using the "evaluate"
-    method of session resolution."""
+        return self
 
-    def _additional_evaluators(self, evaluator_compiler):
-        self.value_evaluators = {}
-        values = self._resolved_values_keys_as_propnames
-        for key, value in values:
-            self.value_evaluators[key] = evaluator_compiler.process(
-                coercions.expect(roles.ExpressionElementRole, value)
-            )
+    @classmethod
+    def _do_post_synchronize_evaluate(cls, session, update_options):
 
-    def _do_post_synchronize(self):
-        session = self.query.session
         states = set()
-        evaluated_keys = list(self.value_evaluators.keys())
-        for obj in self.matched_objects:
+        evaluated_keys = list(update_options._value_evaluators.keys())
+        for obj in update_options._matched_objects:
+
             state, dict_ = (
                 attributes.instance_state(obj),
                 attributes.instance_dict(obj),
             )
 
+            assert (
+                state.identity_token == update_options._refresh_identity_token
+            )
+
             # only evaluate unmodified attributes
             to_evaluate = state.unmodified.intersection(evaluated_keys)
             for key in to_evaluate:
-                dict_[key] = self.value_evaluators[key](obj)
+                dict_[key] = update_options._value_evaluators[key](obj)
 
             state.manager.dispatch.refresh(state, None, to_evaluate)
 
@@ -1991,39 +1957,25 @@ class BulkUpdateEvaluate(BulkEvaluate, BulkUpdate):
             states.add(state)
         session._register_altered(states)
 
-
-class BulkDeleteEvaluate(BulkEvaluate, BulkDelete):
-    """BulkUD which handles DELETEs using the "evaluate"
-    method of session resolution."""
-
-    def _do_post_synchronize(self):
-        self.query.session._remove_newly_deleted(
-            [attributes.instance_state(obj) for obj in self.matched_objects]
-        )
-
-
-class BulkUpdateFetch(BulkFetch, BulkUpdate):
-    """BulkUD which handles UPDATEs using the "fetch"
-    method of session resolution."""
-
-    def _do_post_synchronize(self):
-        session = self.query.session
-        target_mapper = self.compile_state._mapper_zero()
+    @classmethod
+    def _do_post_synchronize_fetch(cls, session, update_options):
+        target_mapper = update_options._subject_mapper
 
         states = set(
             [
                 attributes.instance_state(session.identity_map[identity_key])
                 for identity_key in [
                     target_mapper.identity_key_from_primary_key(
-                        list(primary_key)
+                        list(primary_key),
+                        identity_token=update_options._refresh_identity_token,
                     )
-                    for primary_key in self.matched_rows
+                    for primary_key in update_options._matched_rows
                 ]
                 if identity_key in session.identity_map
             ]
         )
 
-        values = self._resolved_values_keys_as_propnames
+        values = update_options._resolved_keys_as_propnames
         attrib = set(k for k, v in values)
         for state in states:
             to_expire = attrib.intersection(state.dict)
@@ -2032,18 +1984,38 @@ class BulkUpdateFetch(BulkFetch, BulkUpdate):
         session._register_altered(states)
 
 
-class BulkDeleteFetch(BulkFetch, BulkDelete):
-    """BulkUD which handles DELETEs using the "fetch"
-    method of session resolution."""
+@CompileState.plugin_for("orm", "delete")
+class BulkORMDelete(DeleteDMLState, BulkUDCompileState):
+    @classmethod
+    def create_for_statement(cls, statement, compiler, **kw):
+        self = cls.__new__(cls)
+
+        self.mapper = statement.table._annotations.get("parentmapper", None)
+
+        DeleteDMLState.__init__(self, statement, compiler, **kw)
+
+        return self
+
+    @classmethod
+    def _do_post_synchronize_evaluate(cls, session, update_options):
+
+        session._remove_newly_deleted(
+            [
+                attributes.instance_state(obj)
+                for obj in update_options._matched_objects
+            ]
+        )
+
+    @classmethod
+    def _do_post_synchronize_fetch(cls, session, update_options):
+        target_mapper = update_options._subject_mapper
 
-    def _do_post_synchronize(self):
-        session = self.query.session
-        target_mapper = self.compile_state._mapper_zero()
-        for primary_key in self.matched_rows:
+        for primary_key in update_options._matched_rows:
             # TODO: inline this and call remove_newly_deleted
             # once
             identity_key = target_mapper.identity_key_from_primary_key(
-                list(primary_key)
+                list(primary_key),
+                identity_token=update_options._refresh_identity_token,
             )
             if identity_key in session.identity_map:
                 session._remove_newly_deleted(
index 5137f9b1d48a2dd502720476d9edce9dbd123926..284ea9d72bc29db57636e490ddf7901d362d5386 100644 (file)
@@ -19,12 +19,12 @@ database to return iterable result sets.
 
 """
 import itertools
+import operator
 
 from . import attributes
 from . import exc as orm_exc
 from . import interfaces
 from . import loading
-from . import persistence
 from .base import _assertions
 from .context import _column_descriptions
 from .context import _legacy_determine_last_joined_entity
@@ -2825,15 +2825,6 @@ class Query(
 
         return result
 
-    def _execute_crud(self, stmt, mapper):
-        conn = self.session.connection(
-            mapper=mapper, clause=stmt, close_with_result=True
-        )
-
-        return conn._execute_20(
-            stmt, self.load_options._params, self._execution_options
-        )
-
     def __str__(self):
         statement = self._statement_20()
 
@@ -3178,9 +3169,27 @@ class Query(
 
         """
 
-        delete_op = persistence.BulkDelete.factory(self, synchronize_session)
-        delete_op.exec_()
-        return delete_op.rowcount
+        bulk_del = BulkDelete(self,)
+        if self.dispatch.before_compile_delete:
+            for fn in self.dispatch.before_compile_delete:
+                new_query = fn(bulk_del.query, bulk_del)
+                if new_query is not None:
+                    bulk_del.query = new_query
+
+                self = bulk_del.query
+
+        delete_ = sql.delete(*self._raw_columns)
+        delete_._where_criteria = self._where_criteria
+        result = self.session.execute(
+            delete_,
+            self.load_options._params,
+            execution_options={"synchronize_session": synchronize_session},
+        )
+        bulk_del.result = result
+        self.session.dispatch.after_bulk_delete(bulk_del)
+        result.close()
+
+        return result.rowcount
 
     def update(self, values, synchronize_session="evaluate", update_args=None):
         r"""Perform a bulk update query.
@@ -3313,11 +3322,27 @@ class Query(
         """
 
         update_args = update_args or {}
-        update_op = persistence.BulkUpdate.factory(
-            self, synchronize_session, values, update_args
+
+        bulk_ud = BulkUpdate(self, values, update_args)
+
+        if self.dispatch.before_compile_update:
+            for fn in self.dispatch.before_compile_update:
+                new_query = fn(bulk_ud.query, bulk_ud)
+                if new_query is not None:
+                    bulk_ud.query = new_query
+            self = bulk_ud.query
+
+        upd = sql.update(*self._raw_columns, **update_args).values(values)
+        upd._where_criteria = self._where_criteria
+        result = self.session.execute(
+            upd,
+            self.load_options._params,
+            execution_options={"synchronize_session": synchronize_session},
         )
-        update_op.exec_()
-        return update_op.rowcount
+        bulk_ud.result = result
+        self.session.dispatch.after_bulk_update(bulk_ud)
+        result.close()
+        return result.rowcount
 
     def _compile_state(self, for_statement=False, **kw):
         """Create an out-of-compiler ORMCompileState object.
@@ -3427,3 +3452,59 @@ class AliasOption(interfaces.LoaderOption):
 
     def process_compile_state(self, compile_state):
         pass
+
+
+class BulkUD(object):
+    """State used for the orm.Query version of update() / delete().
+
+    This object is now specific to Query only.
+
+    """
+
+    def __init__(self, query):
+        self.query = query.enable_eagerloads(False)
+        self._validate_query_state()
+        self.mapper = self.query._entity_from_pre_ent_zero()
+
+    def _validate_query_state(self):
+        for attr, methname, notset, op in (
+            ("_limit_clause", "limit()", None, operator.is_),
+            ("_offset_clause", "offset()", None, operator.is_),
+            ("_order_by_clauses", "order_by()", (), operator.eq),
+            ("_group_by_clauses", "group_by()", (), operator.eq),
+            ("_distinct", "distinct()", False, operator.is_),
+            (
+                "_from_obj",
+                "join(), outerjoin(), select_from(), or from_self()",
+                (),
+                operator.eq,
+            ),
+            (
+                "_legacy_setup_joins",
+                "join(), outerjoin(), select_from(), or from_self()",
+                (),
+                operator.eq,
+            ),
+        ):
+            if not op(getattr(self.query, attr), notset):
+                raise sa_exc.InvalidRequestError(
+                    "Can't call Query.update() or Query.delete() "
+                    "when %s has been called" % (methname,)
+                )
+
+    @property
+    def session(self):
+        return self.query.session
+
+
+class BulkUpdate(BulkUD):
+    """BulkUD which handles UPDATEs."""
+
+    def __init__(self, query, values, update_kwargs):
+        super(BulkUpdate, self).__init__(query)
+        self.values = values
+        self.update_kwargs = update_kwargs
+
+
+class BulkDelete(BulkUD):
+    """BulkUD which handles DELETEs."""
index ee42419a261d0759214527dc9ed27899af2cfa02..5ad8bcf2f2ffa78f20828ecdf39c31dcf86f0398 100644 (file)
@@ -33,7 +33,9 @@ from .. import future
 from .. import util
 from ..inspection import inspect
 from ..sql import coercions
+from ..sql import dml
 from ..sql import roles
+from ..sql import selectable
 from ..sql import visitors
 from ..sql.base import CompileState
 
@@ -113,16 +115,24 @@ class ORMExecuteState(util.MemoizedSlots):
         "_execution_options",
         "_merged_execution_options",
         "bind_arguments",
+        "_compile_state_cls",
     )
 
     def __init__(
-        self, session, statement, parameters, execution_options, bind_arguments
+        self,
+        session,
+        statement,
+        parameters,
+        execution_options,
+        bind_arguments,
+        compile_state_cls,
     ):
         self.session = session
         self.statement = statement
         self.parameters = parameters
         self._execution_options = execution_options
         self.bind_arguments = bind_arguments
+        self._compile_state_cls = compile_state_cls
 
     def invoke_statement(
         self,
@@ -193,6 +203,38 @@ class ORMExecuteState(util.MemoizedSlots):
             statement, _params, _execution_options, _bind_arguments
         )
 
+    @property
+    def is_orm_statement(self):
+        """return True if the operation is an ORM statement.
+
+        This indictes that the select(), update(), or delete() being
+        invoked contains ORM entities as subjects.   For a statement
+        that does not have ORM entities and instead refers only to
+        :class:`.Table` metadata, it is invoked as a Core SQL statement
+        and no ORM-level automation takes place.
+
+        """
+        return self._compile_state_cls is not None
+
+    @property
+    def is_select(self):
+        """return True if this is a SELECT operation."""
+        return isinstance(self.statement, selectable.Select)
+
+    @property
+    def is_update(self):
+        """return True if this is an UPDATE operation."""
+        return isinstance(self.statement, dml.Update)
+
+    @property
+    def is_delete(self):
+        """return True if this is a DELETE operation."""
+        return isinstance(self.statement, dml.Delete)
+
+    @property
+    def _is_crud(self):
+        return isinstance(self.statement, (dml.Update, dml.Delete))
+
     @property
     def execution_options(self):
         """Placeholder for execution options.
@@ -270,10 +312,30 @@ class ORMExecuteState(util.MemoizedSlots):
     def load_options(self):
         """Return the load_options that will be used for this execution."""
 
+        if not self.is_select:
+            raise sa_exc.InvalidRequestError(
+                "This ORM execution is not against a SELECT statement "
+                "so there are no load options."
+            )
         return self._execution_options.get(
             "_sa_orm_load_options", context.QueryContext.default_load_options
         )
 
+    @property
+    def update_delete_options(self):
+        """Return the update_delete_options that will be used for this
+        execution."""
+
+        if not self._is_crud:
+            raise sa_exc.InvalidRequestError(
+                "This ORM execution is not against an UPDATE or DELETE "
+                "statement so there are no update options."
+            )
+        return self._execution_options.get(
+            "_sa_orm_update_options",
+            persistence.BulkUDCompileState.default_update_options,
+        )
+
     @property
     def user_defined_options(self):
         """The sequence of :class:`.UserDefinedOptions` that have been
@@ -1455,35 +1517,37 @@ class Session(_SessionClassMethods):
             compile_state_cls = CompileState._get_plugin_class_for_plugin(
                 statement, "orm"
             )
+        else:
+            compile_state_cls = None
 
-            compile_state_cls.orm_pre_session_exec(
-                self, statement, execution_options, bind_arguments
+        if compile_state_cls is not None:
+            execution_options = compile_state_cls.orm_pre_session_exec(
+                self, statement, params, execution_options, bind_arguments
             )
-
-            if self.dispatch.do_orm_execute:
-                skip_events = bind_arguments.pop("_sa_skip_events", False)
-
-                if not skip_events:
-                    orm_exec_state = ORMExecuteState(
-                        self,
-                        statement,
-                        params,
-                        execution_options,
-                        bind_arguments,
-                    )
-                    for fn in self.dispatch.do_orm_execute:
-                        result = fn(orm_exec_state)
-                        if result:
-                            return result
-
         else:
-            compile_state_cls = None
             bind_arguments.setdefault("clause", statement)
             if statement._is_future:
                 execution_options = util.immutabledict().merge_with(
                     execution_options, {"future_result": True}
                 )
 
+        if self.dispatch.do_orm_execute:
+            # run this event whether or not we are in ORM mode
+            skip_events = bind_arguments.get("_sa_skip_events", False)
+            if not skip_events:
+                orm_exec_state = ORMExecuteState(
+                    self,
+                    statement,
+                    params,
+                    execution_options,
+                    bind_arguments,
+                    compile_state_cls,
+                )
+                for fn in self.dispatch.do_orm_execute:
+                    result = fn(orm_exec_state)
+                    if result:
+                        return result
+
         bind = self.get_bind(**bind_arguments)
 
         conn = self._connection_for_bind(bind, close_with_result=True)
@@ -1601,8 +1665,8 @@ class Session(_SessionClassMethods):
                 self.__binds[insp] = bind
             elif insp.is_mapper:
                 self.__binds[insp.class_] = bind
-                for selectable in insp._all_tables:
-                    self.__binds[selectable] = bind
+                for _selectable in insp._all_tables:
+                    self.__binds[_selectable] = bind
             else:
                 raise sa_exc.ArgumentError(
                     "Not an acceptable bind target: %s" % key
@@ -1664,7 +1728,9 @@ class Session(_SessionClassMethods):
         """
         self._add_bind(table, bind)
 
-    def get_bind(self, mapper=None, clause=None, bind=None):
+    def get_bind(
+        self, mapper=None, clause=None, bind=None, _sa_skip_events=None
+    ):
         """Return a "bind" to which this :class:`.Session` is bound.
 
         The "bind" is usually an instance of :class:`_engine.Engine`,
index f1431908902d4bc339e57a910f67898cc290b6a8..5dd3b519ac6bed1cdb60707e060e95d76a0fb20c 100644 (file)
@@ -446,10 +446,14 @@ class CompileState(object):
             plugin_name = statement._propagate_attrs.get(
                 "compile_state_plugin", "default"
             )
-        else:
-            plugin_name = "default"
+            klass = cls.plugins.get(
+                (plugin_name, statement.__visit_name__), None
+            )
+            if klass is None:
+                klass = cls.plugins[("default", statement.__visit_name__)]
 
-        klass = cls.plugins[(plugin_name, statement.__visit_name__)]
+        else:
+            klass = cls.plugins[("default", statement.__visit_name__)]
 
         if klass is cls:
             return cls(statement, compiler, **kw)
index db43e42a63ec073eaa94757eccaa71cd686a92fe..4c6a0317a4541d6b93105806f24948a5e6537ced 100644 (file)
@@ -755,6 +755,16 @@ class AnonymizedFromClauseImpl(StrictFromClauseImpl):
         return element.alias(name=name, flat=flat)
 
 
+class DMLTableImpl(_SelectIsNotFrom, _NoTextCoercion, RoleImpl):
+    __slots__ = ()
+
+    def _post_coercion(self, element, **kw):
+        if "dml_table" in element._annotations:
+            return element._annotations["dml_table"]
+        else:
+            return element
+
+
 class DMLSelectImpl(_NoTextCoercion, RoleImpl):
     __slots__ = ()
 
index f4160b5520cb06c031d483f1f702bbf5dbd0941d..2519438d1b757b6182a57195a31a0b1bc07c508f 100644 (file)
@@ -3215,6 +3215,8 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         if toplevel:
             self.isupdate = True
+            if not self.compile_state:
+                self.compile_state = compile_state
 
         extra_froms = compile_state._extra_froms
         is_multitable = bool(extra_froms)
@@ -3342,6 +3344,8 @@ class SQLCompiler(Compiled):
         toplevel = not self.stack
         if toplevel:
             self.isdelete = True
+            if not self.compile_state:
+                self.compile_state = compile_state
 
         extra_froms = compile_state._extra_froms
 
index 467a764d62555a9a994e3107ccd1032c4390882a..a82641d77c9b8c63491af6c890a22560fc4b358e 100644 (file)
@@ -19,6 +19,7 @@ from .base import CompileState
 from .base import DialectKWArgs
 from .base import Executable
 from .base import HasCompileState
+from .elements import BooleanClauseList
 from .elements import ClauseElement
 from .elements import Null
 from .selectable import HasCTE
@@ -150,7 +151,6 @@ class UpdateDMLState(DMLState):
 
     def __init__(self, statement, compiler, **kw):
         self.statement = statement
-
         self.isupdate = True
         self._preserve_parameter_order = statement._preserve_parameter_order
         if statement._ordered_values is not None:
@@ -447,7 +447,9 @@ class ValuesBase(UpdateBase):
     _returning = ()
 
     def __init__(self, table, values, prefixes):
-        self.table = coercions.expect(roles.FromClauseRole, table)
+        self.table = coercions.expect(
+            roles.DMLTableRole, table, apply_propagate_attrs=self
+        )
         if values is not None:
             self.values.non_generative(self, values)
         if prefixes:
@@ -949,6 +951,28 @@ class DMLWhereBase(object):
             coercions.expect(roles.WhereHavingRole, whereclause),
         )
 
+    def filter(self, *criteria):
+        """A synonym for the :meth:`_dml.DMLWhereBase.where` method."""
+
+        return self.where(*criteria)
+
+    @property
+    def whereclause(self):
+        """Return the completed WHERE clause for this :class:`.DMLWhereBase`
+        statement.
+
+        This assembles the current collection of WHERE criteria
+        into a single :class:`_expression.BooleanClauseList` construct.
+
+
+        .. versionadded:: 1.4
+
+        """
+
+        return BooleanClauseList._construct_for_whereclause(
+            self._where_criteria
+        )
+
 
 class Update(DMLWhereBase, ValuesBase):
     """Represent an Update construct.
@@ -1266,7 +1290,9 @@ class Delete(DMLWhereBase, UpdateBase):
 
         """
         self._bind = bind
-        self.table = coercions.expect(roles.FromClauseRole, table)
+        self.table = coercions.expect(
+            roles.DMLTableRole, table, apply_propagate_attrs=self
+        )
         self._returning = returning
 
         if prefixes:
index 5a55fe5f28b04afc9aa3dbc491d8d5b466e4dbe9..3d94ec9ff593d5a83acc9a82b2241ce5e403f4ae 100644 (file)
@@ -184,10 +184,15 @@ class CompoundElementRole(SQLRole):
     )
 
 
+# TODO: are we using this?
 class DMLRole(StatementRole):
     pass
 
 
+class DMLTableRole(FromClauseRole):
+    _role_name = "subject table for an INSERT, UPDATE or DELETE"
+
+
 class DMLColumnRole(SQLRole):
     _role_name = "SET/VALUES column expression or string key"
 
index d6845e05f72184b5cb39b209a628316da47128fd..a95fc561ae1bbac34afab4baafbe5093c41181cd 100644 (file)
@@ -789,7 +789,7 @@ class FromClause(roles.AnonymizedFromClauseRole, Selectable):
         self._reset_column_collection()
 
 
-class Join(FromClause):
+class Join(roles.DMLTableRole, FromClause):
     """represent a ``JOIN`` construct between two
     :class:`_expression.FromClause`
     elements.
@@ -1406,7 +1406,7 @@ class AliasedReturnsRows(NoInit, FromClause):
         return self.element.bind
 
 
-class Alias(AliasedReturnsRows):
+class Alias(roles.DMLTableRole, AliasedReturnsRows):
     """Represents an table or selectable alias (AS).
 
     Represents an alias, as typically applied to any table or
@@ -1987,7 +1987,7 @@ class FromGrouping(GroupedElement, FromClause):
         self.element = state["element"]
 
 
-class TableClause(Immutable, FromClause):
+class TableClause(roles.DMLTableRole, Immutable, FromClause):
     """Represents a minimal "table" construct.
 
     This is a lightweight table object that has only a name, a
index 388097e45a263183467ff1f4429b7b683548b5aa..68281f33d1048915dcdc85256c65ef2aa55cc830 100644 (file)
@@ -10,6 +10,7 @@ from .. import util
 from ..inspection import inspect
 from ..util import collections_abc
 from ..util import HasMemoized
+from ..util import py37
 
 SKIP_TRAVERSE = util.symbol("skip_traverse")
 COMPARE_FAILED = False
@@ -562,23 +563,38 @@ class _CacheKey(ExtendedInternalTraversal):
         )
 
     def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
+        if py37:
+            # in py37 we can assume two dictionaries created in the same
+            # insert ordering will retain that sorting
+            return (
+                attrname,
+                tuple(
+                    (
+                        k._gen_cache_key(anon_map, bindparams)
+                        if hasattr(k, "__clause_element__")
+                        else k,
+                        obj[k]._gen_cache_key(anon_map, bindparams),
+                    )
+                    for k in obj
+                ),
+            )
+        else:
+            expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
+            if expr_values:
+                # expr values can't be sorted deterministically right now,
+                # so no cache
+                anon_map[NO_CACHE] = True
+                return ()
 
-        expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
-        if expr_values:
-            # expr values can't be sorted deterministically right now,
-            # so no cache
-            anon_map[NO_CACHE] = True
-            return ()
-
-        str_values = expr_values.symmetric_difference(obj)
+            str_values = expr_values.symmetric_difference(obj)
 
-        return (
-            attrname,
-            tuple(
-                (k, obj[k]._gen_cache_key(anon_map, bindparams))
-                for k in sorted(str_values)
-            ),
-        )
+            return (
+                attrname,
+                tuple(
+                    (k, obj[k]._gen_cache_key(anon_map, bindparams))
+                    for k in sorted(str_values)
+                ),
+            )
 
     def visit_dml_multi_values(
         self, attrname, obj, parent, anon_map, bindparams
@@ -1130,6 +1146,18 @@ class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
             for lv, rv in zip(left, right):
                 if not self._compare_dml_values_or_ce(lv, rv, **kw):
                     return COMPARE_FAILED
+        elif isinstance(right, collections_abc.Sequence):
+            return COMPARE_FAILED
+        elif py37:
+            # dictionaries guaranteed to support insert ordering in
+            # py37 so that we can compare the keys in order.  without
+            # this, we can't compare SQL expression keys because we don't
+            # know which key is which
+            for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
+                if not self._compare_dml_values_or_ce(lk, rk, **kw):
+                    return COMPARE_FAILED
+                if not self._compare_dml_values_or_ce(lv, rv, **kw):
+                    return COMPARE_FAILED
         else:
             for lk in left:
                 lv = left[lk]
index 0ea9f067e523fb202c8888fb38ebc493b3a439bf..54da06a3daadaea05f4400817342ed55786cf6c4 100644 (file)
@@ -403,10 +403,6 @@ class AssertsCompiledSQL(object):
                 LABEL_STYLE_TABLENAME_PLUS_COL
             )
             clause = compile_state.statement
-        elif isinstance(clause, orm.persistence.BulkUD):
-            with mock.patch.object(clause, "_execute_stmt") as stmt_mock:
-                clause.exec_()
-                clause = stmt_mock.mock_calls[0][1][0]
 
         if compile_kwargs:
             kw["compile_kwargs"] = compile_kwargs
index 55a6cdcf90b324365dee247f4369ab75400ef9d4..273570357b09f600df0913bd840eed8f0a4f6efe 100644 (file)
@@ -65,6 +65,7 @@ from .compat import pickle  # noqa
 from .compat import print_  # noqa
 from .compat import py2k  # noqa
 from .compat import py36  # noqa
+from .compat import py37  # noqa
 from .compat import py3k  # noqa
 from .compat import quote_plus  # noqa
 from .compat import raise_  # noqa
index 247dbc13c3e3dbbf03a6f9738710dc9e5c770d78..5c46395f9a1e0d780a133ed16186308b77ec4868 100644 (file)
@@ -15,6 +15,7 @@ import platform
 import sys
 
 
+py37 = sys.version_info >= (3, 7)
 py36 = sys.version_info >= (3, 6)
 py3k = sys.version_info >= (3, 0)
 py2k = sys.version_info < (3, 0)
index 1b31c96e9c8917f40c58c4386465973e03c8231c..188e8e929aa4370dfdd4e3f02c00a7b36c3de41a 100644 (file)
@@ -859,6 +859,7 @@ class JoinedEagerLoadTest(fixtures.MappedTest):
                 ORMCompileState.orm_pre_session_exec(
                     sess,
                     compile_state.select_statement,
+                    {},
                     exec_opts,
                     bind_arguments,
                 )
index ca65111da211228724421bdc4b73b2bfe41f76a6..ce0e7b945274c2a265bae3c49ed546d2af959441 100644 (file)
@@ -466,10 +466,16 @@ class ResultTest(fixtures.TestBase):
     def test_scalar_one(self):
         result = self._fixture(num_rows=1)
 
+        row = result.scalar_one()
+        eq_(row, 1)
+
+    def test_scalars_plus_one(self):
+        result = self._fixture(num_rows=1)
+
         row = result.scalars().one()
         eq_(row, 1)
 
-    def test_scalar_one_none(self):
+    def test_scalars_plus_one_none(self):
         result = self._fixture(num_rows=0)
 
         result = result.scalars()
@@ -488,6 +494,21 @@ class ResultTest(fixtures.TestBase):
             result.one,
         )
 
+    def test_one_or_none(self):
+        result = self._fixture(num_rows=1)
+
+        eq_(result.one_or_none(), (1, 1, 1))
+
+    def test_scalar_one_or_none(self):
+        result = self._fixture(num_rows=1)
+
+        eq_(result.scalar_one_or_none(), 1)
+
+    def test_scalar_one_or_none_none(self):
+        result = self._fixture(num_rows=0)
+
+        eq_(result.scalar_one_or_none(), None)
+
     def test_one_or_none_none(self):
         result = self._fixture(num_rows=0)
 
index bb103370519dc0ce3335b0e66fbf2e52aa2680d3..c0029fbb6389b0b693fdef5c678ca5e97208c9fa 100644 (file)
@@ -3,6 +3,7 @@ import os
 
 from sqlalchemy import Column
 from sqlalchemy import DateTime
+from sqlalchemy import delete
 from sqlalchemy import event
 from sqlalchemy import Float
 from sqlalchemy import ForeignKey
@@ -13,6 +14,7 @@ from sqlalchemy import sql
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import testing
+from sqlalchemy import update
 from sqlalchemy import util
 from sqlalchemy.ext.horizontal_shard import ShardedSession
 from sqlalchemy.future import select as future_select
@@ -444,7 +446,7 @@ class ShardTest(object):
         t = get_tokyo(sess2)
         eq_(t.city, tokyo.city)
 
-    def test_bulk_update(self):
+    def test_bulk_update_synchronize_evaluate(self):
         sess = self._fixture_data()
 
         eq_(
@@ -456,7 +458,8 @@ class ShardTest(object):
         eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
 
         sess.query(Report).filter(Report.temperature >= 80).update(
-            {"temperature": Report.temperature + 6}
+            {"temperature": Report.temperature + 6},
+            synchronize_session="evaluate",
         )
 
         eq_(
@@ -467,13 +470,58 @@ class ShardTest(object):
         # test synchronize session as well
         eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
 
-    def test_bulk_delete(self):
+    def test_bulk_update_synchronize_fetch(self):
+        sess = self._fixture_data()
+
+        eq_(
+            set(row.temperature for row in sess.query(Report.temperature)),
+            {80.0, 75.0, 85.0},
+        )
+
+        temps = sess.query(Report).all()
+        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+        sess.query(Report).filter(Report.temperature >= 80).update(
+            {"temperature": Report.temperature + 6},
+            synchronize_session="fetch",
+        )
+
+        eq_(
+            set(row.temperature for row in sess.query(Report.temperature)),
+            {86.0, 75.0, 91.0},
+        )
+
+        # test synchronize session as well
+        eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
+
+    def test_bulk_delete_synchronize_evaluate(self):
+        sess = self._fixture_data()
+
+        temps = sess.query(Report).all()
+        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+        sess.query(Report).filter(Report.temperature >= 80).delete(
+            synchronize_session="evaluate"
+        )
+
+        eq_(
+            set(row.temperature for row in sess.query(Report.temperature)),
+            {75.0},
+        )
+
+        # test synchronize session as well
+        for t in temps:
+            assert inspect(t).deleted is (t.temperature >= 80)
+
+    def test_bulk_delete_synchronize_fetch(self):
         sess = self._fixture_data()
 
         temps = sess.query(Report).all()
         eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
 
-        sess.query(Report).filter(Report.temperature >= 80).delete()
+        sess.query(Report).filter(Report.temperature >= 80).delete(
+            synchronize_session="fetch"
+        )
 
         eq_(
             set(row.temperature for row in sess.query(Report.temperature)),
@@ -484,6 +532,118 @@ class ShardTest(object):
         for t in temps:
             assert inspect(t).deleted is (t.temperature >= 80)
 
+    def test_bulk_update_future_synchronize_evaluate(self):
+        sess = self._fixture_data()
+
+        eq_(
+            set(
+                row.temperature
+                for row in sess.execute(future_select(Report.temperature))
+            ),
+            {80.0, 75.0, 85.0},
+        )
+
+        temps = sess.execute(future_select(Report)).scalars().all()
+        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+        sess.execute(
+            update(Report)
+            .filter(Report.temperature >= 80)
+            .values({"temperature": Report.temperature + 6},)
+            .execution_options(synchronize_session="evaluate")
+        )
+
+        eq_(
+            set(
+                row.temperature
+                for row in sess.execute(future_select(Report.temperature))
+            ),
+            {86.0, 75.0, 91.0},
+        )
+
+        # test synchronize session as well
+        eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
+
+    def test_bulk_update_future_synchronize_fetch(self):
+        sess = self._fixture_data()
+
+        eq_(
+            set(
+                row.temperature
+                for row in sess.execute(future_select(Report.temperature))
+            ),
+            {80.0, 75.0, 85.0},
+        )
+
+        temps = sess.execute(future_select(Report)).scalars().all()
+        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+        sess.execute(
+            update(Report)
+            .filter(Report.temperature >= 80)
+            .values({"temperature": Report.temperature + 6},)
+            .execution_options(synchronize_session="fetch")
+        )
+
+        eq_(
+            set(
+                row.temperature
+                for row in sess.execute(future_select(Report.temperature))
+            ),
+            {86.0, 75.0, 91.0},
+        )
+
+        # test synchronize session as well
+        eq_(set(t.temperature for t in temps), {86.0, 75.0, 91.0})
+
+    def test_bulk_delete_future_synchronize_evaluate(self):
+        sess = self._fixture_data()
+
+        temps = sess.execute(future_select(Report)).scalars().all()
+        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+        sess.execute(
+            delete(Report)
+            .filter(Report.temperature >= 80)
+            .execution_options(synchronize_session="evaluate")
+        )
+
+        eq_(
+            set(
+                row.temperature
+                for row in sess.execute(future_select(Report.temperature))
+            ),
+            {75.0},
+        )
+
+        # test synchronize session as well
+        for t in temps:
+            assert inspect(t).deleted is (t.temperature >= 80)
+
+    def test_bulk_delete_future_synchronize_fetch(self):
+        sess = self._fixture_data()
+
+        temps = sess.execute(future_select(Report)).scalars().all()
+        eq_(set(t.temperature for t in temps), {80.0, 75.0, 85.0})
+
+        sess.execute(
+            delete(Report)
+            .filter(Report.temperature >= 80)
+            .execution_options(synchronize_session="fetch")
+        )
+
+        eq_(
+            set(
+                row.temperature
+                for row in sess.execute(future_select(Report.temperature))
+            ),
+            {75.0},
+        )
+
+        # test synchronize session as well
+        for t in temps:
+            assert inspect(t).deleted is (t.temperature >= 80)
+
 
 class DistinctEngineShardTest(ShardTest, fixtures.TestBase):
     def _init_dbs(self):
index 353c52c5c0d89c4bbbd16f00f1cd4eb2a82eb58d..fbac35f7eeabd6cfdd72acfb0a2068ccdb3a52e6 100644 (file)
@@ -9,9 +9,9 @@ from sqlalchemy import String
 from sqlalchemy.ext import hybrid
 from sqlalchemy.ext.declarative import declarative_base
 from sqlalchemy.orm import aliased
-from sqlalchemy.orm import persistence
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
+from sqlalchemy.sql import update
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import AssertsCompiledSQL
 from sqlalchemy.testing import eq_
@@ -588,15 +588,10 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
     def test_update_plain(self):
         Person = self.classes.Person
 
-        s = Session()
-        q = s.query(Person)
-
-        bulk_ud = persistence.BulkUpdate.factory(
-            q, False, {Person.fname: "Dr."}, {}
-        )
+        statement = update(Person).values({Person.fname: "Dr."})
 
         self.assert_compile(
-            bulk_ud,
+            statement,
             "UPDATE person SET first_name=:first_name",
             params={"first_name": "Dr."},
         )
@@ -604,15 +599,10 @@ class BulkUpdateTest(fixtures.DeclarativeMappedTest, AssertsCompiledSQL):
     def test_update_expr(self):
         Person = self.classes.Person
 
-        s = Session()
-        q = s.query(Person)
-
-        bulk_ud = persistence.BulkUpdate.factory(
-            q, False, {Person.name: "Dr. No"}, {}
-        )
+        statement = update(Person).values({Person.name: "Dr. No"})
 
         self.assert_compile(
-            bulk_ud,
+            statement,
             "UPDATE person SET first_name=:first_name, last_name=:last_name",
             params={"first_name": "Dr.", "last_name": "No"},
         )
index 0d679e6db52e9240076fa78423b69aed1a1c57e3..da9783dfd15d8bc0b48dd7a800b105413c373767 100644 (file)
@@ -4,12 +4,13 @@ from sqlalchemy import Integer
 from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
+from sqlalchemy import update
+from sqlalchemy.future import select as future_select
 from sqlalchemy.orm import aliased
 from sqlalchemy.orm import composite
 from sqlalchemy.orm import CompositeProperty
 from sqlalchemy.orm import configure_mappers
 from sqlalchemy.orm import mapper
-from sqlalchemy.orm import persistence
 from sqlalchemy.orm import relationship
 from sqlalchemy.orm import Session
 from sqlalchemy.testing import assert_raises_message
@@ -231,17 +232,20 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
 
         sess = self._fixture()
 
-        e1 = sess.query(Edge).filter(Edge.start == Point(14, 5)).one()
+        e1 = sess.execute(
+            future_select(Edge).filter(Edge.start == Point(14, 5))
+        ).scalar_one()
 
         eq_(e1.end, Point(2, 7))
 
-        q = sess.query(Edge).filter(Edge.start == Point(14, 5))
-        bulk_ud = persistence.BulkUpdate.factory(
-            q, False, {Edge.end: Point(16, 10)}, {}
+        stmt = (
+            update(Edge)
+            .filter(Edge.start == Point(14, 5))
+            .values({Edge.end: Point(16, 10)})
         )
 
         self.assert_compile(
-            bulk_ud,
+            stmt,
             "UPDATE edges SET x2=:x2, y2=:y2 WHERE edges.x1 = :x1_1 "
             "AND edges.y1 = :y1_1",
             params={"x2": 16, "x1_1": 14, "y2": 10, "y1_1": 5},
@@ -253,12 +257,18 @@ class PointTest(fixtures.MappedTest, testing.AssertsCompiledSQL):
 
         sess = self._fixture()
 
-        e1 = sess.query(Edge).filter(Edge.start == Point(14, 5)).one()
+        e1 = sess.execute(
+            future_select(Edge).filter(Edge.start == Point(14, 5))
+        ).scalar_one()
 
         eq_(e1.end, Point(2, 7))
 
-        q = sess.query(Edge).filter(Edge.start == Point(14, 5))
-        q.update({Edge.end: Point(16, 10)})
+        stmt = (
+            update(Edge)
+            .filter(Edge.start == Point(14, 5))
+            .values({Edge.end: Point(16, 10)})
+        )
+        sess.execute(stmt)
 
         eq_(e1.end, Point(16, 10))
 
index a26d0ae267e9cf1ccc384cad6f2b4c332efbdb4f..a4f084106038c408488e3cc097f17f23d4605ea6 100644 (file)
@@ -1633,6 +1633,12 @@ class RawSelectTest(QueryTest, AssertsCompiledSQL):
             checkparams={"id_1": 5, "name": "ed"},
         )
 
+        self.assert_compile(
+            update(User).values({User.name: "ed"}).where(User.id == 5),
+            "UPDATE users SET name=:name WHERE users.id = :id_1",
+            checkparams={"id_1": 5, "name": "ed"},
+        )
+
     def test_delete_from_entity(self):
         from sqlalchemy.sql import delete
 
index 299aba8099d13380684f71051cf54510a1b867a7..2e19b94354039ae950186ff9cc34bee232cd66f1 100644 (file)
@@ -1684,7 +1684,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
         eq_(upd.session, sess)
         eq_(
             canary.after_bulk_update_legacy.mock_calls,
-            [call(sess, upd.query, upd.context, upd.result)],
+            [call(sess, upd.query, None, upd.result)],
         )
 
     def test_on_bulk_delete_hook(self):
@@ -1714,7 +1714,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
         eq_(upd.session, sess)
         eq_(
             canary.after_bulk_delete_legacy.mock_calls,
-            [call(sess, upd.query, upd.context, upd.result)],
+            [call(sess, upd.query, None, upd.result)],
         )
 
 
index c1457289aac7afa4e41e9e1137b939abf36f862e..b68e0d2e652232f638e42bdadcf17e4f96e10b0e 100644 (file)
@@ -703,6 +703,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
                 s.refresh(a1)
         # joined eager load didn't continue
         eq_(len(a1.bs), 1)
+        s.close()
 
     @_combinations
     def test_flag_resolves_existing(self, target, event_name, fn):
@@ -715,6 +716,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
         s.expire(a1)
         event.listen(target, event_name, fn, restore_load_context=True)
         s.query(A).all()
+        s.close()
 
     @_combinations
     def test_flag_resolves(self, target, event_name, fn):
@@ -728,6 +730,7 @@ class RestoreLoadContextTest(fixtures.DeclarativeMappedTest):
             s.refresh(a1)
         # joined eager load continued
         eq_(len(a1.bs), 3)
+        s.close()
 
 
 class DeclarativeEventListenTest(
@@ -1768,6 +1771,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
 
         upd = canary.after_bulk_update.mock_calls[0][1][0]
         eq_(upd.session, sess)
+        eq_(upd.result.rowcount, 0)
 
     def test_on_bulk_delete_hook(self):
         User, users = self.classes.User, self.tables.users
@@ -1787,6 +1791,7 @@ class SessionEventsTest(_RemoveListeners, _fixtures.FixtureTest):
 
         upd = canary.after_bulk_delete.mock_calls[0][1][0]
         eq_(upd.session, sess)
+        eq_(upd.result.rowcount, 0)
 
     def test_connection_emits_after_begin(self):
         sess, canary = self._listener_fixture(bind=testing.db)
index 9017ca84eef59f685d6a4028aa89fb54d70d29d3..5430fbffc542d60875cdb0d416f8a431ff22ceaa 100644 (file)
@@ -1,6 +1,7 @@
 from sqlalchemy import Boolean
 from sqlalchemy import case
 from sqlalchemy import column
+from sqlalchemy import event
 from sqlalchemy import exc
 from sqlalchemy import ForeignKey
 from sqlalchemy import func
@@ -10,6 +11,8 @@ from sqlalchemy import select
 from sqlalchemy import String
 from sqlalchemy import testing
 from sqlalchemy import text
+from sqlalchemy import update
+from sqlalchemy.future import select as future_select
 from sqlalchemy.orm import backref
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm import mapper
@@ -20,10 +23,8 @@ from sqlalchemy.testing import assert_raises
 from sqlalchemy.testing import assert_raises_message
 from sqlalchemy.testing import eq_
 from sqlalchemy.testing import fixtures
-from sqlalchemy.testing import mock
 from sqlalchemy.testing.schema import Column
 from sqlalchemy.testing.schema import Table
-from sqlalchemy.util import collections_abc
 
 
 class UpdateDeleteTest(fixtures.MappedTest):
@@ -385,6 +386,58 @@ class UpdateDeleteTest(fixtures.MappedTest):
             list(zip([15, 27, 19, 27])),
         )
 
+    def test_update_future(self):
+        User, users = self.classes.User, self.tables.users
+
+        sess = Session()
+
+        john, jack, jill, jane = (
+            sess.execute(future_select(User).order_by(User.id)).scalars().all()
+        )
+
+        sess.execute(
+            update(User)
+            .where(User.age > 29)
+            .values({"age": User.age - 10})
+            .execution_options(synchronize_session="evaluate"),
+        )
+
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 37, 29, 27])
+        eq_(
+            sess.execute(future_select(User.age).order_by(User.id)).all(),
+            list(zip([25, 37, 29, 27])),
+        )
+
+        sess.execute(
+            update(User)
+            .where(User.age > 29)
+            .values({User.age: User.age - 10})
+            .execution_options(synchronize_session="evaluate")
+        )
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 29, 27])
+        eq_(
+            sess.query(User.age).order_by(User.id).all(),
+            list(zip([25, 27, 29, 27])),
+        )
+
+        sess.query(User).filter(User.age > 27).update(
+            {users.c.age_int: User.age - 10}, synchronize_session="evaluate"
+        )
+        eq_([john.age, jack.age, jill.age, jane.age], [25, 27, 19, 27])
+        eq_(
+            sess.query(User.age).order_by(User.id).all(),
+            list(zip([25, 27, 19, 27])),
+        )
+
+        sess.query(User).filter(User.age == 25).update(
+            {User.age: User.age - 10}, synchronize_session="fetch"
+        )
+        eq_([john.age, jack.age, jill.age, jane.age], [15, 27, 19, 27])
+        eq_(
+            sess.query(User.age).order_by(User.id).all(),
+            list(zip([15, 27, 19, 27])),
+        )
+
     def test_update_against_table_col(self):
         User, users = self.classes.User, self.tables.users
 
@@ -677,41 +730,111 @@ class UpdateDeleteTest(fixtures.MappedTest):
 
         # Do an update using unordered dict and check that the parameters used
         # are ordered in table order
+
+        m1 = testing.mock.Mock()
+
+        @event.listens_for(session, "after_bulk_update")
+        def do_orm_execute(bulk_ud):
+            m1(bulk_ud.result.context.compiled.compile_state.statement)
+
         q = session.query(User)
-        with mock.patch.object(q, "_execute_crud") as exec_:
-            q.filter(User.id == 15).update({"name": "foob", "id": 123})
-            # Confirm that parameters are a dict instead of tuple or list
-            params = exec_.mock_calls[0][1][0]._values
-            assert isinstance(params, collections_abc.Mapping)
+        q.filter(User.id == 15).update({"name": "foob", "age": 123})
+        assert m1.mock_calls[0][1][0]._values
 
-    def test_update_preserve_parameter_order(self):
+    def test_update_preserve_parameter_order_query(self):
         User = self.classes.User
         session = Session()
 
         # Do update using a tuple and check that order is preserved
-        q = session.query(User)
-        with mock.patch.object(q, "_execute_crud") as exec_:
-            q.filter(User.id == 15).update(
-                (("id", 123), ("name", "foob")),
-                update_args={"preserve_parameter_order": True},
-            )
+
+        m1 = testing.mock.Mock()
+
+        @event.listens_for(session, "after_bulk_update")
+        def do_orm_execute(bulk_ud):
+
             cols = [
-                c.key for c, v in exec_.mock_calls[0][1][0]._ordered_values
+                c.key
+                for c, v in (
+                    (
+                        bulk_ud.result.context
+                    ).compiled.compile_state.statement._ordered_values
+                )
             ]
-            eq_(["id", "name"], cols)
+            m1(cols)
 
-        # Now invert the order and use a list instead, and check that order is
-        # also preserved
         q = session.query(User)
-        with mock.patch.object(q, "_execute_crud") as exec_:
-            q.filter(User.id == 15).update(
-                [("name", "foob"), ("id", 123)],
-                update_args={"preserve_parameter_order": True},
+        q.filter(User.id == 15).update(
+            (("age", 123), ("name", "foob")),
+            update_args={"preserve_parameter_order": True},
+        )
+
+        eq_(m1.mock_calls[0][1][0], ["age_int", "name"])
+
+        m1.mock_calls = []
+
+        q = session.query(User)
+        q.filter(User.id == 15).update(
+            [("name", "foob"), ("age", 123)],
+            update_args={"preserve_parameter_order": True},
+        )
+        eq_(m1.mock_calls[0][1][0], ["name", "age_int"])
+
+    def test_update_multi_values_error_future(self):
+        User = self.classes.User
+        session = Session()
+
+        # Do update using a tuple and check that order is preserved
+
+        stmt = (
+            update(User)
+            .filter(User.id == 15)
+            .values([("id", 123), ("name", "foob")])
+        )
+
+        assert_raises_message(
+            exc.InvalidRequestError,
+            "UPDATE construct does not support multiple parameter sets.",
+            session.execute,
+            stmt,
+        )
+
+    def test_update_preserve_parameter_order_future(self):
+        User = self.classes.User
+        session = Session()
+
+        # Do update using a tuple and check that order is preserved
+
+        stmt = (
+            update(User)
+            .filter(User.id == 15)
+            .ordered_values(("age", 123), ("name", "foob"))
+        )
+        result = session.execute(stmt)
+        cols = [
+            c.key
+            for c, v in (
+                (
+                    result.context
+                ).compiled.compile_state.statement._ordered_values
             )
-            cols = [
-                c.key for c, v in exec_.mock_calls[0][1][0]._ordered_values
-            ]
-            eq_(["name", "id"], cols)
+        ]
+        eq_(["age_int", "name"], cols)
+
+        # Now invert the order and use a list instead, and check that order is
+        # also preserved
+        stmt = (
+            update(User)
+            .filter(User.id == 15)
+            .ordered_values(("name", "foob"), ("age", 123),)
+        )
+        result = session.execute(stmt)
+        cols = [
+            c.key
+            for c, v in (
+                result.context
+            ).compiled.compile_state.statement._ordered_values
+        ]
+        eq_(["name", "age_int"], cols)
 
 
 class UpdateDeleteIgnoresLoadersTest(fixtures.MappedTest):
@@ -1103,16 +1226,23 @@ class ExpressionUpdateTest(fixtures.MappedTest):
 
     def test_update_args(self):
         Data = self.classes.Data
-        session = testing.mock.Mock(wraps=Session())
+        session = Session()
         update_args = {"mysql_limit": 1}
 
+        m1 = testing.mock.Mock()
+
+        @event.listens_for(session, "after_bulk_update")
+        def do_orm_execute(bulk_ud):
+            update_stmt = (
+                bulk_ud.result.context.compiled.compile_state.statement
+            )
+            m1(update_stmt)
+
         q = session.query(Data)
-        with testing.mock.patch.object(q, "_execute_crud") as exec_:
-            q.update({Data.cnt: Data.cnt + 1}, update_args=update_args)
-        eq_(exec_.call_count, 1)
-        args, kwargs = exec_.mock_calls[0][1:3]
-        eq_(len(args), 2)
-        update_stmt = args[0]
+        q.update({Data.cnt: Data.cnt + 1}, update_args=update_args)
+
+        update_stmt = m1.mock_calls[0][1][0]
+
         eq_(update_stmt.dialect_kwargs, update_args)
 
 
@@ -1163,18 +1293,22 @@ class InheritTest(fixtures.DeclarativeMappedTest):
         )
         s.commit()
 
-    def test_illegal_metadata(self):
+    @testing.only_on("mysql", "Multi table update")
+    def test_update_from_join_no_problem(self):
         person = self.classes.Person.__table__
         engineer = self.classes.Engineer.__table__
 
         sess = Session()
-        assert_raises_message(
-            exc.InvalidRequestError,
-            "This operation requires only one Table or entity be "
-            "specified as the target.",
-            sess.query(person.join(engineer)).update,
-            {},
+        sess.query(person.join(engineer)).filter(person.c.name == "e2").update(
+            {person.c.name: "updated", engineer.c.engineer_name: "e2a"},
         )
+        obj = sess.execute(
+            future_select(self.classes.Engineer).filter(
+                self.classes.Engineer.name == "updated"
+            )
+        ).scalar()
+        eq_(obj.name, "updated")
+        eq_(obj.engineer_name, "e2a")
 
     def test_update_subtable_only(self):
         Engineer = self.classes.Engineer
index 2d84ab6764cb761a6804db38c965160d277d3fc8..3f74bdbccc3dc2d5b58bb6bd695d8fee0ae1a90d 100644 (file)
@@ -87,6 +87,16 @@ table_a_2_bs = Table(
 
 table_b = Table("b", meta, Column("a", Integer), Column("b", Integer))
 
+table_b_b = Table(
+    "b_b",
+    meta,
+    Column("a", Integer),
+    Column("b", Integer),
+    Column("c", Integer),
+    Column("d", Integer),
+    Column("e", Integer),
+)
+
 table_c = Table("c", meta, Column("x", Integer), Column("y", Integer))
 
 table_d = Table("d", meta, Column("y", Integer), Column("z", Integer))
@@ -711,6 +721,54 @@ class CoreFixtures(object):
 
     fixtures.append(_statements_w_anonymous_col_names)
 
+    def _update_dml_w_dicts():
+        return (
+            table_b_b.update().values(
+                {
+                    table_b_b.c.a: 5,
+                    table_b_b.c.b: 5,
+                    table_b_b.c.c: 5,
+                    table_b_b.c.d: 5,
+                }
+            ),
+            # equivalent, but testing dictionary insert ordering as cache key
+            # / compare
+            table_b_b.update().values(
+                {
+                    table_b_b.c.a: 5,
+                    table_b_b.c.c: 5,
+                    table_b_b.c.b: 5,
+                    table_b_b.c.d: 5,
+                }
+            ),
+            table_b_b.update().values(
+                {table_b_b.c.a: 5, table_b_b.c.b: 5, "c": 5, table_b_b.c.d: 5}
+            ),
+            table_b_b.update().values(
+                {
+                    table_b_b.c.a: 5,
+                    table_b_b.c.b: 5,
+                    table_b_b.c.c: 5,
+                    table_b_b.c.d: 5,
+                    table_b_b.c.e: 10,
+                }
+            ),
+            table_b_b.update()
+            .values(
+                {
+                    table_b_b.c.a: 5,
+                    table_b_b.c.b: 5,
+                    table_b_b.c.c: 5,
+                    table_b_b.c.d: 5,
+                    table_b_b.c.e: 10,
+                }
+            )
+            .where(table_b_b.c.c > 10),
+        )
+
+    if util.py37:
+        fixtures.append(_update_dml_w_dicts)
+
 
 class CacheKeyFixture(object):
     def _run_cache_key_fixture(self, fixture, compare_values):