]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
reorganize pre_session_exec around do_orm_execute
authorMike Bayer <mike_mp@zzzcomputing.com>
Sun, 18 Dec 2022 21:33:22 +0000 (16:33 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Mon, 26 Dec 2022 18:48:55 +0000 (13:48 -0500)
Allow do_orm_execute() events to both receive the complete
state of bind_argments, load_options, update_delete_options
as they do already, but also allow them to *change* all those
things via new execution options.   Options like autoflush,
populate_existing etc. can now be updated within a
do_orm_execute() hook and those changes will take effect
all the way through.

Took a few tries to get something that covers every case here,
in particular horizontal sharding which is consuming those
options as well as using context.invoke(), without excess
complexity.  The good news seems to be that a simple
reorg and replacing the "reentrant" boolean with
"is this before do_orm_execute is invoked" was all that was
needed.

As part of this we add a new "identity_token" option allowing
this option to be controlled from do_orm_execute() as well
as from the outside.

WIP

Fixes: #7837
Change-Id: I087728215edec8d1b1712322ab389e3f52ff76ba

17 files changed:
doc/build/changelog/unreleased_20/7837.rst [new file with mode: 0644]
doc/build/glossary.rst
doc/build/orm/queryguide/api.rst
examples/sharding/separate_databases.py
examples/sharding/separate_schema_translates.py
examples/sharding/separate_tables.py
lib/sqlalchemy/ext/horizontal_shard.py
lib/sqlalchemy/orm/bulk_persistence.py
lib/sqlalchemy/orm/context.py
lib/sqlalchemy/orm/loading.py
lib/sqlalchemy/orm/query.py
lib/sqlalchemy/orm/session.py
test/ext/test_deprecations.py
test/ext/test_horizontal_shard.py
test/orm/test_events.py
test/orm/test_query.py
test/orm/test_session.py

diff --git a/doc/build/changelog/unreleased_20/7837.rst b/doc/build/changelog/unreleased_20/7837.rst
new file mode 100644 (file)
index 0000000..1abb3e1
--- /dev/null
@@ -0,0 +1,40 @@
+.. change::
+    :tags: usecase, orm
+    :tickets: 7837
+
+    Adjustments to the :class:`_orm.Session` in terms of extensibility,
+    as well as updates to the :class:`.ShardedSession` extension:
+
+    * :meth:`_orm.Session.get` now accepts
+      :paramref:`_orm.Session.get.bind_arguments`, which in particular may be
+      useful when using the horizontal sharding extension.
+
+    * :meth:`_orm.Session.get_bind` accepts arbitrary kw arguments, which
+      assists in developing code that uses a :class:`_orm.Session` class which
+      overrides this method with additional arguments.
+
+    * Added a new ORM execution option ``identity_token`` which may be used
+      to directly affect the "identity token" that will be associated with
+      newly loaded ORM objects.  This token is how sharding approaches
+      (namely the :class:`.ShardedSession`, but can be used in other cases
+      as well) separate object identities across different "shards".
+
+      .. seealso::
+
+          :ref:`queryguide_identity_token`
+
+    * The :meth:`_orm.SessionEvents.do_orm_execute` event hook may now be used
+      to affect all ORM-related options, including ``autoflush``,
+      ``populate_existing``, and ``yield_per``; these options are re-consumed
+      subsequent to event hooks being invoked before they are acted upon.
+      Previously, options like ``autoflush`` would have been already evaluated
+      at this point. The new ``identity_token`` option is also supported in
+      this mode and is now used by the horizontal sharding extension.
+
+
+    * The :class:`.ShardedSession` class replaces the
+      :paramref:`.ShardedSession.id_chooser` hook with a new hook
+      :paramref:`.ShardedSession.identity_chooser`, which no longer relies upon
+      the legacy :class:`_orm.Query` object.
+      :paramref:`.ShardedSession.id_chooser` is still accepted in place of
+      :paramref:`.ShardedSession.identity_chooser` with a deprecation warning.
index d0bc4f8148c572527baec850ba5eb5369aaeb2ba..70eb05e6445b61ca77e5017ff521b18a1693a067 100644 (file)
@@ -488,6 +488,19 @@ Glossary
         primary key identity within the database, as well as their unique
         identity within a :class:`_orm.Session` :term:`identity map`.
 
+        In SQLAlchemy, you can view the identity key for an ORM object
+        using the :func:`_sa.inspect` API to return the :class:`_orm.InstanceState`
+        tracking object, then looking at the :attr:`_orm.InstanceState.key`
+        attribute::
+
+            >>> from sqlalchemy import inspect
+            >>> inspect(some_object).key
+            (<class '__main__.MyTable'>, (1,), None)
+
+        .. seealso::
+
+           :term:`identity map`
+
     identity map
         A mapping between Python objects and their database identities.
         The identity map is a collection that's associated with an
@@ -505,6 +518,9 @@ Glossary
 
             `Identity Map (via Martin Fowler) <https://martinfowler.com/eaaCatalog/identityMap.html>`_
 
+            :ref:`session_get` - how to look up an object in the identity map
+            by primary key
+
     lazy initialization
         A tactic of delaying some initialization action, such as creating objects,
         populating data, or establishing connectivity to other services, until
index 136b4b39bbbdaf4fffd161302d00dbe9f29a089a..35259a3b38bf300874123b264795c4fb77e9706f 100644 (file)
@@ -280,6 +280,138 @@ will have the same result as that of the ``yield_per`` execution option.
 
     :ref:`engine_stream_results`
 
+.. _queryguide_identity_token:
+
+Identity Token
+^^^^^^^^^^^^^^
+
+.. doctest-disable:
+
+.. deepalchemy::   This option is an advanced-use feature mostly intended
+   to be used with the :ref:`horizontal_sharding_toplevel` extension. For
+   typical cases of loading objects with identical primary keys from different
+   "shards" or partitions, consider using individual :class:`_orm.Session`
+   objects per shard first.
+
+
+The "identity token" is an arbitrary value that can be associated within
+the :term:`identity key` of newly loaded objects.   This element exists
+first and foremost to support extensions which perform per-row "sharding",
+where objects may be loaded from any number of replicas of a particular
+database table that nonetheless have overlapping primary key values.
+The primary consumer of "identity token" is the
+:ref:`horizontal_sharding_toplevel` extension, which supplies a general
+framework for persisting objects among multiple "shards" of a particular
+database table.
+
+The ``identity_token`` execution option may be used on a per-query basis
+to directly affect this token.   Using it directly, one can populate a
+:class:`_orm.Session` with multiple instances of an object that have the
+same primary key and source table, but different "identities".
+
+One such example is to populate a :class:`_orm.Session` with objects that
+come from same-named tables in different schemas, using the
+:ref:`schema_translating` feature which can affect the choice of schema
+within the scope of queries.  Given a mapping as:
+
+.. sourcecode:: python
+
+    from sqlalchemy.orm import DeclarativeBase
+    from sqlalchemy.orm import Mapped
+    from sqlalchemy.orm import mapped_column
+
+
+    class Base(DeclarativeBase):
+        pass
+
+
+    class MyTable(Base):
+        __tablename__ = "my_table"
+
+        id: Mapped[int] = mapped_column(primary_key=True)
+        name: Mapped[str]
+
+The default "schema" name for the class above is ``None``, meaning, no
+schema qualification will be written into SQL statements.  However,
+if we make use of :paramref:`_engine.Connection.execution_options.schema_translate_map`,
+mapping ``None`` to an alternate schema, we can place instances of
+``MyTable`` into two different schemas:
+
+.. sourcecode:: python
+
+    engine = create_engine(
+        "postgresql+psycopg://scott:tiger@localhost/test",
+    )
+
+    with Session(
+        engine.execution_options(schema_translate_map={None: "test_schema"})
+    ) as sess:
+        sess.add(MyTable(name="this is schema one"))
+        sess.commit()
+
+    with Session(
+        engine.execution_options(schema_translate_map={None: "test_schema_2"})
+    ) as sess:
+        sess.add(MyTable(name="this is schema two"))
+        sess.commit()
+
+The above two blocks create a :class:`_orm.Session` object linked to a different
+schema translate map each time, and an instance of ``MyTable`` is persisted
+into both ``test_schema.my_table`` as well as ``test_schema_2.my_table``.
+
+The :class:`_orm.Session` objects above are independent.  If we wanted to
+persist both objects in one transaction, we would need to use the
+:ref:`horizontal_sharding_toplevel` extension to do this.
+
+However, we can illustrate querying for these objects in one session as follows:
+
+.. sourcecode:: python
+
+    with Session(engine) as sess:
+        obj1 = sess.scalar(
+            select(MyTable)
+            .where(MyTable.id == 1)
+            .execution_options(
+                schema_translate_map={None: "test_schema"},
+                identity_token="test_schema",
+            )
+        )
+        obj2 = sess.scalar(
+            select(MyTable)
+            .where(MyTable.id == 1)
+            .execution_options(
+                schema_translate_map={None: "test_schema_2"},
+                identity_token="test_schema_2",
+            )
+        )
+
+Both ``obj1`` and ``obj2`` are distinct from each other.  However, they both
+refer to primary key id 1 for the ``MyTable`` class, yet are distinct.
+This is how the ``identity_token`` comes into play, which we can see in the
+inspection of each object, where we look at :attr:`_orm.InstanceState.key`
+to view the two distinct identity tokens::
+
+    >>> from sqlalchemy import inspect
+    >>> inspect(obj1).key
+    (<class '__main__.MyTable'>, (1,), 'test_schema')
+    >>> inspect(obj2).key
+    (<class '__main__.MyTable'>, (1,), 'test_schema_2')
+
+
+The above logic takes place automatically when using the
+:ref:`horizontal_sharding_toplevel` extension.
+
+.. versionadded:: 2.0.0b5 - added the ``identity_token`` ORM level execution
+   option.
+
+.. seealso::
+
+    :ref:`examples_sharding` - in the :ref:`examples_toplevel` section.
+    See the script ``separate_schema_translates.py`` for a demonstration of
+    the above use case using the full sharding API.
+
+
+.. doctest-enable:
 
 .. _queryguide_inspection:
 
index a45182f42db83399982654cb75131f38288a087b..fe92fd3bac43b6c2fc7c434c3aecb9c2fcc80d6c 100644 (file)
@@ -135,8 +135,8 @@ def shard_chooser(mapper, instance, clause=None):
         return shard_chooser(mapper, instance.location)
 
 
-def id_chooser(query, ident):
-    """id chooser.
+def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw):
+    """identity chooser.
 
     given a primary key, returns a list of shards
     to search.  here, we don't have any particular information from a
@@ -145,11 +145,11 @@ def id_chooser(query, ident):
     distributed among DBs.
 
     """
-    if query.lazy_loaded_from:
+    if lazy_loaded_from:
         # if we are in a lazy load, we can look at the parent object
         # and limit our search to that same shard, assuming that's how we've
         # set things up.
-        return [query.lazy_loaded_from.identity_token]
+        return [lazy_loaded_from.identity_token]
     else:
         return ["north_america", "asia", "europe", "south_america"]
 
@@ -237,7 +237,7 @@ def _get_select_comparisons(statement):
 # further configure create_session to use these functions
 Session.configure(
     shard_chooser=shard_chooser,
-    id_chooser=id_chooser,
+    identity_chooser=identity_chooser,
     execute_chooser=execute_chooser,
 )
 
index 2d4c2a0464fb6ac35f6afc89ed61af4f9a6de71c..f7bdc62500eb80acff89ce2ff9041af835f95761 100644 (file)
@@ -130,21 +130,20 @@ def shard_chooser(mapper, instance, clause=None):
         return shard_chooser(mapper, instance.location)
 
 
-def id_chooser(query, ident):
-    """id chooser.
+def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw):
+    """identity chooser.
 
-    given a primary key identity and a legacy :class:`_orm.Query`,
-    return which shard we should look at.
+    given a primary key identity, return which shard we should look at.
 
     in this case, we only want to support this for lazy-loaded items;
     any primary query should have shard id set up front.
 
     """
-    if query.lazy_loaded_from:
+    if lazy_loaded_from:
         # if we are in a lazy load, we can look at the parent object
         # and limit our search to that same shard, assuming that's how we've
         # set things up.
-        return [query.lazy_loaded_from.identity_token]
+        return [lazy_loaded_from.identity_token]
     else:
         raise NotImplementedError()
 
@@ -169,7 +168,7 @@ def execute_chooser(context):
 # configure shard chooser
 Session.configure(
     shard_chooser=shard_chooser,
-    id_chooser=id_chooser,
+    identity_chooser=identity_chooser,
     execute_chooser=execute_chooser,
 )
 
index 8f39471e888534a5935ffcf53b8584138caa2e05..97c6a07f6a1047f3c1e742753456ae9e3ab59276 100644 (file)
@@ -149,8 +149,8 @@ def shard_chooser(mapper, instance, clause=None):
         return shard_chooser(mapper, instance.location)
 
 
-def id_chooser(query, ident):
-    """id chooser.
+def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw):
+    """identity chooser.
 
     given a primary key, returns a list of shards
     to search.  here, we don't have any particular information from a
@@ -159,11 +159,11 @@ def id_chooser(query, ident):
     distributed among DBs.
 
     """
-    if query.lazy_loaded_from:
+    if lazy_loaded_from:
         # if we are in a lazy load, we can look at the parent object
         # and limit our search to that same shard, assuming that's how we've
         # set things up.
-        return [query.lazy_loaded_from.identity_token]
+        return [lazy_loaded_from.identity_token]
     else:
         return ["north_america", "asia", "europe", "south_america"]
 
@@ -251,7 +251,7 @@ def _get_select_comparisons(statement):
 # further configure create_session to use these functions
 Session.configure(
     shard_chooser=shard_chooser,
-    id_chooser=id_chooser,
+    identity_chooser=identity_chooser,
     execute_chooser=execute_chooser,
 )
 
index 69767ad6cbf74b2195c86322edd52b3f08cd1d20..fd53c60468a5e84b403a581faf1862676bf4225d 100644 (file)
@@ -13,11 +13,14 @@ distribute queries and persistence operations across multiple databases.
 For a usage example, see the :ref:`examples_sharding` example included in
 the source distribution.
 
-.. legacy:: The horizontal sharding API is not fully updated for the
-   SQLAlchemy 2.0 API, and still relies in part on the
-   legacy :class:`.Query` architecture, in particular as part of the
-   signature for the :paramref:`.ShardedSession.id_chooser` parameter.
-   This may change in a future release.
+.. deepalchemy:: The horizontal sharding extension is an advanced feature,
+   involving a complex statement -> database interaction as well as
+   use of semi-public APIs for non-trivial cases.   Simpler approaches to
+   refering to multiple database "shards", most commonly using a distinct
+   :class:`_orm.Session` per "shard", should always be considered first
+   before using this more complex and less-production-tested system.
+
+
 
 """
 from __future__ import annotations
@@ -38,8 +41,11 @@ from .. import exc
 from .. import inspect
 from .. import util
 from ..orm import PassiveFlag
+from ..orm._typing import OrmExecuteOptionsParameter
 from ..orm.mapper import Mapper
 from ..orm.query import Query
+from ..orm.session import _BindArguments
+from ..orm.session import _PKIdentityArgument
 from ..orm.session import Session
 from ..util.typing import Protocol
 
@@ -80,6 +86,20 @@ class ShardChooser(Protocol):
         ...
 
 
+class IdentityChooser(Protocol):
+    def __call__(
+        self,
+        mapper: Mapper[_T],
+        primary_key: _PKIdentityArgument,
+        *,
+        lazy_loaded_from: Optional[InstanceState[Any]],
+        execution_options: OrmExecuteOptionsParameter,
+        bind_arguments: _BindArguments,
+        **kw: Any,
+    ) -> Any:
+        ...
+
+
 class ShardedQuery(Query[_T]):
     """Query class used with :class:`.ShardedSession`.
 
@@ -94,8 +114,7 @@ class ShardedQuery(Query[_T]):
         super().__init__(*args, **kwargs)
         assert isinstance(self.session, ShardedSession)
 
-        self.id_chooser = self.session.id_chooser
-        self.query_chooser = self.session.query_chooser
+        self.identity_chooser = self.session.identity_chooser
         self.execute_chooser = self.session.execute_chooser
         self._shard_id = None
 
@@ -119,19 +138,22 @@ class ShardedQuery(Query[_T]):
 
 class ShardedSession(Session):
     shard_chooser: ShardChooser
-    id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]]
+    identity_chooser: IdentityChooser
     execute_chooser: Callable[[ORMExecuteState], Iterable[Any]]
 
     def __init__(
         self,
         shard_chooser: ShardChooser,
-        id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]],
+        identity_chooser: Optional[IdentityChooser] = None,
         execute_chooser: Optional[
             Callable[[ORMExecuteState], Iterable[Any]]
         ] = None,
         shards: Optional[Dict[str, Any]] = None,
         query_cls: Type[Query[_T]] = ShardedQuery,
         *,
+        id_chooser: Optional[
+            Callable[[Query[_T], Iterable[_T]], Iterable[Any]]
+        ] = None,
         query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None,
         **kwargs: Any,
     ) -> None:
@@ -171,12 +193,41 @@ class ShardedSession(Session):
             self, "do_orm_execute", execute_and_instances, retval=True
         )
         self.shard_chooser = shard_chooser
-        self.id_chooser = id_chooser
+
+        if id_chooser:
+            _id_chooser = id_chooser
+            util.warn_deprecated(
+                "The ``id_chooser`` parameter is deprecated; "
+                "please use ``identity_chooser``.",
+                "2.0",
+            )
+
+            def _legacy_identity_chooser(
+                mapper: Mapper[_T],
+                primary_key: _PKIdentityArgument,
+                *,
+                lazy_loaded_from: Optional[InstanceState[Any]],
+                execution_options: OrmExecuteOptionsParameter,
+                bind_arguments: _BindArguments,
+                **kw: Any,
+            ) -> Any:
+                q = self.query(mapper)
+                if lazy_loaded_from:
+                    q = q._set_lazyload_from(lazy_loaded_from)
+                return _id_chooser(q, primary_key)
+
+            self.identity_chooser = _legacy_identity_chooser
+        elif identity_chooser:
+            self.identity_chooser = identity_chooser
+        else:
+            raise exc.ArgumentError(
+                "identity_chooser or id_chooser is required"
+            )
 
         if query_chooser:
             _query_chooser = query_chooser
             util.warn_deprecated(
-                "The ``query_choser`` parameter is deprecated; "
+                "The ``query_chooser`` parameter is deprecated; "
                 "please use ``execute_chooser``.",
                 "1.4",
             )
@@ -199,7 +250,6 @@ class ShardedSession(Session):
                 "execute_chooser or query_chooser is required"
             )
         self.execute_chooser = execute_chooser
-        self.query_chooser = query_chooser
         self.__shards: Dict[_ShardKey, _SessionBind] = {}
         if shards is not None:
             for k in shards:
@@ -212,6 +262,8 @@ class ShardedSession(Session):
         identity_token: Optional[Any] = None,
         passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
         lazy_loaded_from: Optional[InstanceState[Any]] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
         **kw: Any,
     ) -> Union[Optional[_O], LoaderCallableStatus]:
         """override the default :meth:`.Session._identity_lookup` method so
@@ -233,10 +285,13 @@ class ShardedSession(Session):
 
             return obj
         else:
-            q = self.query(mapper)
-            if lazy_loaded_from:
-                q = q._set_lazyload_from(lazy_loaded_from)
-            for shard_id in self.id_chooser(q, primary_key_identity):
+            for shard_id in self.identity_chooser(
+                mapper,
+                primary_key_identity,
+                lazy_loaded_from=lazy_loaded_from,
+                execution_options=execution_options,
+                bind_arguments=dict(bind_arguments) if bind_arguments else {},
+            ):
                 obj2 = super()._identity_lookup(
                     mapper,
                     primary_key_identity,
@@ -325,11 +380,6 @@ class ShardedSession(Session):
 def execute_and_instances(
     orm_context: ORMExecuteState,
 ) -> Union[Result[_T], IteratorResult[_TP]]:
-    update_options: Union[
-        None,
-        BulkUDCompileState.default_update_options,
-        Type[BulkUDCompileState.default_update_options],
-    ]
     active_options: Union[
         None,
         QueryContext.default_load_options,
@@ -337,58 +387,30 @@ def execute_and_instances(
         BulkUDCompileState.default_update_options,
         Type[BulkUDCompileState.default_update_options],
     ]
-    load_options: Union[
-        None,
-        QueryContext.default_load_options,
-        Type[QueryContext.default_load_options],
-    ]
 
     if orm_context.is_select:
-        load_options = active_options = orm_context.load_options
-        update_options = None
+        active_options = orm_context.load_options
 
     elif orm_context.is_update or orm_context.is_delete:
-        load_options = None
-        update_options = active_options = orm_context.update_delete_options
+        active_options = orm_context.update_delete_options
     else:
-        load_options = update_options = active_options = None
+        active_options = None
 
     session = orm_context.session
     assert isinstance(session, ShardedSession)
 
     def iter_for_shard(
         shard_id: str,
-        load_options: Union[
-            None,
-            QueryContext.default_load_options,
-            Type[QueryContext.default_load_options],
-        ],
-        update_options: Union[
-            None,
-            BulkUDCompileState.default_update_options,
-            Type[BulkUDCompileState.default_update_options],
-        ],
     ) -> Union[Result[_T], IteratorResult[_TP]]:
-        execution_options = dict(orm_context.local_execution_options)
 
         bind_arguments = dict(orm_context.bind_arguments)
         bind_arguments["shard_id"] = shard_id
 
-        if orm_context.is_select:
-            assert load_options is not None
-            load_options += {"_refresh_identity_token": shard_id}
-            execution_options["_sa_orm_load_options"] = load_options
-        elif orm_context.is_update or orm_context.is_delete:
-            assert update_options is not None
-            update_options += {"_refresh_identity_token": shard_id}
-            execution_options["_sa_orm_update_options"] = update_options
-
-        return orm_context.invoke_statement(
-            bind_arguments=bind_arguments, execution_options=execution_options
-        )
+        orm_context.update_execution_options(identity_token=shard_id)
+        return orm_context.invoke_statement(bind_arguments=bind_arguments)
 
-    if active_options and active_options._refresh_identity_token is not None:
-        shard_id = active_options._refresh_identity_token
+    if active_options and active_options._identity_token is not None:
+        shard_id = active_options._identity_token
     elif "_sa_shard_id" in orm_context.execution_options:
         shard_id = orm_context.execution_options["_sa_shard_id"]
     elif "shard_id" in orm_context.bind_arguments:
@@ -397,10 +419,10 @@ def execute_and_instances(
         shard_id = None
 
     if shard_id is not None:
-        return iter_for_shard(shard_id, load_options, update_options)
+        return iter_for_shard(shard_id)
     else:
         partial = []
         for shard_id in session.execute_chooser(orm_context):
-            result_ = iter_for_shard(shard_id, load_options, update_options)
+            result_ = iter_for_shard(shard_id)
             partial.append(result_)
         return partial[0].merge(*partial[1:])
index 181dbd4a283ff658264f7d1b1da07fdef8e2e64f..805bfdc65e90fc4e70037a455dfceda9257790f3 100644 (file)
@@ -555,7 +555,7 @@ class BulkUDCompileState(ORMDMLState):
         _resolved_values = EMPTY_DICT
         _eval_condition = None
         _matched_rows = None
-        _refresh_identity_token = None
+        _identity_token = None
 
     @classmethod
     def can_use_returning(
@@ -577,10 +577,8 @@ class BulkUDCompileState(ORMDMLState):
         params,
         execution_options,
         bind_arguments,
-        is_reentrant_invoke,
+        is_pre_event,
     ):
-        if is_reentrant_invoke:
-            return statement, execution_options
 
         (
             update_options,
@@ -590,6 +588,7 @@ class BulkUDCompileState(ORMDMLState):
             {
                 "synchronize_session",
                 "autoflush",
+                "identity_token",
                 "is_delete_using",
                 "is_update_from",
                 "dml_strategy",
@@ -637,55 +636,56 @@ class BulkUDCompileState(ORMDMLState):
                     "for 'bulk' ORM updates (i.e. multiple parameter sets)"
                 )
 
-        if update_options._autoflush:
-            session._autoflush()
-
-        if update_options._dml_strategy == "orm":
+        if not is_pre_event:
+            if update_options._autoflush:
+                session._autoflush()
 
-            if update_options._synchronize_session == "auto":
-                update_options = cls._do_pre_synchronize_auto(
-                    session,
-                    statement,
-                    params,
-                    execution_options,
-                    bind_arguments,
-                    update_options,
-                )
-            elif update_options._synchronize_session == "evaluate":
-                update_options = cls._do_pre_synchronize_evaluate(
-                    session,
-                    statement,
-                    params,
-                    execution_options,
-                    bind_arguments,
-                    update_options,
-                )
-            elif update_options._synchronize_session == "fetch":
-                update_options = cls._do_pre_synchronize_fetch(
-                    session,
-                    statement,
-                    params,
-                    execution_options,
-                    bind_arguments,
-                    update_options,
-                )
-        elif update_options._dml_strategy == "bulk":
-            if update_options._synchronize_session == "auto":
-                update_options += {"_synchronize_session": "evaluate"}
+            if update_options._dml_strategy == "orm":
 
-        # indicators from the "pre exec" step that are then
-        # added to the DML statement, which will also be part of the cache
-        # key.  The compile level create_for_statement() method will then
-        # consume these at compiler time.
-        statement = statement._annotate(
-            {
-                "synchronize_session": update_options._synchronize_session,
-                "is_delete_using": update_options._is_delete_using,
-                "is_update_from": update_options._is_update_from,
-                "dml_strategy": update_options._dml_strategy,
-                "can_use_returning": update_options._can_use_returning,
-            }
-        )
+                if update_options._synchronize_session == "auto":
+                    update_options = cls._do_pre_synchronize_auto(
+                        session,
+                        statement,
+                        params,
+                        execution_options,
+                        bind_arguments,
+                        update_options,
+                    )
+                elif update_options._synchronize_session == "evaluate":
+                    update_options = cls._do_pre_synchronize_evaluate(
+                        session,
+                        statement,
+                        params,
+                        execution_options,
+                        bind_arguments,
+                        update_options,
+                    )
+                elif update_options._synchronize_session == "fetch":
+                    update_options = cls._do_pre_synchronize_fetch(
+                        session,
+                        statement,
+                        params,
+                        execution_options,
+                        bind_arguments,
+                        update_options,
+                    )
+            elif update_options._dml_strategy == "bulk":
+                if update_options._synchronize_session == "auto":
+                    update_options += {"_synchronize_session": "evaluate"}
+
+            # indicators from the "pre exec" step that are then
+            # added to the DML statement, which will also be part of the cache
+            # key.  The compile level create_for_statement() method will then
+            # consume these at compiler time.
+            statement = statement._annotate(
+                {
+                    "synchronize_session": update_options._synchronize_session,
+                    "is_delete_using": update_options._is_delete_using,
+                    "is_update_from": update_options._is_update_from,
+                    "dml_strategy": update_options._dml_strategy,
+                    "can_use_returning": update_options._can_use_returning,
+                }
+            )
 
         return (
             statement,
@@ -836,7 +836,7 @@ class BulkUDCompileState(ORMDMLState):
             if state.mapper.isa(mapper) and not state.expired
         ]
 
-        identity_token = update_options._refresh_identity_token
+        identity_token = update_options._identity_token
         if identity_token is not None:
             raw_data = [
                 (obj, state, dict_)
@@ -1091,7 +1091,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState):
         params,
         execution_options,
         bind_arguments,
-        is_reentrant_invoke,
+        is_pre_event,
     ):
 
         (
@@ -1143,7 +1143,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState):
                     context._orm_load_exec_options
                 )
 
-        if insert_options._autoflush:
+        if not is_pre_event and insert_options._autoflush:
             session._autoflush()
 
         statement = statement._annotate(
@@ -1577,7 +1577,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
         for param in params:
             identity_key = mapper.identity_key_from_primary_key(
                 (param[key] for key in pk_keys),
-                update_options._refresh_identity_token,
+                update_options._identity_token,
             )
             state = identity_map.fast_get_state(identity_key)
             if not state:
@@ -1635,7 +1635,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
             )
 
             matched_rows = [
-                tuple(row) + (update_options._refresh_identity_token,)
+                tuple(row) + (update_options._identity_token,)
                 for row in pk_rows
             ]
         else:
@@ -1651,8 +1651,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState):
                 for primary_key, identity_token in [
                     (row[0:-1], row[-1]) for row in matched_rows
                 ]
-                if update_options._refresh_identity_token is None
-                or identity_token == update_options._refresh_identity_token
+                if update_options._identity_token is None
+                or identity_token == update_options._identity_token
             ]
             if identity_key in session.identity_map
         ]
@@ -1912,7 +1912,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState):
             )
 
             matched_rows = [
-                tuple(row) + (update_options._refresh_identity_token,)
+                tuple(row) + (update_options._identity_token,)
                 for row in pk_rows
             ]
         else:
index 3bd8b02a71be5bc5909f84e0a88acbb41a1facc6..b3478b83e14511f6a39a5926ca1d9d3cec5ee027 100644 (file)
@@ -135,7 +135,7 @@ class QueryContext:
         _version_check = False
         _invoke_all_eagers = True
         _autoflush = True
-        _refresh_identity_token = None
+        _identity_token = None
         _yield_per = None
         _refresh_state = None
         _lazy_loaded_from = None
@@ -194,14 +194,14 @@ class QueryContext:
         self.version_check = load_options._version_check
         self.refresh_state = load_options._refresh_state
         self.yield_per = load_options._yield_per
-        self.identity_token = load_options._refresh_identity_token
+        self.identity_token = load_options._identity_token
 
     def _get_top_level_context(self) -> QueryContext:
         return self.top_level_context or self
 
 
 _orm_load_exec_options = util.immutabledict(
-    {"_result_disable_adapt_to_context": True, "future_result": True}
+    {"_result_disable_adapt_to_context": True}
 )
 
 
@@ -235,7 +235,7 @@ class AbstractORMCompileState(CompileState):
         params,
         execution_options,
         bind_arguments,
-        is_reentrant_invoke,
+        is_pre_event,
     ):
         raise NotImplementedError()
 
@@ -384,11 +384,11 @@ class ORMCompileState(AbstractORMCompileState):
         params,
         execution_options,
         bind_arguments,
-        is_reentrant_invoke,
+        is_pre_event,
     ):
-        if is_reentrant_invoke:
-            return statement, execution_options
 
+        # consume result-level load_options.  These may have been set up
+        # in an ORMExecuteState hook
         (
             load_options,
             execution_options,
@@ -398,26 +398,24 @@ class ORMCompileState(AbstractORMCompileState):
                 "populate_existing",
                 "autoflush",
                 "yield_per",
+                "identity_token",
                 "sa_top_level_orm_context",
             },
             execution_options,
             statement._execution_options,
         )
+
         # default execution options for ORM results:
         # 1. _result_disable_adapt_to_context=True
         #    this will disable the ResultSetMetadata._adapt_to_context()
         #    step which we don't need, as we have result processors cached
         #    against the original SELECT statement before caching.
-        # 2. future_result=True.  The ORM should **never** resolve columns
-        #    in a result set based on names, only on Column objects that
-        #    are correctly adapted to the context.   W the legacy result
-        #    it will still attempt name-based resolution and also emit a
-        #    warning.
         if not execution_options:
             execution_options = _orm_load_exec_options
         else:
             execution_options = execution_options.union(_orm_load_exec_options)
 
+        # would have been placed here by legacy Query only
         if load_options._yield_per:
             execution_options = execution_options.union(
                 {"yield_per": load_options._yield_per}
@@ -457,7 +455,7 @@ class ORMCompileState(AbstractORMCompileState):
             if plugin_subject:
                 bind_arguments["mapper"] = plugin_subject.mapper
 
-        if load_options._autoflush:
+        if not is_pre_event and load_options._autoflush:
             session._autoflush()
 
         return statement, execution_options
@@ -483,6 +481,7 @@ class ORMCompileState(AbstractORMCompileState):
         load_options = execution_options.get(
             "_sa_orm_load_options", QueryContext.default_load_options
         )
+
         if compile_state.compile_options._is_star:
             return result
 
@@ -3119,6 +3118,6 @@ class _IdentityTokenEntity(_ORMColumnEntity):
 
     def row_processor(self, context, result):
         def getter(row):
-            return context.load_options._refresh_identity_token
+            return context.load_options._identity_token
 
         return getter, self._label_name, self._extra_entities
index 6e7695f8610a2f08b358cf916b939e3d3d75dce2..f331cd63b0186f4dce39aeaf1acf3e04a07954b9 100644 (file)
@@ -701,7 +701,7 @@ def _set_get_options(
     if only_load_props:
         compile_options["_only_load_props"] = frozenset(only_load_props)
     if identity_token:
-        load_options["_refresh_identity_token"] = identity_token
+        load_options["_identity_token"] = identity_token
 
     if load_options:
         load_opt += load_options
index 01db08eb4616567d01e0c8aaf88406853ef83a47..d2bd930ff413de4e3c7715faafc3a1a68fee21f9 100644 (file)
@@ -470,7 +470,7 @@ class Query(
         if only_load_props:
             compile_options["_only_load_props"] = frozenset(only_load_props)
         if identity_token:
-            load_options["_refresh_identity_token"] = identity_token
+            load_options["_identity_token"] = identity_token
 
         if load_options:
             self.load_options += load_options
index bf3df05990699e5d8181227dc041b8e5a43e88b0..8b5f7c88ab8866dbcbef22cb3e6ab700dce0302f 100644 (file)
@@ -267,6 +267,7 @@ class ORMExecuteState(util.MemoizedSlots):
         "execution_options",
         "local_execution_options",
         "bind_arguments",
+        "identity_token",
         "_compile_state_cls",
         "_starting_event_idx",
         "_events_todo",
@@ -579,9 +580,8 @@ class ORMExecuteState(util.MemoizedSlots):
     def _is_crud(self) -> bool:
         return isinstance(self.statement, (dml.Update, dml.Delete))
 
-    def update_execution_options(self, **opts: _ExecuteOptions) -> None:
+    def update_execution_options(self, **opts: Any) -> None:
         """Update the local execution options with new values."""
-        # TODO: no coverage
         self.local_execution_options = self.local_execution_options.union(opts)
 
     def _orm_compile_options(
@@ -1912,27 +1912,10 @@ class Session(_SessionClassMethods, EventTarget):
                 )
         else:
             compile_state_cls = None
+            bind_arguments.setdefault("clause", statement)
 
         execution_options = util.coerce_to_immutabledict(execution_options)
 
-        if compile_state_cls is not None:
-            (
-                statement,
-                execution_options,
-            ) = compile_state_cls.orm_pre_session_exec(
-                self,
-                statement,
-                params,
-                execution_options,
-                bind_arguments,
-                _parent_execute_state is not None,
-            )
-        else:
-            bind_arguments.setdefault("clause", statement)
-            execution_options = execution_options.union(
-                {"future_result": True}
-            )
-
         if _parent_execute_state:
             events_todo = _parent_execute_state._remaining_events()
         else:
@@ -1941,6 +1924,25 @@ class Session(_SessionClassMethods, EventTarget):
                 events_todo = list(events_todo) + [_add_event]
 
         if events_todo:
+            if compile_state_cls is not None:
+                # for event handlers, do the orm_pre_session_exec
+                # pass ahead of the event handlers, so that things like
+                # .load_options, .update_delete_options etc. are populated.
+                # is_pre_event=True allows the hook to hold off on things
+                # it doesn't want to do twice, including autoflush as well
+                # as "pre fetch" for DML, etc.
+                (
+                    statement,
+                    execution_options,
+                ) = compile_state_cls.orm_pre_session_exec(
+                    self,
+                    statement,
+                    params,
+                    execution_options,
+                    bind_arguments,
+                    True,
+                )
+
             orm_exec_state = ORMExecuteState(
                 self,
                 statement,
@@ -1962,6 +1964,24 @@ class Session(_SessionClassMethods, EventTarget):
             statement = orm_exec_state.statement
             execution_options = orm_exec_state.local_execution_options
 
+        if compile_state_cls is not None:
+            # now run orm_pre_session_exec() "for real".   if there were
+            # event hooks, this will re-run the steps that interpret
+            # new execution_options into load_options / update_delete_options,
+            # which we assume the event hook might have updated.
+            # autoflush will also be invoked in this step if enabled.
+            (
+                statement,
+                execution_options,
+            ) = compile_state_cls.orm_pre_session_exec(
+                self,
+                statement,
+                params,
+                execution_options,
+                bind_arguments,
+                False,
+            )
+
         bind = self.get_bind(**bind_arguments)
 
         conn = self._connection_for_bind(bind)
@@ -2379,6 +2399,7 @@ class Session(_SessionClassMethods, EventTarget):
         bind: Optional[_SessionBind] = None,
         _sa_skip_events: Optional[bool] = None,
         _sa_skip_for_implicit_returning: bool = False,
+        **kw: Any,
     ) -> Union[Engine, Connection]:
         """Return a "bind" to which this :class:`.Session` is bound.
 
@@ -2653,6 +2674,8 @@ class Session(_SessionClassMethods, EventTarget):
         identity_token: Any = None,
         passive: PassiveFlag = PassiveFlag.PASSIVE_OFF,
         lazy_loaded_from: Optional[InstanceState[Any]] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
     ) -> Union[Optional[_O], LoaderCallableStatus]:
         """Locate an object in the identity map.
 
@@ -3262,6 +3285,7 @@ class Session(_SessionClassMethods, EventTarget):
         with_for_update: Optional[ForUpdateArg] = None,
         identity_token: Optional[Any] = None,
         execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
     ) -> Optional[_O]:
         """Return an instance based on the given primary key identifier,
         or ``None`` if not found.
@@ -3355,6 +3379,13 @@ class Session(_SessionClassMethods, EventTarget):
             :ref:`orm_queryguide_execution_options` - ORM-specific execution
             options
 
+        :param bind_arguments: dictionary of additional arguments to determine
+         the bind.  May include "mapper", "bind", or other custom arguments.
+         Contents of this dictionary are passed to the
+         :meth:`.Session.get_bind` method.
+
+         .. versionadded: 2.0.0b5
+
         :return: The object instance, or ``None``.
 
         """
@@ -3367,6 +3398,7 @@ class Session(_SessionClassMethods, EventTarget):
             with_for_update=with_for_update,
             identity_token=identity_token,
             execution_options=execution_options,
+            bind_arguments=bind_arguments,
         )
 
     def _get_impl(
@@ -3379,7 +3411,8 @@ class Session(_SessionClassMethods, EventTarget):
         populate_existing: bool = False,
         with_for_update: Optional[ForUpdateArg] = None,
         identity_token: Optional[Any] = None,
-        execution_options: Optional[OrmExecuteOptionsParameter] = None,
+        execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT,
+        bind_arguments: Optional[_BindArguments] = None,
     ) -> Optional[_O]:
 
         # convert composite types to individual args
@@ -3453,7 +3486,11 @@ class Session(_SessionClassMethods, EventTarget):
         ):
 
             instance = self._identity_lookup(
-                mapper, primary_key_identity, identity_token=identity_token
+                mapper,
+                primary_key_identity,
+                identity_token=identity_token,
+                execution_options=execution_options,
+                bind_arguments=bind_arguments,
             )
 
             if instance is not None:
@@ -3484,13 +3521,14 @@ class Session(_SessionClassMethods, EventTarget):
 
         if options:
             statement = statement.options(*options)
-        if execution_options:
-            statement = statement.execution_options(**execution_options)
         return db_load_fn(
             self,
             statement,
             primary_key_identity,
             load_options=load_options,
+            identity_token=identity_token,
+            execution_options=execution_options,
+            bind_arguments=bind_arguments,
         )
 
     def merge(
index 09f904487adea08bebe6035870b061ab58d4777f..97c4172ba7be4686a11f34e80c8e2d14500b7250 100644 (file)
@@ -1,3 +1,5 @@
+from sqlalchemy import Column
+from sqlalchemy import Integer
 from sqlalchemy import testing
 from sqlalchemy.ext.automap import automap_base
 from sqlalchemy.ext.horizontal_shard import ShardedSession
@@ -68,7 +70,7 @@ class HorizontalShardTest(fixtures.TestBase):
         m1 = mock.Mock()
 
         with testing.expect_deprecated(
-            "The ``query_choser`` parameter is deprecated; please use"
+            "The ``query_chooser`` parameter is deprecated; please use"
         ):
             s = ShardedSession(
                 shard_chooser=m1.shard_chooser,
@@ -80,3 +82,30 @@ class HorizontalShardTest(fixtures.TestBase):
         s.execute_chooser(m2)
 
         eq_(m1.mock_calls, [mock.call.query_chooser(m2.statement)])
+
+    def test_id_chooser(self, decl_base):
+        class A(decl_base):
+            __tablename__ = "a"
+            id = Column(Integer, primary_key=True)
+
+        m1 = mock.Mock()
+
+        with testing.expect_deprecated(
+            "The ``id_chooser`` parameter is deprecated; please use"
+        ):
+            s = ShardedSession(
+                shard_chooser=m1.shard_chooser,
+                id_chooser=m1.id_chooser,
+                execute_chooser=m1.execute_chooser,
+            )
+
+        m2 = mock.Mock()
+        s.identity_chooser(
+            A.__mapper__,
+            m2.primary_key,
+            lazy_loaded_from=m2.lazy_loaded_from,
+            execution_options=m2.execution_options,
+            bind_arguments=m2.bind_arguments,
+        )
+
+        eq_(m1.mock_calls, [mock.call.id_chooser(mock.ANY, m2.primary_key)])
index ab4a24f71c1e10c150354c3fde0197fe51c6b54e..8e5d09cab0a6487854f3d19d893b3b672da1d74f 100644 (file)
@@ -28,6 +28,7 @@ from sqlalchemy.pool import SingletonThreadPool
 from sqlalchemy.sql import operators
 from sqlalchemy.sql import Select
 from sqlalchemy.testing import eq_
+from sqlalchemy.testing import expect_deprecated
 from sqlalchemy.testing import fixtures
 from sqlalchemy.testing import is_
 from sqlalchemy.testing import provision
@@ -109,7 +110,15 @@ class ShardTest:
             else:
                 return shard_chooser(mapper, instance.location)
 
-        def id_chooser(query, ident):
+        def identity_chooser(
+            mapper,
+            primary_key,
+            *,
+            lazy_loaded_from,
+            execution_options,
+            bind_arguments,
+            **kw,
+        ):
             return ["north_america", "asia", "europe", "south_america"]
 
         def execute_chooser(orm_context):
@@ -144,7 +153,7 @@ class ShardTest:
                 "south_america": db4,
             },
             shard_chooser=shard_chooser,
-            id_chooser=id_chooser,
+            identity_chooser=identity_chooser,
             execute_chooser=execute_chooser,
         )
 
@@ -189,7 +198,7 @@ class ShardTest:
         tokyo.reports.append(Report(80.0, id_=1))
         newyork.reports.append(Report(75, id_=1))
         quito.reports.append(Report(85))
-        sess = sharded_session(future=True)
+        sess = sharded_session()
         for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]:
             sess.add(c)
         sess.flush()
@@ -589,6 +598,68 @@ class DistinctEngineShardTest(ShardTest, fixtures.MappedTest):
         )
 
 
+class LegacyAPIShardTest(DistinctEngineShardTest):
+    @classmethod
+    def setup_session(cls):
+        global sharded_session
+        shard_lookup = {
+            "North America": "north_america",
+            "Asia": "asia",
+            "Europe": "europe",
+            "South America": "south_america",
+        }
+
+        def shard_chooser(mapper, instance, clause=None):
+            if isinstance(instance, WeatherLocation):
+                return shard_lookup[instance.continent]
+            else:
+                return shard_chooser(mapper, instance.location)
+
+        def id_chooser(query, primary_key):
+            return ["north_america", "asia", "europe", "south_america"]
+
+        def query_chooser(query):
+            ids = []
+
+            class FindContinent(sql.ClauseVisitor):
+                def visit_binary(self, binary):
+                    if binary.left.shares_lineage(
+                        weather_locations.c.continent
+                    ):
+                        if binary.operator == operators.eq:
+                            ids.append(shard_lookup[binary.right.value])
+                        elif binary.operator == operators.in_op:
+                            for value in binary.right.value:
+                                ids.append(shard_lookup[value])
+
+            if isinstance(query, Select) and query.whereclause is not None:
+                FindContinent().traverse(query.whereclause)
+            if len(ids) == 0:
+                return ["north_america", "asia", "europe", "south_america"]
+            else:
+                return ids
+
+        sm = sessionmaker(class_=ShardedSession, autoflush=True)
+        sm.configure(
+            shards={
+                "north_america": db1,
+                "asia": db2,
+                "europe": db3,
+                "south_america": db4,
+            },
+            shard_chooser=shard_chooser,
+            id_chooser=id_chooser,
+            query_chooser=query_chooser,
+        )
+
+        def sharded_session():
+            with expect_deprecated(
+                "The ``id_chooser`` parameter is deprecated",
+                "The ``query_chooser`` parameter is deprecated",
+            ):
+                return sm()
+
+
 class AttachedFileShardTest(ShardTest, fixtures.MappedTest):
     """Use modern schema conventions along with SQLite ATTACH."""
 
@@ -723,7 +794,7 @@ class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest):
         session = ShardedSession(
             shards={"test": testing.db},
             shard_chooser=lambda *args: "test",
-            id_chooser=lambda *args: None,
+            identity_chooser=lambda *args: None,
             execute_chooser=lambda *args: ["test"],
         )
 
@@ -764,7 +835,7 @@ class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest):
         return ShardedSession(
             shards={"main": testing.db},
             shard_chooser=lambda *args: "main",
-            id_chooser=lambda *args: ["fake", "main"],
+            identity_chooser=lambda *args: ["fake", "main"],
             execute_chooser=lambda *args: ["fake", "main"],
             **kw,
         )
@@ -843,15 +914,23 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
             else:
                 assert False
 
-        def id_chooser(query, ident):
-            assert query.lazy_loaded_from
-            if isinstance(query.lazy_loaded_from.obj(), Book):
-                token = shard_for_book(query.lazy_loaded_from.obj())
-                assert query.lazy_loaded_from.identity_token == token
+        def identity_chooser(
+            mapper,
+            primary_key,
+            *,
+            lazy_loaded_from,
+            execution_options,
+            bind_arguments,
+            **kw,
+        ):
+            assert lazy_loaded_from
+            if isinstance(lazy_loaded_from.obj(), Book):
+                token = shard_for_book(lazy_loaded_from.obj())
+                assert lazy_loaded_from.identity_token == token
 
-            return [query.lazy_loaded_from.identity_token]
+            return [lazy_loaded_from.identity_token]
 
-        def no_query_chooser(orm_context):
+        def execute_chooser(orm_context):
             if (
                 orm_context.statement.column_descriptions[0]["type"] is Book
                 and lazy_load_book
@@ -878,8 +957,8 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest):
         session = ShardedSession(
             shards={"test": db1, "test2": db2},
             shard_chooser=shard_chooser,
-            id_chooser=id_chooser,
-            execute_chooser=no_query_chooser,
+            identity_chooser=identity_chooser,
+            execute_chooser=execute_chooser,
         )
 
         return session
index 56d2815fa52de86ef2c0dc2cb3eb7c494c6674bc..05d5d376dea8f41996ac759201c91476df854184 100644 (file)
@@ -1,6 +1,7 @@
 from unittest.mock import ANY
 from unittest.mock import call
 from unittest.mock import Mock
+from unittest.mock import patch
 
 import sqlalchemy as sa
 from sqlalchemy import bindparam
@@ -375,7 +376,6 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
             result.context.execution_options,
             {
                 "four": True,
-                "future_result": True,
                 "one": True,
                 "three": True,
                 "two": True,
@@ -741,7 +741,6 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
             {
                 "statement_two": True,
                 "statement_four": True,
-                "future_result": True,
                 "one": True,
                 "two": True,
                 "three": True,
@@ -751,6 +750,42 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest):
             },
         )
 
+    @testing.variation("session_start", [True, False])
+    @testing.variation("dest_autoflush", [True, False])
+    @testing.variation("stmt_type", ["select", "bulk", "dml"])
+    def test_autoflush_change(self, session_start, dest_autoflush, stmt_type):
+        User = self.classes.User
+
+        sess = fixture_session(autoflush=session_start)
+
+        @event.listens_for(sess, "do_orm_execute")
+        def do_orm_execute(ctx):
+            ctx.update_execution_options(autoflush=dest_autoflush)
+
+        with patch.object(sess, "_autoflush") as m1:
+            if stmt_type.select:
+                sess.execute(select(User))
+            elif stmt_type.bulk:
+                sess.execute(
+                    insert(User),
+                    [
+                        {"id": 1, "name": "n1"},
+                        {"id": 2, "name": "n2"},
+                        {"id": 3, "name": "n3"},
+                    ],
+                )
+            elif stmt_type.dml:
+                sess.execute(
+                    update(User).where(User.id == 2).values(name="nn")
+                )
+            else:
+                stmt_type.fail()
+
+        if dest_autoflush:
+            eq_(m1.mock_calls, [call()])
+        else:
+            eq_(m1.mock_calls, [])
+
 
 class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest):
     run_inserts = None
index 7966006cf56d972ff92af6c5251f4c8733254b27..9e303a778b0e64498f8793db1479fe9735637703 100644 (file)
@@ -5507,6 +5507,7 @@ class YieldTest(_fixtures.FixtureTest):
         @event.listens_for(sess, "do_orm_execute")
         def check(ctx):
             eq_(ctx.load_options._yield_per, 15)
+            return
             eq_(
                 {
                     k: v
@@ -5516,7 +5517,6 @@ class YieldTest(_fixtures.FixtureTest):
                 {
                     "yield_per": 15,
                     "foo": "bar",
-                    "future_result": True,
                 },
             )
 
@@ -5535,6 +5535,7 @@ class YieldTest(_fixtures.FixtureTest):
         @event.listens_for(sess, "do_orm_execute")
         def check(ctx):
             eq_(ctx.load_options._yield_per, 15)
+
             eq_(
                 {
                     k: v
@@ -5543,7 +5544,6 @@ class YieldTest(_fixtures.FixtureTest):
                 },
                 {
                     "yield_per": 15,
-                    "future_result": True,
                 },
             )
 
@@ -5553,8 +5553,8 @@ class YieldTest(_fixtures.FixtureTest):
         assert isinstance(
             result.raw.cursor_strategy, _cursor.BufferedRowCursorFetchStrategy
         )
+        eq_(result._yield_per, 15)
         eq_(result.raw.cursor_strategy._max_row_buffer, 15)
-
         eq_(len(result.all()), 4)
 
     def test_no_joinedload_opt(self):
@@ -7515,23 +7515,80 @@ class ExecutionOptionsTest(QueryTest):
         assert u.addresses[0].email_address == "jack@bean.com"
         assert u.orders[1].items[2].description == "item 5"
 
-    def test_option_transfer_future(self):
+    @testing.variation("source", ["statement", "do_orm_exec"])
+    def test_execution_options_to_load_options(self, source):
         User = self.classes.User
-        stmt = select(User).execution_options(
-            populate_existing=True, autoflush=False, yield_per=10
-        )
+
+        stmt = select(User)
+
+        if source.statement:
+            stmt = stmt.execution_options(
+                populate_existing=True,
+                autoflush=False,
+                yield_per=10,
+                identity_token="some_token",
+            )
         s = fixture_session()
 
         m1 = mock.Mock()
 
-        event.listen(s, "do_orm_execute", m1)
+        def do_orm_execute(ctx):
+            m1(ctx)
+            if source.do_orm_exec:
+                ctx.update_execution_options(
+                    autoflush=False,
+                    populate_existing=True,
+                    yield_per=10,
+                    identity_token="some_token",
+                )
+
+        event.listen(s, "do_orm_execute", do_orm_execute)
+
+        from sqlalchemy.orm import loading
+
+        with mock.patch.object(loading, "instances") as m2:
+            s.execute(stmt)
+
+        if source.do_orm_exec:
+            # in do_orm_exec version, load options are empty, our new
+            # execution options have not yet been transferred.
+            eq_(
+                m1.mock_calls[0][1][0].load_options,
+                QueryContext.default_load_options,
+            )
+        elif source.statement:
+            # in statement version, the incoming exc options have been
+            # transferred, because the fact that do_orm_exec is used
+            # means the options were set up up front for the benefit
+            # of the do_orm_exec hook itself.
+            eq_(
+                m1.mock_calls[0][1][0].load_options,
+                QueryContext.default_load_options(
+                    _autoflush=False,
+                    _populate_existing=True,
+                    _yield_per=10,
+                    _identity_token="some_token",
+                ),
+            )
+
+        # py37 mock does not have .args
+        call_args = m2.mock_calls[0][1]
 
-        s.execute(stmt)
+        cursor = call_args[0]
+        cursor.all()
 
+        # the orm_pre_session_exec() method
+        # was called unconditionally after the event handler
+        # in both cases (i.e. a second time) so options were transferred
+        # even if we set them up in the do_orm_exec hook only.
+        query_context = call_args[1]
         eq_(
-            m1.mock_calls[0][1][0].load_options,
+            query_context.load_options,
             QueryContext.default_load_options(
-                _autoflush=False, _populate_existing=True, _yield_per=10
+                _autoflush=False,
+                _populate_existing=True,
+                _yield_per=10,
+                _identity_token="some_token",
             ),
         )
 
index 79ea5d17031521f28907cba50b06f76776dd1d88..921c55f74a74efb6d0694d773f4dd42f3b494f47 100644 (file)
@@ -1,5 +1,8 @@
+from __future__ import annotations
+
 import inspect as _py_inspect
 import pickle
+from typing import TYPE_CHECKING
 
 import sqlalchemy as sa
 from sqlalchemy import delete
@@ -48,6 +51,9 @@ from sqlalchemy.testing.util import gc_collect
 from sqlalchemy.util.compat import inspect_getfullargspec
 from test.orm import _fixtures
 
+if TYPE_CHECKING:
+    from sqlalchemy.orm import ORMExecuteState
+
 
 class ExecutionTest(_fixtures.FixtureTest):
     run_inserts = None
@@ -563,7 +569,10 @@ class SessionUtilTest(_fixtures.FixtureTest):
             u1,
         )
 
-    def test_get_execution_option(self):
+    @testing.variation(
+        "arg", ["execution_options", "identity_token", "bind_arguments"]
+    )
+    def test_get_arguments(self, arg: testing.Variation) -> None:
         users, User = self.tables.users, self.classes.User
 
         self.mapper_registry.map_imperatively(User, users)
@@ -571,12 +580,28 @@ class SessionUtilTest(_fixtures.FixtureTest):
         called = False
 
         @event.listens_for(sess, "do_orm_execute")
-        def check(ctx):
+        def check(ctx: ORMExecuteState) -> None:
             nonlocal called
             called = True
-            eq_(ctx.execution_options["foo"], "bar")
 
-        sess.get(User, 42, execution_options={"foo": "bar"})
+            if arg.execution_options:
+                eq_(ctx.execution_options["foo"], "bar")
+            elif arg.bind_arguments:
+                eq_(ctx.bind_arguments["foo"], "bar")
+            elif arg.identity_token:
+                eq_(ctx.load_options._identity_token, "foobar")
+            else:
+                arg.fail()
+
+        if arg.execution_options:
+            sess.get(User, 42, execution_options={"foo": "bar"})
+        elif arg.bind_arguments:
+            sess.get(User, 42, bind_arguments={"foo": "bar"})
+        elif arg.identity_token:
+            sess.get(User, 42, identity_token="foobar")
+        else:
+            arg.fail()
+
         sess.close()
 
         is_true(called)