]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
reorg bulk persistence into a separate module
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Aug 2022 14:20:49 +0000 (10:20 -0400)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 11 Aug 2022 15:04:01 +0000 (11:04 -0400)
This restores persistence.py to only functions that are used
by unitofwork.py, and all the "bulk" stuff gets its own
module bulk_persistence.py.  Also fixes up the ORM context
class hierarchy for bulk.

This is all ahead of the ORM-insert changes coming in, so that
the later review can be about logic and not about reorganization.

Change-Id: I035896e9e77fcece866d246edf30097cccad0182

lib/sqlalchemy/orm/bulk_persistence.py [new file with mode: 0644]
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/persistence.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py

diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py
new file mode 100644 (file)
index 0000000..225292d
--- /dev/null
@@ -0,0 +1,1111 @@
+# orm/bulk_persistence.py
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+
+"""additional ORM persistence classes related to "bulk" operations,
+specifically outside of the flush() process.
+
+"""
+
+from __future__ import annotations
+
+from typing import Any
+from typing import Dict
+from typing import Iterable
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
+from . import attributes
+from . import evaluator
+from . import exc as orm_exc
+from . import persistence
+from .base import NO_VALUE
+from .context import AbstractORMCompileState
+from .. import exc as sa_exc
+from .. import sql
+from .. import util
+from ..engine import Dialect
+from ..engine import result as _result
+from ..sql import coercions
+from ..sql import expression
+from ..sql import roles
+from ..sql import select
+from ..sql import sqltypes
+from ..sql.base import _entity_namespace_key
+from ..sql.base import CompileState
+from ..sql.base import Options
+from ..sql.dml import DeleteDMLState
+from ..sql.dml import InsertDMLState
+from ..sql.dml import UpdateDMLState
+from ..util import EMPTY_DICT
+from ..util.typing import Literal
+
+if TYPE_CHECKING:
+    from .mapper import Mapper
+    from .session import ORMExecuteState
+    from .session import SessionTransaction
+    from .state import InstanceState
+
+_O = TypeVar("_O", bound=object)
+
+
+_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
+
+
+def _bulk_insert(
+    mapper: Mapper[_O],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    return_defaults: bool,
+    render_nulls: bool,
+) -> None:
+    base_mapper = mapper.base_mapper
+
+    if session_transaction.session.connection_callable:
+        raise NotImplementedError(
+            "connection_callable / per-instance sharding "
+            "not supported in bulk_insert()"
+        )
+
+    if isstates:
+        if return_defaults:
+            states = [(state, state.dict) for state in mappings]
+            mappings = [dict_ for (state, dict_) in states]
+        else:
+            mappings = [state.dict for state in mappings]
+    else:
+        mappings = list(mappings)
+
+    connection = session_transaction.connection(base_mapper)
+    for table, super_mapper in base_mapper._sorted_tables.items():
+        if not mapper.isa(super_mapper):
+            continue
+
+        records = (
+            (
+                None,
+                state_dict,
+                params,
+                mapper,
+                connection,
+                value_params,
+                has_all_pks,
+                has_all_defaults,
+            )
+            for (
+                state,
+                state_dict,
+                params,
+                mp,
+                conn,
+                value_params,
+                has_all_pks,
+                has_all_defaults,
+            ) in persistence._collect_insert_commands(
+                table,
+                ((None, mapping, mapper, connection) for mapping in mappings),
+                bulk=True,
+                return_defaults=return_defaults,
+                render_nulls=render_nulls,
+            )
+        )
+        persistence._emit_insert_statements(
+            base_mapper,
+            None,
+            super_mapper,
+            table,
+            records,
+            bookkeeping=return_defaults,
+        )
+
+    if return_defaults and isstates:
+        identity_cls = mapper._identity_class
+        identity_props = [p.key for p in mapper._identity_key_props]
+        for state, dict_ in states:
+            state.key = (
+                identity_cls,
+                tuple([dict_[key] for key in identity_props]),
+            )
+
+
+def _bulk_update(
+    mapper: Mapper[Any],
+    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
+    session_transaction: SessionTransaction,
+    isstates: bool,
+    update_changed_only: bool,
+) -> None:
+    base_mapper = mapper.base_mapper
+
+    search_keys = mapper._primary_key_propkeys
+    if mapper._version_id_prop:
+        search_keys = {mapper._version_id_prop.key}.union(search_keys)
+
+    def _changed_dict(mapper, state):
+        return dict(
+            (k, v)
+            for k, v in state.dict.items()
+            if k in state.committed_state or k in search_keys
+        )
+
+    if isstates:
+        if update_changed_only:
+            mappings = [_changed_dict(mapper, state) for state in mappings]
+        else:
+            mappings = [state.dict for state in mappings]
+    else:
+        mappings = list(mappings)
+
+    if session_transaction.session.connection_callable:
+        raise NotImplementedError(
+            "connection_callable / per-instance sharding "
+            "not supported in bulk_update()"
+        )
+
+    connection = session_transaction.connection(base_mapper)
+
+    for table, super_mapper in base_mapper._sorted_tables.items():
+        if not mapper.isa(super_mapper):
+            continue
+
+        records = persistence._collect_update_commands(
+            None,
+            table,
+            (
+                (
+                    None,
+                    mapping,
+                    mapper,
+                    connection,
+                    (
+                        mapping[mapper._version_id_prop.key]
+                        if mapper._version_id_prop
+                        else None
+                    ),
+                )
+                for mapping in mappings
+            ),
+            bulk=True,
+        )
+
+        persistence._emit_update_statements(
+            base_mapper,
+            None,
+            super_mapper,
+            table,
+            records,
+            bookkeeping=False,
+        )
+
+
+class ORMDMLState(AbstractORMCompileState):
+    @classmethod
+    def get_entity_description(cls, statement):
+        ext_info = statement.table._annotations["parententity"]
+        mapper = ext_info.mapper
+        if ext_info.is_aliased_class:
+            _label_name = ext_info.name
+        else:
+            _label_name = mapper.class_.__name__
+
+        return {
+            "name": _label_name,
+            "type": mapper.class_,
+            "expr": ext_info.entity,
+            "entity": ext_info.entity,
+            "table": mapper.local_table,
+        }
+
+    @classmethod
+    def get_returning_column_descriptions(cls, statement):
+        def _ent_for_col(c):
+            return c._annotations.get("parententity", None)
+
+        def _attr_for_col(c, ent):
+            if ent is None:
+                return c
+            proxy_key = c._annotations.get("proxy_key", None)
+            if not proxy_key:
+                return c
+            else:
+                return getattr(ent.entity, proxy_key, c)
+
+        return [
+            {
+                "name": c.key,
+                "type": c.type,
+                "expr": _attr_for_col(c, ent),
+                "aliased": ent.is_aliased_class,
+                "entity": ent.entity,
+            }
+            for c, ent in [
+                (c, _ent_for_col(c)) for c in statement._all_selected_columns
+            ]
+        ]
+
+
+class BulkUDCompileState(ORMDMLState):
+    class default_update_options(Options):
+        _synchronize_session: _SynchronizeSessionArgument = "evaluate"
+        _is_delete_using = False
+        _is_update_from = False
+        _autoflush = True
+        _subject_mapper = None
+        _resolved_values = EMPTY_DICT
+        _resolved_keys_as_propnames = EMPTY_DICT
+        _value_evaluators = EMPTY_DICT
+        _matched_objects = None
+        _matched_rows = None
+        _refresh_identity_token = None
+
+    @classmethod
+    def can_use_returning(
+        cls,
+        dialect: Dialect,
+        mapper: Mapper[Any],
+        *,
+        is_multitable: bool = False,
+        is_update_from: bool = False,
+        is_delete_using: bool = False,
+    ) -> bool:
+        raise NotImplementedError()
+
+    @classmethod
+    def orm_pre_session_exec(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        is_reentrant_invoke,
+    ):
+        if is_reentrant_invoke:
+            return statement, execution_options
+
+        (
+            update_options,
+            execution_options,
+        ) = BulkUDCompileState.default_update_options.from_execution_options(
+            "_sa_orm_update_options",
+            {"synchronize_session", "is_delete_using", "is_update_from"},
+            execution_options,
+            statement._execution_options,
+        )
+
+        sync = update_options._synchronize_session
+        if sync is not None:
+            if sync not in ("evaluate", "fetch", False):
+                raise sa_exc.ArgumentError(
+                    "Valid strategies for session synchronization "
+                    "are 'evaluate', 'fetch', False"
+                )
+
+        bind_arguments["clause"] = statement
+        try:
+            plugin_subject = statement._propagate_attrs["plugin_subject"]
+        except KeyError:
+            assert False, "statement had 'orm' plugin but no plugin_subject"
+        else:
+            bind_arguments["mapper"] = plugin_subject.mapper
+
+        update_options += {"_subject_mapper": plugin_subject.mapper}
+
+        if update_options._autoflush:
+            session._autoflush()
+
+        statement = statement._annotate(
+            {
+                "synchronize_session": update_options._synchronize_session,
+                "is_delete_using": update_options._is_delete_using,
+                "is_update_from": update_options._is_update_from,
+            }
+        )
+
+        # this stage of the execution is called before the do_orm_execute event
+        # hook.  meaning for an extension like horizontal sharding, this step
+        # happens before the extension splits out into multiple backends and
+        # runs only once.  if we do pre_sync_fetch, we execute a SELECT
+        # statement, which the horizontal sharding extension splits amongst the
+        # shards and combines the results together.
+
+        if update_options._synchronize_session == "evaluate":
+            update_options = cls._do_pre_synchronize_evaluate(
+                session,
+                statement,
+                params,
+                execution_options,
+                bind_arguments,
+                update_options,
+            )
+        elif update_options._synchronize_session == "fetch":
+            update_options = cls._do_pre_synchronize_fetch(
+                session,
+                statement,
+                params,
+                execution_options,
+                bind_arguments,
+                update_options,
+            )
+
+        return (
+            statement,
+            util.immutabledict(execution_options).union(
+                {"_sa_orm_update_options": update_options}
+            ),
+        )
+
+    @classmethod
+    def orm_setup_cursor_result(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        result,
+    ):
+
+        # this stage of the execution is called after the
+        # do_orm_execute event hook.  meaning for an extension like
+        # horizontal sharding, this step happens *within* the horizontal
+        # sharding event handler which calls session.execute() re-entrantly
+        # and will occur for each backend individually.
+        # the sharding extension then returns its own merged result from the
+        # individual ones we return here.
+
+        update_options = execution_options["_sa_orm_update_options"]
+        if update_options._synchronize_session == "evaluate":
+            cls._do_post_synchronize_evaluate(session, result, update_options)
+        elif update_options._synchronize_session == "fetch":
+            cls._do_post_synchronize_fetch(session, result, update_options)
+
+        return result
+
+    @classmethod
+    def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
+        """Apply extra criteria filtering.
+
+        For all distinct single-table-inheritance mappers represented in the
+        table being updated or deleted, produce additional WHERE criteria such
+        that only the appropriate subtypes are selected from the total results.
+
+        Additionally, add WHERE criteria originating from LoaderCriteriaOptions
+        collected from the statement.
+
+        """
+
+        return_crit = ()
+
+        adapter = ext_info._adapter if ext_info.is_aliased_class else None
+
+        if (
+            "additional_entity_criteria",
+            ext_info.mapper,
+        ) in global_attributes:
+            return_crit += tuple(
+                ae._resolve_where_criteria(ext_info)
+                for ae in global_attributes[
+                    ("additional_entity_criteria", ext_info.mapper)
+                ]
+                if ae.include_aliases or ae.entity is ext_info
+            )
+
+        if ext_info.mapper._single_table_criterion is not None:
+            return_crit += (ext_info.mapper._single_table_criterion,)
+
+        if adapter:
+            return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
+
+        return return_crit
+
+    @classmethod
+    def _interpret_returning_rows(cls, mapper, rows):
+        """translate from local inherited table columns to base mapper
+        primary key columns.
+
+        Joined inheritance mappers always establish the primary key in terms of
+        the base table.   When we UPDATE a sub-table, we can only get
+        RETURNING for the sub-table's columns.
+
+        Here, we create a lookup from the local sub table's primary key
+        columns to the base table PK columns so that we can get identity
+        key values from RETURNING that's against the joined inheritance
+        sub-table.
+
+        the complexity here is to support more than one level deep of
+        inheritance, where we have to link columns to each other across
+        the inheritance hierarchy.
+
+        """
+
+        if mapper.local_table is not mapper.base_mapper.local_table:
+            return rows
+
+        # this starts as a mapping of
+        # local_pk_col: local_pk_col.
+        # we will then iteratively rewrite the "value" of the dict with
+        # each successive superclass column
+        local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key}
+
+        for mp in mapper.iterate_to_root():
+            if mp.inherits is None:
+                break
+            elif mp.local_table is mp.inherits.local_table:
+                continue
+
+            t_to_e = dict(mp._table_to_equated[mp.inherits.local_table])
+            col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]}
+            for pk, super_ in local_pk_to_base_pk.items():
+                local_pk_to_base_pk[pk] = col_to_col[super_]
+
+        lookup = {
+            local_pk_to_base_pk[lpk]: idx
+            for idx, lpk in enumerate(mapper.local_table.primary_key)
+        }
+        primary_key_convert = [
+            lookup[bpk] for bpk in mapper.base_mapper.primary_key
+        ]
+
+        return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
+
+    @classmethod
+    def _do_pre_synchronize_evaluate(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        update_options,
+    ):
+        mapper = update_options._subject_mapper
+        target_cls = mapper.class_
+
+        value_evaluators = resolved_keys_as_propnames = EMPTY_DICT
+
+        try:
+            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+            crit = ()
+            if statement._where_criteria:
+                crit += statement._where_criteria
+
+            global_attributes = {}
+            for opt in statement._with_options:
+                if opt._is_criteria_option:
+                    opt.get_global_criteria(global_attributes)
+
+            if global_attributes:
+                crit += cls._adjust_for_extra_criteria(
+                    global_attributes, mapper
+                )
+
+            if crit:
+                eval_condition = evaluator_compiler.process(*crit)
+            else:
+
+                def eval_condition(obj):
+                    return True
+
+        except evaluator.UnevaluatableError as err:
+            raise sa_exc.InvalidRequestError(
+                'Could not evaluate current criteria in Python: "%s". '
+                "Specify 'fetch' or False for the "
+                "synchronize_session execution option." % err
+            ) from err
+
+        if statement.__visit_name__ == "lambda_element":
+            # ._resolved is called on every LambdaElement in order to
+            # generate the cache key, so this access does not add
+            # additional expense
+            effective_statement = statement._resolved
+        else:
+            effective_statement = statement
+
+        if effective_statement.__visit_name__ == "update":
+            resolved_values = cls._get_resolved_values(
+                mapper, effective_statement
+            )
+            value_evaluators = {}
+            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+                mapper, resolved_values
+            )
+            for key, value in resolved_keys_as_propnames:
+                try:
+                    _evaluator = evaluator_compiler.process(
+                        coercions.expect(roles.ExpressionElementRole, value)
+                    )
+                except evaluator.UnevaluatableError:
+                    pass
+                else:
+                    value_evaluators[key] = _evaluator
+
+        # TODO: detect when the where clause is a trivial primary key match.
+        matched_objects = [
+            state.obj()
+            for state in session.identity_map.all_states()
+            if state.mapper.isa(mapper)
+            and not state.expired
+            and eval_condition(state.obj())
+            and (
+                update_options._refresh_identity_token is None
+                # TODO: coverage for the case where horizontal sharding
+                # invokes an update() or delete() given an explicit identity
+                # token up front
+                or state.identity_token
+                == update_options._refresh_identity_token
+            )
+        ]
+        return update_options + {
+            "_matched_objects": matched_objects,
+            "_value_evaluators": value_evaluators,
+            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+        }
+
+    @classmethod
+    def _get_resolved_values(cls, mapper, statement):
+        if statement._multi_values:
+            return []
+        elif statement._ordered_values:
+            return list(statement._ordered_values)
+        elif statement._values:
+            return list(statement._values.items())
+        else:
+            return []
+
+    @classmethod
+    def _resolved_keys_as_propnames(cls, mapper, resolved_values):
+        values = []
+        for k, v in resolved_values:
+            if isinstance(k, attributes.QueryableAttribute):
+                values.append((k.key, v))
+                continue
+            elif hasattr(k, "__clause_element__"):
+                k = k.__clause_element__()
+
+            if mapper and isinstance(k, expression.ColumnElement):
+                try:
+                    attr = mapper._columntoproperty[k]
+                except orm_exc.UnmappedColumnError:
+                    pass
+                else:
+                    values.append((attr.key, v))
+            else:
+                raise sa_exc.InvalidRequestError(
+                    "Invalid expression type: %r" % k
+                )
+        return values
+
+    @classmethod
+    def _do_pre_synchronize_fetch(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        update_options,
+    ):
+        mapper = update_options._subject_mapper
+
+        select_stmt = (
+            select(*(mapper.primary_key + (mapper.select_identity_token,)))
+            .select_from(mapper)
+            .options(*statement._with_options)
+        )
+        select_stmt._where_criteria = statement._where_criteria
+
+        def skip_for_returning(orm_context: ORMExecuteState) -> Any:
+            bind = orm_context.session.get_bind(**orm_context.bind_arguments)
+            if cls.can_use_returning(
+                bind.dialect,
+                mapper,
+                is_update_from=update_options._is_update_from,
+                is_delete_using=update_options._is_delete_using,
+            ):
+                return _result.null_result()
+            else:
+                return None
+
+        result = session.execute(
+            select_stmt,
+            params,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
+            _add_event=skip_for_returning,
+        )
+        matched_rows = result.fetchall()
+
+        value_evaluators = EMPTY_DICT
+
+        if statement.__visit_name__ == "lambda_element":
+            # ._resolved is called on every LambdaElement in order to
+            # generate the cache key, so this access does not add
+            # additional expense
+            effective_statement = statement._resolved
+        else:
+            effective_statement = statement
+
+        if effective_statement.__visit_name__ == "update":
+            target_cls = mapper.class_
+            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
+            resolved_values = cls._get_resolved_values(
+                mapper, effective_statement
+            )
+            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+                mapper, resolved_values
+            )
+
+            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
+                mapper, resolved_values
+            )
+            value_evaluators = {}
+            for key, value in resolved_keys_as_propnames:
+                try:
+                    _evaluator = evaluator_compiler.process(
+                        coercions.expect(roles.ExpressionElementRole, value)
+                    )
+                except evaluator.UnevaluatableError:
+                    pass
+                else:
+                    value_evaluators[key] = _evaluator
+
+        else:
+            resolved_keys_as_propnames = EMPTY_DICT
+
+        return update_options + {
+            "_value_evaluators": value_evaluators,
+            "_matched_rows": matched_rows,
+            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
+        }
+
+
+@CompileState.plugin_for("orm", "insert")
+class ORMInsert(ORMDMLState, InsertDMLState):
+    @classmethod
+    def orm_pre_session_exec(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        is_reentrant_invoke,
+    ):
+        bind_arguments["clause"] = statement
+        try:
+            plugin_subject = statement._propagate_attrs["plugin_subject"]
+        except KeyError:
+            assert False, "statement had 'orm' plugin but no plugin_subject"
+        else:
+            bind_arguments["mapper"] = plugin_subject.mapper
+
+        return (
+            statement,
+            util.immutabledict(execution_options),
+        )
+
+    @classmethod
+    def orm_setup_cursor_result(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        result,
+    ):
+        return result
+
+
+@CompileState.plugin_for("orm", "update")
+class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
+    @classmethod
+    def create_for_statement(cls, statement, compiler, **kw):
+
+        self = cls.__new__(cls)
+
+        ext_info = statement.table._annotations["parententity"]
+
+        self.mapper = mapper = ext_info.mapper
+
+        self.extra_criteria_entities = {}
+
+        self._resolved_values = cls._get_resolved_values(mapper, statement)
+
+        extra_criteria_attributes = {}
+
+        for opt in statement._with_options:
+            if opt._is_criteria_option:
+                opt.get_global_criteria(extra_criteria_attributes)
+
+        if statement._values:
+            self._resolved_values = dict(self._resolved_values)
+
+        new_stmt = sql.Update.__new__(sql.Update)
+        new_stmt.__dict__.update(statement.__dict__)
+        new_stmt.table = mapper.local_table
+
+        # note if the statement has _multi_values, these
+        # are passed through to the new statement, which will then raise
+        # InvalidRequestError because UPDATE doesn't support multi_values
+        # right now.
+        if statement._ordered_values:
+            new_stmt._ordered_values = self._resolved_values
+        elif statement._values:
+            new_stmt._values = self._resolved_values
+
+        new_crit = cls._adjust_for_extra_criteria(
+            extra_criteria_attributes, mapper
+        )
+        if new_crit:
+            new_stmt = new_stmt.where(*new_crit)
+
+        # if we are against a lambda statement we might not be the
+        # topmost object that received per-execute annotations
+
+        # do this first as we need to determine if there is
+        # UPDATE..FROM
+
+        UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
+
+        if compiler._annotations.get(
+            "synchronize_session", None
+        ) == "fetch" and self.can_use_returning(
+            compiler.dialect, mapper, is_multitable=self.is_multitable
+        ):
+            if new_stmt._returning:
+                raise sa_exc.InvalidRequestError(
+                    "Can't use synchronize_session='fetch' "
+                    "with explicit returning()"
+                )
+            self.statement = self.statement.returning(
+                *mapper.local_table.primary_key
+            )
+
+        return self
+
+    @classmethod
+    def can_use_returning(
+        cls,
+        dialect: Dialect,
+        mapper: Mapper[Any],
+        *,
+        is_multitable: bool = False,
+        is_update_from: bool = False,
+        is_delete_using: bool = False,
+    ) -> bool:
+
+        # normal answer for "should we use RETURNING" at all.
+        normal_answer = (
+            dialect.update_returning and mapper.local_table.implicit_returning
+        )
+        if not normal_answer:
+            return False
+
+        # these workarounds are currently hypothetical for UPDATE,
+        # unlike DELETE where they impact MariaDB
+        if is_update_from:
+            return dialect.update_returning_multifrom
+
+        elif is_multitable and not dialect.update_returning_multifrom:
+            raise sa_exc.CompileError(
+                f'Dialect "{dialect.name}" does not support RETURNING '
+                "with UPDATE..FROM; for synchronize_session='fetch', "
+                "please add the additional execution option "
+                "'is_update_from=True' to the statement to indicate that "
+                "a separate SELECT should be used for this backend."
+            )
+
+        return True
+
+    @classmethod
+    def _get_crud_kv_pairs(cls, statement, kv_iterator):
+        plugin_subject = statement._propagate_attrs["plugin_subject"]
+
+        core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
+
+        if not plugin_subject or not plugin_subject.mapper:
+            return core_get_crud_kv_pairs(statement, kv_iterator)
+
+        mapper = plugin_subject.mapper
+
+        values = []
+
+        for k, v in kv_iterator:
+            k = coercions.expect(roles.DMLColumnRole, k)
+
+            if isinstance(k, str):
+                desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
+                if desc is NO_VALUE:
+                    values.append(
+                        (
+                            k,
+                            coercions.expect(
+                                roles.ExpressionElementRole,
+                                v,
+                                type_=sqltypes.NullType(),
+                                is_crud=True,
+                            ),
+                        )
+                    )
+                else:
+                    values.extend(
+                        core_get_crud_kv_pairs(
+                            statement, desc._bulk_update_tuples(v)
+                        )
+                    )
+            elif "entity_namespace" in k._annotations:
+                k_anno = k._annotations
+                attr = _entity_namespace_key(
+                    k_anno["entity_namespace"], k_anno["proxy_key"]
+                )
+                values.extend(
+                    core_get_crud_kv_pairs(
+                        statement, attr._bulk_update_tuples(v)
+                    )
+                )
+            else:
+                values.append(
+                    (
+                        k,
+                        coercions.expect(
+                            roles.ExpressionElementRole,
+                            v,
+                            type_=sqltypes.NullType(),
+                            is_crud=True,
+                        ),
+                    )
+                )
+        return values
+
+    @classmethod
+    def _do_post_synchronize_evaluate(cls, session, result, update_options):
+
+        states = set()
+        evaluated_keys = list(update_options._value_evaluators.keys())
+        values = update_options._resolved_keys_as_propnames
+        attrib = set(k for k, v in values)
+        for obj in update_options._matched_objects:
+
+            state, dict_ = (
+                attributes.instance_state(obj),
+                attributes.instance_dict(obj),
+            )
+
+            # the evaluated states were gathered across all identity tokens.
+            # however the post_sync events are called per identity token,
+            # so filter.
+            if (
+                update_options._refresh_identity_token is not None
+                and state.identity_token
+                != update_options._refresh_identity_token
+            ):
+                continue
+
+            # only evaluate unmodified attributes
+            to_evaluate = state.unmodified.intersection(evaluated_keys)
+            for key in to_evaluate:
+                if key in dict_:
+                    dict_[key] = update_options._value_evaluators[key](obj)
+
+            state.manager.dispatch.refresh(state, None, to_evaluate)
+
+            state._commit(dict_, list(to_evaluate))
+
+            to_expire = attrib.intersection(dict_).difference(to_evaluate)
+            if to_expire:
+                state._expire_attributes(dict_, to_expire)
+
+            states.add(state)
+        session._register_altered(states)
+
+    @classmethod
+    def _do_post_synchronize_fetch(cls, session, result, update_options):
+        target_mapper = update_options._subject_mapper
+
+        states = set()
+        evaluated_keys = list(update_options._value_evaluators.keys())
+
+        if result.returns_rows:
+            rows = cls._interpret_returning_rows(target_mapper, result.all())
+
+            matched_rows = [
+                tuple(row) + (update_options._refresh_identity_token,)
+                for row in rows
+            ]
+        else:
+            matched_rows = update_options._matched_rows
+
+        objs = [
+            session.identity_map[identity_key]
+            for identity_key in [
+                target_mapper.identity_key_from_primary_key(
+                    list(primary_key),
+                    identity_token=identity_token,
+                )
+                for primary_key, identity_token in [
+                    (row[0:-1], row[-1]) for row in matched_rows
+                ]
+                if update_options._refresh_identity_token is None
+                or identity_token == update_options._refresh_identity_token
+            ]
+            if identity_key in session.identity_map
+        ]
+
+        values = update_options._resolved_keys_as_propnames
+        attrib = set(k for k, v in values)
+
+        for obj in objs:
+            state, dict_ = (
+                attributes.instance_state(obj),
+                attributes.instance_dict(obj),
+            )
+
+            to_evaluate = state.unmodified.intersection(evaluated_keys)
+            for key in to_evaluate:
+                if key in dict_:
+                    dict_[key] = update_options._value_evaluators[key](obj)
+            state.manager.dispatch.refresh(state, None, to_evaluate)
+
+            state._commit(dict_, list(to_evaluate))
+
+            to_expire = attrib.intersection(dict_).difference(to_evaluate)
+            if to_expire:
+                state._expire_attributes(dict_, to_expire)
+
+            states.add(state)
+        session._register_altered(states)
+
+
+@CompileState.plugin_for("orm", "delete")
+class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
+    @classmethod
+    def create_for_statement(cls, statement, compiler, **kw):
+        self = cls.__new__(cls)
+
+        ext_info = statement.table._annotations["parententity"]
+        self.mapper = mapper = ext_info.mapper
+
+        self.extra_criteria_entities = {}
+
+        extra_criteria_attributes = {}
+
+        for opt in statement._with_options:
+            if opt._is_criteria_option:
+                opt.get_global_criteria(extra_criteria_attributes)
+
+        new_crit = cls._adjust_for_extra_criteria(
+            extra_criteria_attributes, mapper
+        )
+        if new_crit:
+            statement = statement.where(*new_crit)
+
+        # do this first as we need to determine if there is
+        # DELETE..FROM
+        DeleteDMLState.__init__(self, statement, compiler, **kw)
+
+        if compiler._annotations.get(
+            "synchronize_session", None
+        ) == "fetch" and self.can_use_returning(
+            compiler.dialect,
+            mapper,
+            is_multitable=self.is_multitable,
+            is_delete_using=compiler._annotations.get(
+                "is_delete_using", False
+            ),
+        ):
+            self.statement = statement.returning(*statement.table.primary_key)
+
+        return self
+
+    @classmethod
+    def can_use_returning(
+        cls,
+        dialect: Dialect,
+        mapper: Mapper[Any],
+        *,
+        is_multitable: bool = False,
+        is_update_from: bool = False,
+        is_delete_using: bool = False,
+    ) -> bool:
+
+        # normal answer for "should we use RETURNING" at all.
+        normal_answer = (
+            dialect.delete_returning and mapper.local_table.implicit_returning
+        )
+        if not normal_answer:
+            return False
+
+        # now get into special workarounds because MariaDB supports
+        # DELETE...RETURNING but not DELETE...USING...RETURNING.
+        if is_delete_using:
+            # is_delete_using hint was passed.   use
+            # additional dialect feature (True for PG, False for MariaDB)
+            return dialect.delete_returning_multifrom
+
+        elif is_multitable and not dialect.delete_returning_multifrom:
+            # is_delete_using hint was not passed, but we determined
+            # at compile time that this is in fact a DELETE..USING.
+            # it's too late to continue since we did not pre-SELECT.
+            # raise that we need that hint up front.
+
+            raise sa_exc.CompileError(
+                f'Dialect "{dialect.name}" does not support RETURNING '
+                "with DELETE..USING; for synchronize_session='fetch', "
+                "please add the additional execution option "
+                "'is_delete_using=True' to the statement to indicate that "
+                "a separate SELECT should be used for this backend."
+            )
+
+        return True
+
+    @classmethod
+    def _do_post_synchronize_evaluate(cls, session, result, update_options):
+
+        session._remove_newly_deleted(
+            [
+                attributes.instance_state(obj)
+                for obj in update_options._matched_objects
+            ]
+        )
+
+    @classmethod
+    def _do_post_synchronize_fetch(cls, session, result, update_options):
+        target_mapper = update_options._subject_mapper
+
+        if result.returns_rows:
+            rows = cls._interpret_returning_rows(target_mapper, result.all())
+
+            matched_rows = [
+                tuple(row) + (update_options._refresh_identity_token,)
+                for row in rows
+            ]
+        else:
+            matched_rows = update_options._matched_rows
+
+        for row in matched_rows:
+            primary_key = row[0:-1]
+            identity_token = row[-1]
+
+            # TODO: inline this and call remove_newly_deleted
+            # once
+            identity_key = target_mapper.identity_key_from_primary_key(
+                list(primary_key),
+                identity_token=identity_token,
+            )
+            if identity_key in session.identity_map:
+                session._remove_newly_deleted(
+                    [
+                        attributes.instance_state(
+                            session.identity_map[identity_key]
+                        )
+                    ]
+                )
index 20a03e9b47b289449c89f0aa1bbaa5a50aaab5ca..f8ea231395d7a83ec51e993b27170f7252740068 100644 (file)
@@ -206,7 +206,50 @@ _orm_load_exec_options = util.immutabledict(
 )
 
 
-class ORMCompileState(CompileState):
+class AbstractORMCompileState(CompileState):
+    @classmethod
+    def create_for_statement(
+        cls,
+        statement: Union[Select, FromStatement],
+        compiler: Optional[SQLCompiler],
+        **kw: Any,
+    ) -> ORMCompileState:
+        """Create a context for a statement given a :class:`.Compiler`.
+        This method is always invoked in the context of SQLCompiler.process().
+        For a Select object, this would be invoked from
+        SQLCompiler.visit_select(). For the special FromStatement object used
+        by Query to indicate "Query.from_statement()", this is called by
+        FromStatement._compiler_dispatch() that would be called by
+        SQLCompiler.process().
+        """
+        return super().create_for_statement(statement, compiler, **kw)
+
+    @classmethod
+    def orm_pre_session_exec(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        is_reentrant_invoke,
+    ):
+        raise NotImplementedError()
+
+    @classmethod
+    def orm_setup_cursor_result(
+        cls,
+        session,
+        statement,
+        params,
+        execution_options,
+        bind_arguments,
+        result,
+    ):
+        raise NotImplementedError()
+
+
+class ORMCompileState(AbstractORMCompileState):
     class default_compile_options(CacheableOptions):
         _cache_key_traversal = [
             ("_use_legacy_query_style", InternalTraversal.dp_boolean),
index 59a0a3d81d4a8b8476780229a9807f9141fb9f06..6310f5b1b622a2bca8557f55fef883a4764cc483 100644 (file)
@@ -21,199 +21,19 @@ from itertools import chain
 from itertools import groupby
 from itertools import zip_longest
 import operator
-from typing import Any
-from typing import Dict
-from typing import Iterable
-from typing import TYPE_CHECKING
-from typing import TypeVar
-from typing import Union
 
 from . import attributes
-from . import evaluator
 from . import exc as orm_exc
 from . import loading
 from . import sync
-from .base import NO_VALUE
 from .base import state_str
 from .. import exc as sa_exc
 from .. import future
 from .. import sql
 from .. import util
-from ..engine import Dialect
-from ..engine import result as _result
-from ..sql import coercions
-from ..sql import expression
 from ..sql import operators
-from ..sql import roles
-from ..sql import select
-from ..sql import sqltypes
-from ..sql.base import _entity_namespace_key
-from ..sql.base import CompileState
-from ..sql.base import Options
-from ..sql.dml import DeleteDMLState
-from ..sql.dml import InsertDMLState
-from ..sql.dml import UpdateDMLState
 from ..sql.elements import BooleanClauseList
 from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
-from ..util.typing import Literal
-
-if TYPE_CHECKING:
-    from .mapper import Mapper
-    from .session import ORMExecuteState
-    from .session import SessionTransaction
-    from .state import InstanceState
-
-_O = TypeVar("_O", bound=object)
-
-
-_SynchronizeSessionArgument = Literal[False, "evaluate", "fetch"]
-
-
-def _bulk_insert(
-    mapper: Mapper[_O],
-    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
-    session_transaction: SessionTransaction,
-    isstates: bool,
-    return_defaults: bool,
-    render_nulls: bool,
-) -> None:
-    base_mapper = mapper.base_mapper
-
-    if session_transaction.session.connection_callable:
-        raise NotImplementedError(
-            "connection_callable / per-instance sharding "
-            "not supported in bulk_insert()"
-        )
-
-    if isstates:
-        if return_defaults:
-            states = [(state, state.dict) for state in mappings]
-            mappings = [dict_ for (state, dict_) in states]
-        else:
-            mappings = [state.dict for state in mappings]
-    else:
-        mappings = list(mappings)
-
-    connection = session_transaction.connection(base_mapper)
-    for table, super_mapper in base_mapper._sorted_tables.items():
-        if not mapper.isa(super_mapper):
-            continue
-
-        records = (
-            (
-                None,
-                state_dict,
-                params,
-                mapper,
-                connection,
-                value_params,
-                has_all_pks,
-                has_all_defaults,
-            )
-            for (
-                state,
-                state_dict,
-                params,
-                mp,
-                conn,
-                value_params,
-                has_all_pks,
-                has_all_defaults,
-            ) in _collect_insert_commands(
-                table,
-                ((None, mapping, mapper, connection) for mapping in mappings),
-                bulk=True,
-                return_defaults=return_defaults,
-                render_nulls=render_nulls,
-            )
-        )
-        _emit_insert_statements(
-            base_mapper,
-            None,
-            super_mapper,
-            table,
-            records,
-            bookkeeping=return_defaults,
-        )
-
-    if return_defaults and isstates:
-        identity_cls = mapper._identity_class
-        identity_props = [p.key for p in mapper._identity_key_props]
-        for state, dict_ in states:
-            state.key = (
-                identity_cls,
-                tuple([dict_[key] for key in identity_props]),
-            )
-
-
-def _bulk_update(
-    mapper: Mapper[Any],
-    mappings: Union[Iterable[InstanceState[_O]], Iterable[Dict[str, Any]]],
-    session_transaction: SessionTransaction,
-    isstates: bool,
-    update_changed_only: bool,
-) -> None:
-    base_mapper = mapper.base_mapper
-
-    search_keys = mapper._primary_key_propkeys
-    if mapper._version_id_prop:
-        search_keys = {mapper._version_id_prop.key}.union(search_keys)
-
-    def _changed_dict(mapper, state):
-        return dict(
-            (k, v)
-            for k, v in state.dict.items()
-            if k in state.committed_state or k in search_keys
-        )
-
-    if isstates:
-        if update_changed_only:
-            mappings = [_changed_dict(mapper, state) for state in mappings]
-        else:
-            mappings = [state.dict for state in mappings]
-    else:
-        mappings = list(mappings)
-
-    if session_transaction.session.connection_callable:
-        raise NotImplementedError(
-            "connection_callable / per-instance sharding "
-            "not supported in bulk_update()"
-        )
-
-    connection = session_transaction.connection(base_mapper)
-
-    for table, super_mapper in base_mapper._sorted_tables.items():
-        if not mapper.isa(super_mapper):
-            continue
-
-        records = _collect_update_commands(
-            None,
-            table,
-            (
-                (
-                    None,
-                    mapping,
-                    mapper,
-                    connection,
-                    (
-                        mapping[mapper._version_id_prop.key]
-                        if mapper._version_id_prop
-                        else None
-                    ),
-                )
-                for mapping in mappings
-            ),
-            bulk=True,
-        )
-
-        _emit_update_statements(
-            base_mapper,
-            None,
-            super_mapper,
-            table,
-            records,
-            bookkeeping=False,
-        )
 
 
 def save_obj(base_mapper, states, uowtransaction, single=False):
@@ -1797,912 +1617,3 @@ def _sort_states(mapper, states):
         sorted(pending, key=operator.attrgetter("insert_order"))
         + persistent_sorted
     )
-
-
-_EMPTY_DICT = util.immutabledict()
-
-
-class BulkUDCompileState(CompileState):
-    class default_update_options(Options):
-        _synchronize_session = "evaluate"
-        _is_delete_using = False
-        _is_update_from = False
-        _autoflush = True
-        _subject_mapper = None
-        _resolved_values = _EMPTY_DICT
-        _resolved_keys_as_propnames = _EMPTY_DICT
-        _value_evaluators = _EMPTY_DICT
-        _matched_objects = None
-        _matched_rows = None
-        _refresh_identity_token = None
-
-    @classmethod
-    def can_use_returning(
-        cls,
-        dialect: Dialect,
-        mapper: Mapper[Any],
-        *,
-        is_multitable: bool = False,
-        is_update_from: bool = False,
-        is_delete_using: bool = False,
-    ) -> bool:
-        raise NotImplementedError()
-
-    @classmethod
-    def orm_pre_session_exec(
-        cls,
-        session,
-        statement,
-        params,
-        execution_options,
-        bind_arguments,
-        is_reentrant_invoke,
-    ):
-        if is_reentrant_invoke:
-            return statement, execution_options
-
-        (
-            update_options,
-            execution_options,
-        ) = BulkUDCompileState.default_update_options.from_execution_options(
-            "_sa_orm_update_options",
-            {"synchronize_session", "is_delete_using", "is_update_from"},
-            execution_options,
-            statement._execution_options,
-        )
-
-        sync = update_options._synchronize_session
-        if sync is not None:
-            if sync not in ("evaluate", "fetch", False):
-                raise sa_exc.ArgumentError(
-                    "Valid strategies for session synchronization "
-                    "are 'evaluate', 'fetch', False"
-                )
-
-        bind_arguments["clause"] = statement
-        try:
-            plugin_subject = statement._propagate_attrs["plugin_subject"]
-        except KeyError:
-            assert False, "statement had 'orm' plugin but no plugin_subject"
-        else:
-            bind_arguments["mapper"] = plugin_subject.mapper
-
-        update_options += {"_subject_mapper": plugin_subject.mapper}
-
-        if update_options._autoflush:
-            session._autoflush()
-
-        statement = statement._annotate(
-            {
-                "synchronize_session": update_options._synchronize_session,
-                "is_delete_using": update_options._is_delete_using,
-                "is_update_from": update_options._is_update_from,
-            }
-        )
-
-        # this stage of the execution is called before the do_orm_execute event
-        # hook.  meaning for an extension like horizontal sharding, this step
-        # happens before the extension splits out into multiple backends and
-        # runs only once.  if we do pre_sync_fetch, we execute a SELECT
-        # statement, which the horizontal sharding extension splits amongst the
-        # shards and combines the results together.
-
-        if update_options._synchronize_session == "evaluate":
-            update_options = cls._do_pre_synchronize_evaluate(
-                session,
-                statement,
-                params,
-                execution_options,
-                bind_arguments,
-                update_options,
-            )
-        elif update_options._synchronize_session == "fetch":
-            update_options = cls._do_pre_synchronize_fetch(
-                session,
-                statement,
-                params,
-                execution_options,
-                bind_arguments,
-                update_options,
-            )
-
-        return (
-            statement,
-            util.immutabledict(execution_options).union(
-                {"_sa_orm_update_options": update_options}
-            ),
-        )
-
-    @classmethod
-    def orm_setup_cursor_result(
-        cls,
-        session,
-        statement,
-        params,
-        execution_options,
-        bind_arguments,
-        result,
-    ):
-
-        # this stage of the execution is called after the
-        # do_orm_execute event hook.  meaning for an extension like
-        # horizontal sharding, this step happens *within* the horizontal
-        # sharding event handler which calls session.execute() re-entrantly
-        # and will occur for each backend individually.
-        # the sharding extension then returns its own merged result from the
-        # individual ones we return here.
-
-        update_options = execution_options["_sa_orm_update_options"]
-        if update_options._synchronize_session == "evaluate":
-            cls._do_post_synchronize_evaluate(session, result, update_options)
-        elif update_options._synchronize_session == "fetch":
-            cls._do_post_synchronize_fetch(session, result, update_options)
-
-        return result
-
-    @classmethod
-    def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
-        """Apply extra criteria filtering.
-
-        For all distinct single-table-inheritance mappers represented in the
-        table being updated or deleted, produce additional WHERE criteria such
-        that only the appropriate subtypes are selected from the total results.
-
-        Additionally, add WHERE criteria originating from LoaderCriteriaOptions
-        collected from the statement.
-
-        """
-
-        return_crit = ()
-
-        adapter = ext_info._adapter if ext_info.is_aliased_class else None
-
-        if (
-            "additional_entity_criteria",
-            ext_info.mapper,
-        ) in global_attributes:
-            return_crit += tuple(
-                ae._resolve_where_criteria(ext_info)
-                for ae in global_attributes[
-                    ("additional_entity_criteria", ext_info.mapper)
-                ]
-                if ae.include_aliases or ae.entity is ext_info
-            )
-
-        if ext_info.mapper._single_table_criterion is not None:
-            return_crit += (ext_info.mapper._single_table_criterion,)
-
-        if adapter:
-            return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
-
-        return return_crit
-
-    @classmethod
-    def _interpret_returning_rows(cls, mapper, rows):
-        """translate from local inherited table columns to base mapper
-        primary key columns.
-
-        Joined inheritance mappers always establish the primary key in terms of
-        the base table.   When we UPDATE a sub-table, we can only get
-        RETURNING for the sub-table's columns.
-
-        Here, we create a lookup from the local sub table's primary key
-        columns to the base table PK columns so that we can get identity
-        key values from RETURNING that's against the joined inheritance
-        sub-table.
-
-        the complexity here is to support more than one level deep of
-        inheritance, where we have to link columns to each other across
-        the inheritance hierarchy.
-
-        """
-
-        if mapper.local_table is not mapper.base_mapper.local_table:
-            return rows
-
-        # this starts as a mapping of
-        # local_pk_col: local_pk_col.
-        # we will then iteratively rewrite the "value" of the dict with
-        # each successive superclass column
-        local_pk_to_base_pk = {pk: pk for pk in mapper.local_table.primary_key}
-
-        for mp in mapper.iterate_to_root():
-            if mp.inherits is None:
-                break
-            elif mp.local_table is mp.inherits.local_table:
-                continue
-
-            t_to_e = dict(mp._table_to_equated[mp.inherits.local_table])
-            col_to_col = {sub_pk: super_pk for super_pk, sub_pk in t_to_e[mp]}
-            for pk, super_ in local_pk_to_base_pk.items():
-                local_pk_to_base_pk[pk] = col_to_col[super_]
-
-        lookup = {
-            local_pk_to_base_pk[lpk]: idx
-            for idx, lpk in enumerate(mapper.local_table.primary_key)
-        }
-        primary_key_convert = [
-            lookup[bpk] for bpk in mapper.base_mapper.primary_key
-        ]
-
-        return [tuple(row[idx] for idx in primary_key_convert) for row in rows]
-
-    @classmethod
-    def _do_pre_synchronize_evaluate(
-        cls,
-        session,
-        statement,
-        params,
-        execution_options,
-        bind_arguments,
-        update_options,
-    ):
-        mapper = update_options._subject_mapper
-        target_cls = mapper.class_
-
-        value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
-
-        try:
-            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
-            crit = ()
-            if statement._where_criteria:
-                crit += statement._where_criteria
-
-            global_attributes = {}
-            for opt in statement._with_options:
-                if opt._is_criteria_option:
-                    opt.get_global_criteria(global_attributes)
-
-            if global_attributes:
-                crit += cls._adjust_for_extra_criteria(
-                    global_attributes, mapper
-                )
-
-            if crit:
-                eval_condition = evaluator_compiler.process(*crit)
-            else:
-
-                def eval_condition(obj):
-                    return True
-
-        except evaluator.UnevaluatableError as err:
-            raise sa_exc.InvalidRequestError(
-                'Could not evaluate current criteria in Python: "%s". '
-                "Specify 'fetch' or False for the "
-                "synchronize_session execution option." % err
-            ) from err
-
-        if statement.__visit_name__ == "lambda_element":
-            # ._resolved is called on every LambdaElement in order to
-            # generate the cache key, so this access does not add
-            # additional expense
-            effective_statement = statement._resolved
-        else:
-            effective_statement = statement
-
-        if effective_statement.__visit_name__ == "update":
-            resolved_values = cls._get_resolved_values(
-                mapper, effective_statement
-            )
-            value_evaluators = {}
-            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
-                mapper, resolved_values
-            )
-            for key, value in resolved_keys_as_propnames:
-                try:
-                    _evaluator = evaluator_compiler.process(
-                        coercions.expect(roles.ExpressionElementRole, value)
-                    )
-                except evaluator.UnevaluatableError:
-                    pass
-                else:
-                    value_evaluators[key] = _evaluator
-
-        # TODO: detect when the where clause is a trivial primary key match.
-        matched_objects = [
-            state.obj()
-            for state in session.identity_map.all_states()
-            if state.mapper.isa(mapper)
-            and not state.expired
-            and eval_condition(state.obj())
-            and (
-                update_options._refresh_identity_token is None
-                # TODO: coverage for the case where horizontal sharding
-                # invokes an update() or delete() given an explicit identity
-                # token up front
-                or state.identity_token
-                == update_options._refresh_identity_token
-            )
-        ]
-        return update_options + {
-            "_matched_objects": matched_objects,
-            "_value_evaluators": value_evaluators,
-            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
-        }
-
-    @classmethod
-    def _get_resolved_values(cls, mapper, statement):
-        if statement._multi_values:
-            return []
-        elif statement._ordered_values:
-            return list(statement._ordered_values)
-        elif statement._values:
-            return list(statement._values.items())
-        else:
-            return []
-
-    @classmethod
-    def _resolved_keys_as_propnames(cls, mapper, resolved_values):
-        values = []
-        for k, v in resolved_values:
-            if isinstance(k, attributes.QueryableAttribute):
-                values.append((k.key, v))
-                continue
-            elif hasattr(k, "__clause_element__"):
-                k = k.__clause_element__()
-
-            if mapper and isinstance(k, expression.ColumnElement):
-                try:
-                    attr = mapper._columntoproperty[k]
-                except orm_exc.UnmappedColumnError:
-                    pass
-                else:
-                    values.append((attr.key, v))
-            else:
-                raise sa_exc.InvalidRequestError(
-                    "Invalid expression type: %r" % k
-                )
-        return values
-
-    @classmethod
-    def _do_pre_synchronize_fetch(
-        cls,
-        session,
-        statement,
-        params,
-        execution_options,
-        bind_arguments,
-        update_options,
-    ):
-        mapper = update_options._subject_mapper
-
-        select_stmt = (
-            select(*(mapper.primary_key + (mapper.select_identity_token,)))
-            .select_from(mapper)
-            .options(*statement._with_options)
-        )
-        select_stmt._where_criteria = statement._where_criteria
-
-        def skip_for_returning(orm_context: ORMExecuteState) -> Any:
-            bind = orm_context.session.get_bind(**orm_context.bind_arguments)
-            if cls.can_use_returning(
-                bind.dialect,
-                mapper,
-                is_update_from=update_options._is_update_from,
-                is_delete_using=update_options._is_delete_using,
-            ):
-                return _result.null_result()
-            else:
-                return None
-
-        result = session.execute(
-            select_stmt,
-            params,
-            execution_options=execution_options,
-            bind_arguments=bind_arguments,
-            _add_event=skip_for_returning,
-        )
-        matched_rows = result.fetchall()
-
-        value_evaluators = _EMPTY_DICT
-
-        if statement.__visit_name__ == "lambda_element":
-            # ._resolved is called on every LambdaElement in order to
-            # generate the cache key, so this access does not add
-            # additional expense
-            effective_statement = statement._resolved
-        else:
-            effective_statement = statement
-
-        if effective_statement.__visit_name__ == "update":
-            target_cls = mapper.class_
-            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
-            resolved_values = cls._get_resolved_values(
-                mapper, effective_statement
-            )
-            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
-                mapper, resolved_values
-            )
-
-            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
-                mapper, resolved_values
-            )
-            value_evaluators = {}
-            for key, value in resolved_keys_as_propnames:
-                try:
-                    _evaluator = evaluator_compiler.process(
-                        coercions.expect(roles.ExpressionElementRole, value)
-                    )
-                except evaluator.UnevaluatableError:
-                    pass
-                else:
-                    value_evaluators[key] = _evaluator
-
-        else:
-            resolved_keys_as_propnames = _EMPTY_DICT
-
-        return update_options + {
-            "_value_evaluators": value_evaluators,
-            "_matched_rows": matched_rows,
-            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
-        }
-
-
-class ORMDMLState:
-    @classmethod
-    def get_entity_description(cls, statement):
-        ext_info = statement.table._annotations["parententity"]
-        mapper = ext_info.mapper
-        if ext_info.is_aliased_class:
-            _label_name = ext_info.name
-        else:
-            _label_name = mapper.class_.__name__
-
-        return {
-            "name": _label_name,
-            "type": mapper.class_,
-            "expr": ext_info.entity,
-            "entity": ext_info.entity,
-            "table": mapper.local_table,
-        }
-
-    @classmethod
-    def get_returning_column_descriptions(cls, statement):
-        def _ent_for_col(c):
-            return c._annotations.get("parententity", None)
-
-        def _attr_for_col(c, ent):
-            if ent is None:
-                return c
-            proxy_key = c._annotations.get("proxy_key", None)
-            if not proxy_key:
-                return c
-            else:
-                return getattr(ent.entity, proxy_key, c)
-
-        return [
-            {
-                "name": c.key,
-                "type": c.type,
-                "expr": _attr_for_col(c, ent),
-                "aliased": ent.is_aliased_class,
-                "entity": ent.entity,
-            }
-            for c, ent in [
-                (c, _ent_for_col(c)) for c in statement._all_selected_columns
-            ]
-        ]
-
-
-@CompileState.plugin_for("orm", "insert")
-class ORMInsert(ORMDMLState, InsertDMLState):
-    @classmethod
-    def orm_pre_session_exec(
-        cls,
-        session,
-        statement,
-        params,
-        execution_options,
-        bind_arguments,
-        is_reentrant_invoke,
-    ):
-        bind_arguments["clause"] = statement
-        try:
-            plugin_subject = statement._propagate_attrs["plugin_subject"]
-        except KeyError:
-            assert False, "statement had 'orm' plugin but no plugin_subject"
-        else:
-            bind_arguments["mapper"] = plugin_subject.mapper
-
-        return (
-            statement,
-            util.immutabledict(execution_options),
-        )
-
-    @classmethod
-    def orm_setup_cursor_result(
-        cls,
-        session,
-        statement,
-        params,
-        execution_options,
-        bind_arguments,
-        result,
-    ):
-        return result
-
-
-@CompileState.plugin_for("orm", "update")
-class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState):
-    @classmethod
-    def create_for_statement(cls, statement, compiler, **kw):
-
-        self = cls.__new__(cls)
-
-        ext_info = statement.table._annotations["parententity"]
-
-        self.mapper = mapper = ext_info.mapper
-
-        self.extra_criteria_entities = {}
-
-        self._resolved_values = cls._get_resolved_values(mapper, statement)
-
-        extra_criteria_attributes = {}
-
-        for opt in statement._with_options:
-            if opt._is_criteria_option:
-                opt.get_global_criteria(extra_criteria_attributes)
-
-        if statement._values:
-            self._resolved_values = dict(self._resolved_values)
-
-        new_stmt = sql.Update.__new__(sql.Update)
-        new_stmt.__dict__.update(statement.__dict__)
-        new_stmt.table = mapper.local_table
-
-        # note if the statement has _multi_values, these
-        # are passed through to the new statement, which will then raise
-        # InvalidRequestError because UPDATE doesn't support multi_values
-        # right now.
-        if statement._ordered_values:
-            new_stmt._ordered_values = self._resolved_values
-        elif statement._values:
-            new_stmt._values = self._resolved_values
-
-        new_crit = cls._adjust_for_extra_criteria(
-            extra_criteria_attributes, mapper
-        )
-        if new_crit:
-            new_stmt = new_stmt.where(*new_crit)
-
-        # if we are against a lambda statement we might not be the
-        # topmost object that received per-execute annotations
-
-        # do this first as we need to determine if there is
-        # UPDATE..FROM
-
-        UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
-
-        if compiler._annotations.get(
-            "synchronize_session", None
-        ) == "fetch" and self.can_use_returning(
-            compiler.dialect, mapper, is_multitable=self.is_multitable
-        ):
-            if new_stmt._returning:
-                raise sa_exc.InvalidRequestError(
-                    "Can't use synchronize_session='fetch' "
-                    "with explicit returning()"
-                )
-            self.statement = self.statement.returning(
-                *mapper.local_table.primary_key
-            )
-
-        return self
-
-    @classmethod
-    def can_use_returning(
-        cls,
-        dialect: Dialect,
-        mapper: Mapper[Any],
-        *,
-        is_multitable: bool = False,
-        is_update_from: bool = False,
-        is_delete_using: bool = False,
-    ) -> bool:
-
-        # normal answer for "should we use RETURNING" at all.
-        normal_answer = (
-            dialect.update_returning and mapper.local_table.implicit_returning
-        )
-        if not normal_answer:
-            return False
-
-        # these workarounds are currently hypothetical for UPDATE,
-        # unlike DELETE where they impact MariaDB
-        if is_update_from:
-            return dialect.update_returning_multifrom
-
-        elif is_multitable and not dialect.update_returning_multifrom:
-            raise sa_exc.CompileError(
-                f'Dialect "{dialect.name}" does not support RETURNING '
-                "with UPDATE..FROM; for synchronize_session='fetch', "
-                "please add the additional execution option "
-                "'is_update_from=True' to the statement to indicate that "
-                "a separate SELECT should be used for this backend."
-            )
-
-        return True
-
-    @classmethod
-    def _get_crud_kv_pairs(cls, statement, kv_iterator):
-        plugin_subject = statement._propagate_attrs["plugin_subject"]
-
-        core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
-
-        if not plugin_subject or not plugin_subject.mapper:
-            return core_get_crud_kv_pairs(statement, kv_iterator)
-
-        mapper = plugin_subject.mapper
-
-        values = []
-
-        for k, v in kv_iterator:
-            k = coercions.expect(roles.DMLColumnRole, k)
-
-            if isinstance(k, str):
-                desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
-                if desc is NO_VALUE:
-                    values.append(
-                        (
-                            k,
-                            coercions.expect(
-                                roles.ExpressionElementRole,
-                                v,
-                                type_=sqltypes.NullType(),
-                                is_crud=True,
-                            ),
-                        )
-                    )
-                else:
-                    values.extend(
-                        core_get_crud_kv_pairs(
-                            statement, desc._bulk_update_tuples(v)
-                        )
-                    )
-            elif "entity_namespace" in k._annotations:
-                k_anno = k._annotations
-                attr = _entity_namespace_key(
-                    k_anno["entity_namespace"], k_anno["proxy_key"]
-                )
-                values.extend(
-                    core_get_crud_kv_pairs(
-                        statement, attr._bulk_update_tuples(v)
-                    )
-                )
-            else:
-                values.append(
-                    (
-                        k,
-                        coercions.expect(
-                            roles.ExpressionElementRole,
-                            v,
-                            type_=sqltypes.NullType(),
-                            is_crud=True,
-                        ),
-                    )
-                )
-        return values
-
-    @classmethod
-    def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
-        states = set()
-        evaluated_keys = list(update_options._value_evaluators.keys())
-        values = update_options._resolved_keys_as_propnames
-        attrib = set(k for k, v in values)
-        for obj in update_options._matched_objects:
-
-            state, dict_ = (
-                attributes.instance_state(obj),
-                attributes.instance_dict(obj),
-            )
-
-            # the evaluated states were gathered across all identity tokens.
-            # however the post_sync events are called per identity token,
-            # so filter.
-            if (
-                update_options._refresh_identity_token is not None
-                and state.identity_token
-                != update_options._refresh_identity_token
-            ):
-                continue
-
-            # only evaluate unmodified attributes
-            to_evaluate = state.unmodified.intersection(evaluated_keys)
-            for key in to_evaluate:
-                if key in dict_:
-                    dict_[key] = update_options._value_evaluators[key](obj)
-
-            state.manager.dispatch.refresh(state, None, to_evaluate)
-
-            state._commit(dict_, list(to_evaluate))
-
-            to_expire = attrib.intersection(dict_).difference(to_evaluate)
-            if to_expire:
-                state._expire_attributes(dict_, to_expire)
-
-            states.add(state)
-        session._register_altered(states)
-
-    @classmethod
-    def _do_post_synchronize_fetch(cls, session, result, update_options):
-        target_mapper = update_options._subject_mapper
-
-        states = set()
-        evaluated_keys = list(update_options._value_evaluators.keys())
-
-        if result.returns_rows:
-            rows = cls._interpret_returning_rows(target_mapper, result.all())
-
-            matched_rows = [
-                tuple(row) + (update_options._refresh_identity_token,)
-                for row in rows
-            ]
-        else:
-            matched_rows = update_options._matched_rows
-
-        objs = [
-            session.identity_map[identity_key]
-            for identity_key in [
-                target_mapper.identity_key_from_primary_key(
-                    list(primary_key),
-                    identity_token=identity_token,
-                )
-                for primary_key, identity_token in [
-                    (row[0:-1], row[-1]) for row in matched_rows
-                ]
-                if update_options._refresh_identity_token is None
-                or identity_token == update_options._refresh_identity_token
-            ]
-            if identity_key in session.identity_map
-        ]
-
-        values = update_options._resolved_keys_as_propnames
-        attrib = set(k for k, v in values)
-
-        for obj in objs:
-            state, dict_ = (
-                attributes.instance_state(obj),
-                attributes.instance_dict(obj),
-            )
-
-            to_evaluate = state.unmodified.intersection(evaluated_keys)
-            for key in to_evaluate:
-                if key in dict_:
-                    dict_[key] = update_options._value_evaluators[key](obj)
-            state.manager.dispatch.refresh(state, None, to_evaluate)
-
-            state._commit(dict_, list(to_evaluate))
-
-            to_expire = attrib.intersection(dict_).difference(to_evaluate)
-            if to_expire:
-                state._expire_attributes(dict_, to_expire)
-
-            states.add(state)
-        session._register_altered(states)
-
-
-@CompileState.plugin_for("orm", "delete")
-class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState):
-    @classmethod
-    def create_for_statement(cls, statement, compiler, **kw):
-        self = cls.__new__(cls)
-
-        ext_info = statement.table._annotations["parententity"]
-        self.mapper = mapper = ext_info.mapper
-
-        self.extra_criteria_entities = {}
-
-        extra_criteria_attributes = {}
-
-        for opt in statement._with_options:
-            if opt._is_criteria_option:
-                opt.get_global_criteria(extra_criteria_attributes)
-
-        new_crit = cls._adjust_for_extra_criteria(
-            extra_criteria_attributes, mapper
-        )
-        if new_crit:
-            statement = statement.where(*new_crit)
-
-        # do this first as we need to determine if there is
-        # DELETE..FROM
-        DeleteDMLState.__init__(self, statement, compiler, **kw)
-
-        if compiler._annotations.get(
-            "synchronize_session", None
-        ) == "fetch" and self.can_use_returning(
-            compiler.dialect,
-            mapper,
-            is_multitable=self.is_multitable,
-            is_delete_using=compiler._annotations.get(
-                "is_delete_using", False
-            ),
-        ):
-            self.statement = statement.returning(*statement.table.primary_key)
-
-        return self
-
-    @classmethod
-    def can_use_returning(
-        cls,
-        dialect: Dialect,
-        mapper: Mapper[Any],
-        *,
-        is_multitable: bool = False,
-        is_update_from: bool = False,
-        is_delete_using: bool = False,
-    ) -> bool:
-
-        # normal answer for "should we use RETURNING" at all.
-        normal_answer = (
-            dialect.delete_returning and mapper.local_table.implicit_returning
-        )
-        if not normal_answer:
-            return False
-
-        # now get into special workarounds because MariaDB supports
-        # DELETE...RETURNING but not DELETE...USING...RETURNING.
-        if is_delete_using:
-            # is_delete_using hint was passed.   use
-            # additional dialect feature (True for PG, False for MariaDB)
-            return dialect.delete_returning_multifrom
-
-        elif is_multitable and not dialect.delete_returning_multifrom:
-            # is_delete_using hint was not passed, but we determined
-            # at compile time that this is in fact a DELETE..USING.
-            # it's too late to continue since we did not pre-SELECT.
-            # raise that we need that hint up front.
-
-            raise sa_exc.CompileError(
-                f'Dialect "{dialect.name}" does not support RETURNING '
-                "with DELETE..USING; for synchronize_session='fetch', "
-                "please add the additional execution option "
-                "'is_delete_using=True' to the statement to indicate that "
-                "a separate SELECT should be used for this backend."
-            )
-
-        return True
-
-    @classmethod
-    def _do_post_synchronize_evaluate(cls, session, result, update_options):
-
-        session._remove_newly_deleted(
-            [
-                attributes.instance_state(obj)
-                for obj in update_options._matched_objects
-            ]
-        )
-
-    @classmethod
-    def _do_post_synchronize_fetch(cls, session, result, update_options):
-        target_mapper = update_options._subject_mapper
-
-        if result.returns_rows:
-            rows = cls._interpret_returning_rows(target_mapper, result.all())
-
-            matched_rows = [
-                tuple(row) + (update_options._refresh_identity_token,)
-                for row in rows
-            ]
-        else:
-            matched_rows = update_options._matched_rows
-
-        for row in matched_rows:
-            primary_key = row[0:-1]
-            identity_token = row[-1]
-
-            # TODO: inline this and call remove_newly_deleted
-            # once
-            identity_key = target_mapper.identity_key_from_primary_key(
-                list(primary_key),
-                identity_token=identity_token,
-            )
-            if identity_key in session.identity_map:
-                session._remove_newly_deleted(
-                    [
-                        attributes.instance_state(
-                            session.identity_map[identity_key]
-                        )
-                    ]
-                )
index b17b053711869d13c13267021dcf03d2806ced6e..6d0f055e4bbb344c33eb27e7120d4566dccd73e4 100644 (file)
@@ -97,9 +97,9 @@ if TYPE_CHECKING:
     from ._typing import _EntityType
     from ._typing import _ExternalEntityType
     from ._typing import _InternalEntityType
+    from .bulk_persistence import _SynchronizeSessionArgument
     from .mapper import Mapper
     from .path_registry import PathRegistry
-    from .persistence import _SynchronizeSessionArgument
     from .session import _PKIdentityArgument
     from .session import Session
     from .state import InstanceState
index a518dfc05dee02a9b683f3ac5adc6fdf8cebb96a..45272caa248585fb0b6f51fe9cd7f122a834efa0 100644 (file)
@@ -34,12 +34,12 @@ from typing import Union
 import weakref
 
 from . import attributes
+from . import bulk_persistence
 from . import context
 from . import descriptor_props
 from . import exc
 from . import identity
 from . import loading
-from . import persistence
 from . import query
 from . import state as statelib
 from ._typing import _O
@@ -705,8 +705,8 @@ class ORMExecuteState(util.MemoizedSlots):
     def update_delete_options(
         self,
     ) -> Union[
-        persistence.BulkUDCompileState.default_update_options,
-        Type[persistence.BulkUDCompileState.default_update_options],
+        bulk_persistence.BulkUDCompileState.default_update_options,
+        Type[bulk_persistence.BulkUDCompileState.default_update_options],
     ]:
         """Return the update_delete_options that will be used for this
         execution."""
@@ -718,7 +718,7 @@ class ORMExecuteState(util.MemoizedSlots):
             )
         return self.execution_options.get(
             "_sa_orm_update_options",
-            persistence.BulkUDCompileState.default_update_options,
+            bulk_persistence.BulkUDCompileState.default_update_options,
         )
 
     @property
@@ -4275,7 +4275,7 @@ class Session(_SessionClassMethods, EventTarget):
         transaction = self.begin(_subtrans=True)
         try:
             if isupdate:
-                persistence._bulk_update(
+                bulk_persistence._bulk_update(
                     mapper,
                     mappings,
                     transaction,
@@ -4283,7 +4283,7 @@ class Session(_SessionClassMethods, EventTarget):
                     update_changed_only,
                 )
             else:
-                persistence._bulk_insert(
+                bulk_persistence._bulk_insert(
                     mapper,
                     mappings,
                     transaction,