From 4338213935b4133e36d593ceec75f7fe36c13f66 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Sun, 18 Dec 2022 16:33:22 -0500 Subject: [PATCH] reorganize pre_session_exec around do_orm_execute Allow do_orm_execute() events to both receive the complete state of bind_argments, load_options, update_delete_options as they do already, but also allow them to *change* all those things via new execution options. Options like autoflush, populate_existing etc. can now be updated within a do_orm_execute() hook and those changes will take effect all the way through. Took a few tries to get something that covers every case here, in particular horizontal sharding which is consuming those options as well as using context.invoke(), without excess complexity. The good news seems to be that a simple reorg and replacing the "reentrant" boolean with "is this before do_orm_execute is invoked" was all that was needed. As part of this we add a new "identity_token" option allowing this option to be controlled from do_orm_execute() as well as from the outside. WIP Fixes: #7837 Change-Id: I087728215edec8d1b1712322ab389e3f52ff76ba --- doc/build/changelog/unreleased_20/7837.rst | 40 +++++ doc/build/glossary.rst | 16 ++ doc/build/orm/queryguide/api.rst | 132 +++++++++++++++++ examples/sharding/separate_databases.py | 10 +- .../sharding/separate_schema_translates.py | 13 +- examples/sharding/separate_tables.py | 10 +- lib/sqlalchemy/ext/horizontal_shard.py | 138 ++++++++++-------- lib/sqlalchemy/orm/bulk_persistence.py | 118 +++++++-------- lib/sqlalchemy/orm/context.py | 27 ++-- lib/sqlalchemy/orm/loading.py | 2 +- lib/sqlalchemy/orm/query.py | 2 +- lib/sqlalchemy/orm/session.py | 86 ++++++++--- test/ext/test_deprecations.py | 31 +++- test/ext/test_horizontal_shard.py | 107 ++++++++++++-- test/orm/test_events.py | 39 ++++- test/orm/test_query.py | 79 ++++++++-- test/orm/test_session.py | 33 ++++- 17 files changed, 677 insertions(+), 206 deletions(-) create mode 100644 doc/build/changelog/unreleased_20/7837.rst diff --git a/doc/build/changelog/unreleased_20/7837.rst b/doc/build/changelog/unreleased_20/7837.rst new file mode 100644 index 0000000000..1abb3e157d --- /dev/null +++ b/doc/build/changelog/unreleased_20/7837.rst @@ -0,0 +1,40 @@ +.. change:: + :tags: usecase, orm + :tickets: 7837 + + Adjustments to the :class:`_orm.Session` in terms of extensibility, + as well as updates to the :class:`.ShardedSession` extension: + + * :meth:`_orm.Session.get` now accepts + :paramref:`_orm.Session.get.bind_arguments`, which in particular may be + useful when using the horizontal sharding extension. + + * :meth:`_orm.Session.get_bind` accepts arbitrary kw arguments, which + assists in developing code that uses a :class:`_orm.Session` class which + overrides this method with additional arguments. + + * Added a new ORM execution option ``identity_token`` which may be used + to directly affect the "identity token" that will be associated with + newly loaded ORM objects. This token is how sharding approaches + (namely the :class:`.ShardedSession`, but can be used in other cases + as well) separate object identities across different "shards". + + .. seealso:: + + :ref:`queryguide_identity_token` + + * The :meth:`_orm.SessionEvents.do_orm_execute` event hook may now be used + to affect all ORM-related options, including ``autoflush``, + ``populate_existing``, and ``yield_per``; these options are re-consumed + subsequent to event hooks being invoked before they are acted upon. + Previously, options like ``autoflush`` would have been already evaluated + at this point. The new ``identity_token`` option is also supported in + this mode and is now used by the horizontal sharding extension. + + + * The :class:`.ShardedSession` class replaces the + :paramref:`.ShardedSession.id_chooser` hook with a new hook + :paramref:`.ShardedSession.identity_chooser`, which no longer relies upon + the legacy :class:`_orm.Query` object. + :paramref:`.ShardedSession.id_chooser` is still accepted in place of + :paramref:`.ShardedSession.identity_chooser` with a deprecation warning. diff --git a/doc/build/glossary.rst b/doc/build/glossary.rst index d0bc4f8148..70eb05e644 100644 --- a/doc/build/glossary.rst +++ b/doc/build/glossary.rst @@ -488,6 +488,19 @@ Glossary primary key identity within the database, as well as their unique identity within a :class:`_orm.Session` :term:`identity map`. + In SQLAlchemy, you can view the identity key for an ORM object + using the :func:`_sa.inspect` API to return the :class:`_orm.InstanceState` + tracking object, then looking at the :attr:`_orm.InstanceState.key` + attribute:: + + >>> from sqlalchemy import inspect + >>> inspect(some_object).key + (, (1,), None) + + .. seealso:: + + :term:`identity map` + identity map A mapping between Python objects and their database identities. The identity map is a collection that's associated with an @@ -505,6 +518,9 @@ Glossary `Identity Map (via Martin Fowler) `_ + :ref:`session_get` - how to look up an object in the identity map + by primary key + lazy initialization A tactic of delaying some initialization action, such as creating objects, populating data, or establishing connectivity to other services, until diff --git a/doc/build/orm/queryguide/api.rst b/doc/build/orm/queryguide/api.rst index 136b4b39bb..35259a3b38 100644 --- a/doc/build/orm/queryguide/api.rst +++ b/doc/build/orm/queryguide/api.rst @@ -280,6 +280,138 @@ will have the same result as that of the ``yield_per`` execution option. :ref:`engine_stream_results` +.. _queryguide_identity_token: + +Identity Token +^^^^^^^^^^^^^^ + +.. doctest-disable: + +.. deepalchemy:: This option is an advanced-use feature mostly intended + to be used with the :ref:`horizontal_sharding_toplevel` extension. For + typical cases of loading objects with identical primary keys from different + "shards" or partitions, consider using individual :class:`_orm.Session` + objects per shard first. + + +The "identity token" is an arbitrary value that can be associated within +the :term:`identity key` of newly loaded objects. This element exists +first and foremost to support extensions which perform per-row "sharding", +where objects may be loaded from any number of replicas of a particular +database table that nonetheless have overlapping primary key values. +The primary consumer of "identity token" is the +:ref:`horizontal_sharding_toplevel` extension, which supplies a general +framework for persisting objects among multiple "shards" of a particular +database table. + +The ``identity_token`` execution option may be used on a per-query basis +to directly affect this token. Using it directly, one can populate a +:class:`_orm.Session` with multiple instances of an object that have the +same primary key and source table, but different "identities". + +One such example is to populate a :class:`_orm.Session` with objects that +come from same-named tables in different schemas, using the +:ref:`schema_translating` feature which can affect the choice of schema +within the scope of queries. Given a mapping as: + +.. sourcecode:: python + + from sqlalchemy.orm import DeclarativeBase + from sqlalchemy.orm import Mapped + from sqlalchemy.orm import mapped_column + + + class Base(DeclarativeBase): + pass + + + class MyTable(Base): + __tablename__ = "my_table" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] + +The default "schema" name for the class above is ``None``, meaning, no +schema qualification will be written into SQL statements. However, +if we make use of :paramref:`_engine.Connection.execution_options.schema_translate_map`, +mapping ``None`` to an alternate schema, we can place instances of +``MyTable`` into two different schemas: + +.. sourcecode:: python + + engine = create_engine( + "postgresql+psycopg://scott:tiger@localhost/test", + ) + + with Session( + engine.execution_options(schema_translate_map={None: "test_schema"}) + ) as sess: + sess.add(MyTable(name="this is schema one")) + sess.commit() + + with Session( + engine.execution_options(schema_translate_map={None: "test_schema_2"}) + ) as sess: + sess.add(MyTable(name="this is schema two")) + sess.commit() + +The above two blocks create a :class:`_orm.Session` object linked to a different +schema translate map each time, and an instance of ``MyTable`` is persisted +into both ``test_schema.my_table`` as well as ``test_schema_2.my_table``. + +The :class:`_orm.Session` objects above are independent. If we wanted to +persist both objects in one transaction, we would need to use the +:ref:`horizontal_sharding_toplevel` extension to do this. + +However, we can illustrate querying for these objects in one session as follows: + +.. sourcecode:: python + + with Session(engine) as sess: + obj1 = sess.scalar( + select(MyTable) + .where(MyTable.id == 1) + .execution_options( + schema_translate_map={None: "test_schema"}, + identity_token="test_schema", + ) + ) + obj2 = sess.scalar( + select(MyTable) + .where(MyTable.id == 1) + .execution_options( + schema_translate_map={None: "test_schema_2"}, + identity_token="test_schema_2", + ) + ) + +Both ``obj1`` and ``obj2`` are distinct from each other. However, they both +refer to primary key id 1 for the ``MyTable`` class, yet are distinct. +This is how the ``identity_token`` comes into play, which we can see in the +inspection of each object, where we look at :attr:`_orm.InstanceState.key` +to view the two distinct identity tokens:: + + >>> from sqlalchemy import inspect + >>> inspect(obj1).key + (, (1,), 'test_schema') + >>> inspect(obj2).key + (, (1,), 'test_schema_2') + + +The above logic takes place automatically when using the +:ref:`horizontal_sharding_toplevel` extension. + +.. versionadded:: 2.0.0b5 - added the ``identity_token`` ORM level execution + option. + +.. seealso:: + + :ref:`examples_sharding` - in the :ref:`examples_toplevel` section. + See the script ``separate_schema_translates.py`` for a demonstration of + the above use case using the full sharding API. + + +.. doctest-enable: .. _queryguide_inspection: diff --git a/examples/sharding/separate_databases.py b/examples/sharding/separate_databases.py index a45182f42d..fe92fd3bac 100644 --- a/examples/sharding/separate_databases.py +++ b/examples/sharding/separate_databases.py @@ -135,8 +135,8 @@ def shard_chooser(mapper, instance, clause=None): return shard_chooser(mapper, instance.location) -def id_chooser(query, ident): - """id chooser. +def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw): + """identity chooser. given a primary key, returns a list of shards to search. here, we don't have any particular information from a @@ -145,11 +145,11 @@ def id_chooser(query, ident): distributed among DBs. """ - if query.lazy_loaded_from: + if lazy_loaded_from: # if we are in a lazy load, we can look at the parent object # and limit our search to that same shard, assuming that's how we've # set things up. - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] else: return ["north_america", "asia", "europe", "south_america"] @@ -237,7 +237,7 @@ def _get_select_comparisons(statement): # further configure create_session to use these functions Session.configure( shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) diff --git a/examples/sharding/separate_schema_translates.py b/examples/sharding/separate_schema_translates.py index 2d4c2a0464..f7bdc62500 100644 --- a/examples/sharding/separate_schema_translates.py +++ b/examples/sharding/separate_schema_translates.py @@ -130,21 +130,20 @@ def shard_chooser(mapper, instance, clause=None): return shard_chooser(mapper, instance.location) -def id_chooser(query, ident): - """id chooser. +def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw): + """identity chooser. - given a primary key identity and a legacy :class:`_orm.Query`, - return which shard we should look at. + given a primary key identity, return which shard we should look at. in this case, we only want to support this for lazy-loaded items; any primary query should have shard id set up front. """ - if query.lazy_loaded_from: + if lazy_loaded_from: # if we are in a lazy load, we can look at the parent object # and limit our search to that same shard, assuming that's how we've # set things up. - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] else: raise NotImplementedError() @@ -169,7 +168,7 @@ def execute_chooser(context): # configure shard chooser Session.configure( shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) diff --git a/examples/sharding/separate_tables.py b/examples/sharding/separate_tables.py index 8f39471e88..97c6a07f6a 100644 --- a/examples/sharding/separate_tables.py +++ b/examples/sharding/separate_tables.py @@ -149,8 +149,8 @@ def shard_chooser(mapper, instance, clause=None): return shard_chooser(mapper, instance.location) -def id_chooser(query, ident): - """id chooser. +def identity_chooser(mapper, primary_key, *, lazy_loaded_from, **kw): + """identity chooser. given a primary key, returns a list of shards to search. here, we don't have any particular information from a @@ -159,11 +159,11 @@ def id_chooser(query, ident): distributed among DBs. """ - if query.lazy_loaded_from: + if lazy_loaded_from: # if we are in a lazy load, we can look at the parent object # and limit our search to that same shard, assuming that's how we've # set things up. - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] else: return ["north_america", "asia", "europe", "south_america"] @@ -251,7 +251,7 @@ def _get_select_comparisons(statement): # further configure create_session to use these functions Session.configure( shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) diff --git a/lib/sqlalchemy/ext/horizontal_shard.py b/lib/sqlalchemy/ext/horizontal_shard.py index 69767ad6cb..fd53c60468 100644 --- a/lib/sqlalchemy/ext/horizontal_shard.py +++ b/lib/sqlalchemy/ext/horizontal_shard.py @@ -13,11 +13,14 @@ distribute queries and persistence operations across multiple databases. For a usage example, see the :ref:`examples_sharding` example included in the source distribution. -.. legacy:: The horizontal sharding API is not fully updated for the - SQLAlchemy 2.0 API, and still relies in part on the - legacy :class:`.Query` architecture, in particular as part of the - signature for the :paramref:`.ShardedSession.id_chooser` parameter. - This may change in a future release. +.. deepalchemy:: The horizontal sharding extension is an advanced feature, + involving a complex statement -> database interaction as well as + use of semi-public APIs for non-trivial cases. Simpler approaches to + refering to multiple database "shards", most commonly using a distinct + :class:`_orm.Session` per "shard", should always be considered first + before using this more complex and less-production-tested system. + + """ from __future__ import annotations @@ -38,8 +41,11 @@ from .. import exc from .. import inspect from .. import util from ..orm import PassiveFlag +from ..orm._typing import OrmExecuteOptionsParameter from ..orm.mapper import Mapper from ..orm.query import Query +from ..orm.session import _BindArguments +from ..orm.session import _PKIdentityArgument from ..orm.session import Session from ..util.typing import Protocol @@ -80,6 +86,20 @@ class ShardChooser(Protocol): ... +class IdentityChooser(Protocol): + def __call__( + self, + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + ... + + class ShardedQuery(Query[_T]): """Query class used with :class:`.ShardedSession`. @@ -94,8 +114,7 @@ class ShardedQuery(Query[_T]): super().__init__(*args, **kwargs) assert isinstance(self.session, ShardedSession) - self.id_chooser = self.session.id_chooser - self.query_chooser = self.session.query_chooser + self.identity_chooser = self.session.identity_chooser self.execute_chooser = self.session.execute_chooser self._shard_id = None @@ -119,19 +138,22 @@ class ShardedQuery(Query[_T]): class ShardedSession(Session): shard_chooser: ShardChooser - id_chooser: Callable[[Query[Any], Iterable[Any]], Iterable[Any]] + identity_chooser: IdentityChooser execute_chooser: Callable[[ORMExecuteState], Iterable[Any]] def __init__( self, shard_chooser: ShardChooser, - id_chooser: Callable[[Query[_T], Iterable[_T]], Iterable[Any]], + identity_chooser: Optional[IdentityChooser] = None, execute_chooser: Optional[ Callable[[ORMExecuteState], Iterable[Any]] ] = None, shards: Optional[Dict[str, Any]] = None, query_cls: Type[Query[_T]] = ShardedQuery, *, + id_chooser: Optional[ + Callable[[Query[_T], Iterable[_T]], Iterable[Any]] + ] = None, query_chooser: Optional[Callable[[Executable], Iterable[Any]]] = None, **kwargs: Any, ) -> None: @@ -171,12 +193,41 @@ class ShardedSession(Session): self, "do_orm_execute", execute_and_instances, retval=True ) self.shard_chooser = shard_chooser - self.id_chooser = id_chooser + + if id_chooser: + _id_chooser = id_chooser + util.warn_deprecated( + "The ``id_chooser`` parameter is deprecated; " + "please use ``identity_chooser``.", + "2.0", + ) + + def _legacy_identity_chooser( + mapper: Mapper[_T], + primary_key: _PKIdentityArgument, + *, + lazy_loaded_from: Optional[InstanceState[Any]], + execution_options: OrmExecuteOptionsParameter, + bind_arguments: _BindArguments, + **kw: Any, + ) -> Any: + q = self.query(mapper) + if lazy_loaded_from: + q = q._set_lazyload_from(lazy_loaded_from) + return _id_chooser(q, primary_key) + + self.identity_chooser = _legacy_identity_chooser + elif identity_chooser: + self.identity_chooser = identity_chooser + else: + raise exc.ArgumentError( + "identity_chooser or id_chooser is required" + ) if query_chooser: _query_chooser = query_chooser util.warn_deprecated( - "The ``query_choser`` parameter is deprecated; " + "The ``query_chooser`` parameter is deprecated; " "please use ``execute_chooser``.", "1.4", ) @@ -199,7 +250,6 @@ class ShardedSession(Session): "execute_chooser or query_chooser is required" ) self.execute_chooser = execute_chooser - self.query_chooser = query_chooser self.__shards: Dict[_ShardKey, _SessionBind] = {} if shards is not None: for k in shards: @@ -212,6 +262,8 @@ class ShardedSession(Session): identity_token: Optional[Any] = None, passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, **kw: Any, ) -> Union[Optional[_O], LoaderCallableStatus]: """override the default :meth:`.Session._identity_lookup` method so @@ -233,10 +285,13 @@ class ShardedSession(Session): return obj else: - q = self.query(mapper) - if lazy_loaded_from: - q = q._set_lazyload_from(lazy_loaded_from) - for shard_id in self.id_chooser(q, primary_key_identity): + for shard_id in self.identity_chooser( + mapper, + primary_key_identity, + lazy_loaded_from=lazy_loaded_from, + execution_options=execution_options, + bind_arguments=dict(bind_arguments) if bind_arguments else {}, + ): obj2 = super()._identity_lookup( mapper, primary_key_identity, @@ -325,11 +380,6 @@ class ShardedSession(Session): def execute_and_instances( orm_context: ORMExecuteState, ) -> Union[Result[_T], IteratorResult[_TP]]: - update_options: Union[ - None, - BulkUDCompileState.default_update_options, - Type[BulkUDCompileState.default_update_options], - ] active_options: Union[ None, QueryContext.default_load_options, @@ -337,58 +387,30 @@ def execute_and_instances( BulkUDCompileState.default_update_options, Type[BulkUDCompileState.default_update_options], ] - load_options: Union[ - None, - QueryContext.default_load_options, - Type[QueryContext.default_load_options], - ] if orm_context.is_select: - load_options = active_options = orm_context.load_options - update_options = None + active_options = orm_context.load_options elif orm_context.is_update or orm_context.is_delete: - load_options = None - update_options = active_options = orm_context.update_delete_options + active_options = orm_context.update_delete_options else: - load_options = update_options = active_options = None + active_options = None session = orm_context.session assert isinstance(session, ShardedSession) def iter_for_shard( shard_id: str, - load_options: Union[ - None, - QueryContext.default_load_options, - Type[QueryContext.default_load_options], - ], - update_options: Union[ - None, - BulkUDCompileState.default_update_options, - Type[BulkUDCompileState.default_update_options], - ], ) -> Union[Result[_T], IteratorResult[_TP]]: - execution_options = dict(orm_context.local_execution_options) bind_arguments = dict(orm_context.bind_arguments) bind_arguments["shard_id"] = shard_id - if orm_context.is_select: - assert load_options is not None - load_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_load_options"] = load_options - elif orm_context.is_update or orm_context.is_delete: - assert update_options is not None - update_options += {"_refresh_identity_token": shard_id} - execution_options["_sa_orm_update_options"] = update_options - - return orm_context.invoke_statement( - bind_arguments=bind_arguments, execution_options=execution_options - ) + orm_context.update_execution_options(identity_token=shard_id) + return orm_context.invoke_statement(bind_arguments=bind_arguments) - if active_options and active_options._refresh_identity_token is not None: - shard_id = active_options._refresh_identity_token + if active_options and active_options._identity_token is not None: + shard_id = active_options._identity_token elif "_sa_shard_id" in orm_context.execution_options: shard_id = orm_context.execution_options["_sa_shard_id"] elif "shard_id" in orm_context.bind_arguments: @@ -397,10 +419,10 @@ def execute_and_instances( shard_id = None if shard_id is not None: - return iter_for_shard(shard_id, load_options, update_options) + return iter_for_shard(shard_id) else: partial = [] for shard_id in session.execute_chooser(orm_context): - result_ = iter_for_shard(shard_id, load_options, update_options) + result_ = iter_for_shard(shard_id) partial.append(result_) return partial[0].merge(*partial[1:]) diff --git a/lib/sqlalchemy/orm/bulk_persistence.py b/lib/sqlalchemy/orm/bulk_persistence.py index 181dbd4a28..805bfdc65e 100644 --- a/lib/sqlalchemy/orm/bulk_persistence.py +++ b/lib/sqlalchemy/orm/bulk_persistence.py @@ -555,7 +555,7 @@ class BulkUDCompileState(ORMDMLState): _resolved_values = EMPTY_DICT _eval_condition = None _matched_rows = None - _refresh_identity_token = None + _identity_token = None @classmethod def can_use_returning( @@ -577,10 +577,8 @@ class BulkUDCompileState(ORMDMLState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): - if is_reentrant_invoke: - return statement, execution_options ( update_options, @@ -590,6 +588,7 @@ class BulkUDCompileState(ORMDMLState): { "synchronize_session", "autoflush", + "identity_token", "is_delete_using", "is_update_from", "dml_strategy", @@ -637,55 +636,56 @@ class BulkUDCompileState(ORMDMLState): "for 'bulk' ORM updates (i.e. multiple parameter sets)" ) - if update_options._autoflush: - session._autoflush() - - if update_options._dml_strategy == "orm": + if not is_pre_event: + if update_options._autoflush: + session._autoflush() - if update_options._synchronize_session == "auto": - update_options = cls._do_pre_synchronize_auto( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "evaluate": - update_options = cls._do_pre_synchronize_evaluate( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._synchronize_session == "fetch": - update_options = cls._do_pre_synchronize_fetch( - session, - statement, - params, - execution_options, - bind_arguments, - update_options, - ) - elif update_options._dml_strategy == "bulk": - if update_options._synchronize_session == "auto": - update_options += {"_synchronize_session": "evaluate"} + if update_options._dml_strategy == "orm": - # indicators from the "pre exec" step that are then - # added to the DML statement, which will also be part of the cache - # key. The compile level create_for_statement() method will then - # consume these at compiler time. - statement = statement._annotate( - { - "synchronize_session": update_options._synchronize_session, - "is_delete_using": update_options._is_delete_using, - "is_update_from": update_options._is_update_from, - "dml_strategy": update_options._dml_strategy, - "can_use_returning": update_options._can_use_returning, - } - ) + if update_options._synchronize_session == "auto": + update_options = cls._do_pre_synchronize_auto( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "evaluate": + update_options = cls._do_pre_synchronize_evaluate( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._synchronize_session == "fetch": + update_options = cls._do_pre_synchronize_fetch( + session, + statement, + params, + execution_options, + bind_arguments, + update_options, + ) + elif update_options._dml_strategy == "bulk": + if update_options._synchronize_session == "auto": + update_options += {"_synchronize_session": "evaluate"} + + # indicators from the "pre exec" step that are then + # added to the DML statement, which will also be part of the cache + # key. The compile level create_for_statement() method will then + # consume these at compiler time. + statement = statement._annotate( + { + "synchronize_session": update_options._synchronize_session, + "is_delete_using": update_options._is_delete_using, + "is_update_from": update_options._is_update_from, + "dml_strategy": update_options._dml_strategy, + "can_use_returning": update_options._can_use_returning, + } + ) return ( statement, @@ -836,7 +836,7 @@ class BulkUDCompileState(ORMDMLState): if state.mapper.isa(mapper) and not state.expired ] - identity_token = update_options._refresh_identity_token + identity_token = update_options._identity_token if identity_token is not None: raw_data = [ (obj, state, dict_) @@ -1091,7 +1091,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): ( @@ -1143,7 +1143,7 @@ class BulkORMInsert(ORMDMLState, InsertDMLState): context._orm_load_exec_options ) - if insert_options._autoflush: + if not is_pre_event and insert_options._autoflush: session._autoflush() statement = statement._annotate( @@ -1577,7 +1577,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): for param in params: identity_key = mapper.identity_key_from_primary_key( (param[key] for key in pk_keys), - update_options._refresh_identity_token, + update_options._identity_token, ) state = identity_map.fast_get_state(identity_key) if not state: @@ -1635,7 +1635,7 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): ) matched_rows = [ - tuple(row) + (update_options._refresh_identity_token,) + tuple(row) + (update_options._identity_token,) for row in pk_rows ] else: @@ -1651,8 +1651,8 @@ class BulkORMUpdate(BulkUDCompileState, UpdateDMLState): for primary_key, identity_token in [ (row[0:-1], row[-1]) for row in matched_rows ] - if update_options._refresh_identity_token is None - or identity_token == update_options._refresh_identity_token + if update_options._identity_token is None + or identity_token == update_options._identity_token ] if identity_key in session.identity_map ] @@ -1912,7 +1912,7 @@ class BulkORMDelete(BulkUDCompileState, DeleteDMLState): ) matched_rows = [ - tuple(row) + (update_options._refresh_identity_token,) + tuple(row) + (update_options._identity_token,) for row in pk_rows ] else: diff --git a/lib/sqlalchemy/orm/context.py b/lib/sqlalchemy/orm/context.py index 3bd8b02a71..b3478b83e1 100644 --- a/lib/sqlalchemy/orm/context.py +++ b/lib/sqlalchemy/orm/context.py @@ -135,7 +135,7 @@ class QueryContext: _version_check = False _invoke_all_eagers = True _autoflush = True - _refresh_identity_token = None + _identity_token = None _yield_per = None _refresh_state = None _lazy_loaded_from = None @@ -194,14 +194,14 @@ class QueryContext: self.version_check = load_options._version_check self.refresh_state = load_options._refresh_state self.yield_per = load_options._yield_per - self.identity_token = load_options._refresh_identity_token + self.identity_token = load_options._identity_token def _get_top_level_context(self) -> QueryContext: return self.top_level_context or self _orm_load_exec_options = util.immutabledict( - {"_result_disable_adapt_to_context": True, "future_result": True} + {"_result_disable_adapt_to_context": True} ) @@ -235,7 +235,7 @@ class AbstractORMCompileState(CompileState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): raise NotImplementedError() @@ -384,11 +384,11 @@ class ORMCompileState(AbstractORMCompileState): params, execution_options, bind_arguments, - is_reentrant_invoke, + is_pre_event, ): - if is_reentrant_invoke: - return statement, execution_options + # consume result-level load_options. These may have been set up + # in an ORMExecuteState hook ( load_options, execution_options, @@ -398,26 +398,24 @@ class ORMCompileState(AbstractORMCompileState): "populate_existing", "autoflush", "yield_per", + "identity_token", "sa_top_level_orm_context", }, execution_options, statement._execution_options, ) + # default execution options for ORM results: # 1. _result_disable_adapt_to_context=True # this will disable the ResultSetMetadata._adapt_to_context() # step which we don't need, as we have result processors cached # against the original SELECT statement before caching. - # 2. future_result=True. The ORM should **never** resolve columns - # in a result set based on names, only on Column objects that - # are correctly adapted to the context. W the legacy result - # it will still attempt name-based resolution and also emit a - # warning. if not execution_options: execution_options = _orm_load_exec_options else: execution_options = execution_options.union(_orm_load_exec_options) + # would have been placed here by legacy Query only if load_options._yield_per: execution_options = execution_options.union( {"yield_per": load_options._yield_per} @@ -457,7 +455,7 @@ class ORMCompileState(AbstractORMCompileState): if plugin_subject: bind_arguments["mapper"] = plugin_subject.mapper - if load_options._autoflush: + if not is_pre_event and load_options._autoflush: session._autoflush() return statement, execution_options @@ -483,6 +481,7 @@ class ORMCompileState(AbstractORMCompileState): load_options = execution_options.get( "_sa_orm_load_options", QueryContext.default_load_options ) + if compile_state.compile_options._is_star: return result @@ -3119,6 +3118,6 @@ class _IdentityTokenEntity(_ORMColumnEntity): def row_processor(self, context, result): def getter(row): - return context.load_options._refresh_identity_token + return context.load_options._identity_token return getter, self._label_name, self._extra_entities diff --git a/lib/sqlalchemy/orm/loading.py b/lib/sqlalchemy/orm/loading.py index 6e7695f861..f331cd63b0 100644 --- a/lib/sqlalchemy/orm/loading.py +++ b/lib/sqlalchemy/orm/loading.py @@ -701,7 +701,7 @@ def _set_get_options( if only_load_props: compile_options["_only_load_props"] = frozenset(only_load_props) if identity_token: - load_options["_refresh_identity_token"] = identity_token + load_options["_identity_token"] = identity_token if load_options: load_opt += load_options diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 01db08eb46..d2bd930ff4 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -470,7 +470,7 @@ class Query( if only_load_props: compile_options["_only_load_props"] = frozenset(only_load_props) if identity_token: - load_options["_refresh_identity_token"] = identity_token + load_options["_identity_token"] = identity_token if load_options: self.load_options += load_options diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index bf3df05990..8b5f7c88ab 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -267,6 +267,7 @@ class ORMExecuteState(util.MemoizedSlots): "execution_options", "local_execution_options", "bind_arguments", + "identity_token", "_compile_state_cls", "_starting_event_idx", "_events_todo", @@ -579,9 +580,8 @@ class ORMExecuteState(util.MemoizedSlots): def _is_crud(self) -> bool: return isinstance(self.statement, (dml.Update, dml.Delete)) - def update_execution_options(self, **opts: _ExecuteOptions) -> None: + def update_execution_options(self, **opts: Any) -> None: """Update the local execution options with new values.""" - # TODO: no coverage self.local_execution_options = self.local_execution_options.union(opts) def _orm_compile_options( @@ -1912,27 +1912,10 @@ class Session(_SessionClassMethods, EventTarget): ) else: compile_state_cls = None + bind_arguments.setdefault("clause", statement) execution_options = util.coerce_to_immutabledict(execution_options) - if compile_state_cls is not None: - ( - statement, - execution_options, - ) = compile_state_cls.orm_pre_session_exec( - self, - statement, - params, - execution_options, - bind_arguments, - _parent_execute_state is not None, - ) - else: - bind_arguments.setdefault("clause", statement) - execution_options = execution_options.union( - {"future_result": True} - ) - if _parent_execute_state: events_todo = _parent_execute_state._remaining_events() else: @@ -1941,6 +1924,25 @@ class Session(_SessionClassMethods, EventTarget): events_todo = list(events_todo) + [_add_event] if events_todo: + if compile_state_cls is not None: + # for event handlers, do the orm_pre_session_exec + # pass ahead of the event handlers, so that things like + # .load_options, .update_delete_options etc. are populated. + # is_pre_event=True allows the hook to hold off on things + # it doesn't want to do twice, including autoflush as well + # as "pre fetch" for DML, etc. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + True, + ) + orm_exec_state = ORMExecuteState( self, statement, @@ -1962,6 +1964,24 @@ class Session(_SessionClassMethods, EventTarget): statement = orm_exec_state.statement execution_options = orm_exec_state.local_execution_options + if compile_state_cls is not None: + # now run orm_pre_session_exec() "for real". if there were + # event hooks, this will re-run the steps that interpret + # new execution_options into load_options / update_delete_options, + # which we assume the event hook might have updated. + # autoflush will also be invoked in this step if enabled. + ( + statement, + execution_options, + ) = compile_state_cls.orm_pre_session_exec( + self, + statement, + params, + execution_options, + bind_arguments, + False, + ) + bind = self.get_bind(**bind_arguments) conn = self._connection_for_bind(bind) @@ -2379,6 +2399,7 @@ class Session(_SessionClassMethods, EventTarget): bind: Optional[_SessionBind] = None, _sa_skip_events: Optional[bool] = None, _sa_skip_for_implicit_returning: bool = False, + **kw: Any, ) -> Union[Engine, Connection]: """Return a "bind" to which this :class:`.Session` is bound. @@ -2653,6 +2674,8 @@ class Session(_SessionClassMethods, EventTarget): identity_token: Any = None, passive: PassiveFlag = PassiveFlag.PASSIVE_OFF, lazy_loaded_from: Optional[InstanceState[Any]] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, ) -> Union[Optional[_O], LoaderCallableStatus]: """Locate an object in the identity map. @@ -3262,6 +3285,7 @@ class Session(_SessionClassMethods, EventTarget): with_for_update: Optional[ForUpdateArg] = None, identity_token: Optional[Any] = None, execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, ) -> Optional[_O]: """Return an instance based on the given primary key identifier, or ``None`` if not found. @@ -3355,6 +3379,13 @@ class Session(_SessionClassMethods, EventTarget): :ref:`orm_queryguide_execution_options` - ORM-specific execution options + :param bind_arguments: dictionary of additional arguments to determine + the bind. May include "mapper", "bind", or other custom arguments. + Contents of this dictionary are passed to the + :meth:`.Session.get_bind` method. + + .. versionadded: 2.0.0b5 + :return: The object instance, or ``None``. """ @@ -3367,6 +3398,7 @@ class Session(_SessionClassMethods, EventTarget): with_for_update=with_for_update, identity_token=identity_token, execution_options=execution_options, + bind_arguments=bind_arguments, ) def _get_impl( @@ -3379,7 +3411,8 @@ class Session(_SessionClassMethods, EventTarget): populate_existing: bool = False, with_for_update: Optional[ForUpdateArg] = None, identity_token: Optional[Any] = None, - execution_options: Optional[OrmExecuteOptionsParameter] = None, + execution_options: OrmExecuteOptionsParameter = util.EMPTY_DICT, + bind_arguments: Optional[_BindArguments] = None, ) -> Optional[_O]: # convert composite types to individual args @@ -3453,7 +3486,11 @@ class Session(_SessionClassMethods, EventTarget): ): instance = self._identity_lookup( - mapper, primary_key_identity, identity_token=identity_token + mapper, + primary_key_identity, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, ) if instance is not None: @@ -3484,13 +3521,14 @@ class Session(_SessionClassMethods, EventTarget): if options: statement = statement.options(*options) - if execution_options: - statement = statement.execution_options(**execution_options) return db_load_fn( self, statement, primary_key_identity, load_options=load_options, + identity_token=identity_token, + execution_options=execution_options, + bind_arguments=bind_arguments, ) def merge( diff --git a/test/ext/test_deprecations.py b/test/ext/test_deprecations.py index 09f904487a..97c4172ba7 100644 --- a/test/ext/test_deprecations.py +++ b/test/ext/test_deprecations.py @@ -1,3 +1,5 @@ +from sqlalchemy import Column +from sqlalchemy import Integer from sqlalchemy import testing from sqlalchemy.ext.automap import automap_base from sqlalchemy.ext.horizontal_shard import ShardedSession @@ -68,7 +70,7 @@ class HorizontalShardTest(fixtures.TestBase): m1 = mock.Mock() with testing.expect_deprecated( - "The ``query_choser`` parameter is deprecated; please use" + "The ``query_chooser`` parameter is deprecated; please use" ): s = ShardedSession( shard_chooser=m1.shard_chooser, @@ -80,3 +82,30 @@ class HorizontalShardTest(fixtures.TestBase): s.execute_chooser(m2) eq_(m1.mock_calls, [mock.call.query_chooser(m2.statement)]) + + def test_id_chooser(self, decl_base): + class A(decl_base): + __tablename__ = "a" + id = Column(Integer, primary_key=True) + + m1 = mock.Mock() + + with testing.expect_deprecated( + "The ``id_chooser`` parameter is deprecated; please use" + ): + s = ShardedSession( + shard_chooser=m1.shard_chooser, + id_chooser=m1.id_chooser, + execute_chooser=m1.execute_chooser, + ) + + m2 = mock.Mock() + s.identity_chooser( + A.__mapper__, + m2.primary_key, + lazy_loaded_from=m2.lazy_loaded_from, + execution_options=m2.execution_options, + bind_arguments=m2.bind_arguments, + ) + + eq_(m1.mock_calls, [mock.call.id_chooser(mock.ANY, m2.primary_key)]) diff --git a/test/ext/test_horizontal_shard.py b/test/ext/test_horizontal_shard.py index ab4a24f71c..8e5d09cab0 100644 --- a/test/ext/test_horizontal_shard.py +++ b/test/ext/test_horizontal_shard.py @@ -28,6 +28,7 @@ from sqlalchemy.pool import SingletonThreadPool from sqlalchemy.sql import operators from sqlalchemy.sql import Select from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_deprecated from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import provision @@ -109,7 +110,15 @@ class ShardTest: else: return shard_chooser(mapper, instance.location) - def id_chooser(query, ident): + def identity_chooser( + mapper, + primary_key, + *, + lazy_loaded_from, + execution_options, + bind_arguments, + **kw, + ): return ["north_america", "asia", "europe", "south_america"] def execute_chooser(orm_context): @@ -144,7 +153,7 @@ class ShardTest: "south_america": db4, }, shard_chooser=shard_chooser, - id_chooser=id_chooser, + identity_chooser=identity_chooser, execute_chooser=execute_chooser, ) @@ -189,7 +198,7 @@ class ShardTest: tokyo.reports.append(Report(80.0, id_=1)) newyork.reports.append(Report(75, id_=1)) quito.reports.append(Report(85)) - sess = sharded_session(future=True) + sess = sharded_session() for c in [tokyo, newyork, toronto, london, dublin, brasilia, quito]: sess.add(c) sess.flush() @@ -589,6 +598,68 @@ class DistinctEngineShardTest(ShardTest, fixtures.MappedTest): ) +class LegacyAPIShardTest(DistinctEngineShardTest): + @classmethod + def setup_session(cls): + global sharded_session + shard_lookup = { + "North America": "north_america", + "Asia": "asia", + "Europe": "europe", + "South America": "south_america", + } + + def shard_chooser(mapper, instance, clause=None): + if isinstance(instance, WeatherLocation): + return shard_lookup[instance.continent] + else: + return shard_chooser(mapper, instance.location) + + def id_chooser(query, primary_key): + return ["north_america", "asia", "europe", "south_america"] + + def query_chooser(query): + ids = [] + + class FindContinent(sql.ClauseVisitor): + def visit_binary(self, binary): + if binary.left.shares_lineage( + weather_locations.c.continent + ): + if binary.operator == operators.eq: + ids.append(shard_lookup[binary.right.value]) + elif binary.operator == operators.in_op: + for value in binary.right.value: + ids.append(shard_lookup[value]) + + if isinstance(query, Select) and query.whereclause is not None: + FindContinent().traverse(query.whereclause) + if len(ids) == 0: + return ["north_america", "asia", "europe", "south_america"] + else: + return ids + + sm = sessionmaker(class_=ShardedSession, autoflush=True) + sm.configure( + shards={ + "north_america": db1, + "asia": db2, + "europe": db3, + "south_america": db4, + }, + shard_chooser=shard_chooser, + id_chooser=id_chooser, + query_chooser=query_chooser, + ) + + def sharded_session(): + with expect_deprecated( + "The ``id_chooser`` parameter is deprecated", + "The ``query_chooser`` parameter is deprecated", + ): + return sm() + + class AttachedFileShardTest(ShardTest, fixtures.MappedTest): """Use modern schema conventions along with SQLite ATTACH.""" @@ -723,7 +794,7 @@ class SelectinloadRegressionTest(fixtures.DeclarativeMappedTest): session = ShardedSession( shards={"test": testing.db}, shard_chooser=lambda *args: "test", - id_chooser=lambda *args: None, + identity_chooser=lambda *args: None, execute_chooser=lambda *args: ["test"], ) @@ -764,7 +835,7 @@ class RefreshDeferExpireTest(fixtures.DeclarativeMappedTest): return ShardedSession( shards={"main": testing.db}, shard_chooser=lambda *args: "main", - id_chooser=lambda *args: ["fake", "main"], + identity_chooser=lambda *args: ["fake", "main"], execute_chooser=lambda *args: ["fake", "main"], **kw, ) @@ -843,15 +914,23 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): else: assert False - def id_chooser(query, ident): - assert query.lazy_loaded_from - if isinstance(query.lazy_loaded_from.obj(), Book): - token = shard_for_book(query.lazy_loaded_from.obj()) - assert query.lazy_loaded_from.identity_token == token + def identity_chooser( + mapper, + primary_key, + *, + lazy_loaded_from, + execution_options, + bind_arguments, + **kw, + ): + assert lazy_loaded_from + if isinstance(lazy_loaded_from.obj(), Book): + token = shard_for_book(lazy_loaded_from.obj()) + assert lazy_loaded_from.identity_token == token - return [query.lazy_loaded_from.identity_token] + return [lazy_loaded_from.identity_token] - def no_query_chooser(orm_context): + def execute_chooser(orm_context): if ( orm_context.statement.column_descriptions[0]["type"] is Book and lazy_load_book @@ -878,8 +957,8 @@ class LazyLoadIdentityKeyTest(fixtures.DeclarativeMappedTest): session = ShardedSession( shards={"test": db1, "test2": db2}, shard_chooser=shard_chooser, - id_chooser=id_chooser, - execute_chooser=no_query_chooser, + identity_chooser=identity_chooser, + execute_chooser=execute_chooser, ) return session diff --git a/test/orm/test_events.py b/test/orm/test_events.py index 56d2815fa5..05d5d376de 100644 --- a/test/orm/test_events.py +++ b/test/orm/test_events.py @@ -1,6 +1,7 @@ from unittest.mock import ANY from unittest.mock import call from unittest.mock import Mock +from unittest.mock import patch import sqlalchemy as sa from sqlalchemy import bindparam @@ -375,7 +376,6 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): result.context.execution_options, { "four": True, - "future_result": True, "one": True, "three": True, "two": True, @@ -741,7 +741,6 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): { "statement_two": True, "statement_four": True, - "future_result": True, "one": True, "two": True, "three": True, @@ -751,6 +750,42 @@ class ORMExecuteTest(_RemoveListeners, _fixtures.FixtureTest): }, ) + @testing.variation("session_start", [True, False]) + @testing.variation("dest_autoflush", [True, False]) + @testing.variation("stmt_type", ["select", "bulk", "dml"]) + def test_autoflush_change(self, session_start, dest_autoflush, stmt_type): + User = self.classes.User + + sess = fixture_session(autoflush=session_start) + + @event.listens_for(sess, "do_orm_execute") + def do_orm_execute(ctx): + ctx.update_execution_options(autoflush=dest_autoflush) + + with patch.object(sess, "_autoflush") as m1: + if stmt_type.select: + sess.execute(select(User)) + elif stmt_type.bulk: + sess.execute( + insert(User), + [ + {"id": 1, "name": "n1"}, + {"id": 2, "name": "n2"}, + {"id": 3, "name": "n3"}, + ], + ) + elif stmt_type.dml: + sess.execute( + update(User).where(User.id == 2).values(name="nn") + ) + else: + stmt_type.fail() + + if dest_autoflush: + eq_(m1.mock_calls, [call()]) + else: + eq_(m1.mock_calls, []) + class MapperEventsTest(_RemoveListeners, _fixtures.FixtureTest): run_inserts = None diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 7966006cf5..9e303a778b 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -5507,6 +5507,7 @@ class YieldTest(_fixtures.FixtureTest): @event.listens_for(sess, "do_orm_execute") def check(ctx): eq_(ctx.load_options._yield_per, 15) + return eq_( { k: v @@ -5516,7 +5517,6 @@ class YieldTest(_fixtures.FixtureTest): { "yield_per": 15, "foo": "bar", - "future_result": True, }, ) @@ -5535,6 +5535,7 @@ class YieldTest(_fixtures.FixtureTest): @event.listens_for(sess, "do_orm_execute") def check(ctx): eq_(ctx.load_options._yield_per, 15) + eq_( { k: v @@ -5543,7 +5544,6 @@ class YieldTest(_fixtures.FixtureTest): }, { "yield_per": 15, - "future_result": True, }, ) @@ -5553,8 +5553,8 @@ class YieldTest(_fixtures.FixtureTest): assert isinstance( result.raw.cursor_strategy, _cursor.BufferedRowCursorFetchStrategy ) + eq_(result._yield_per, 15) eq_(result.raw.cursor_strategy._max_row_buffer, 15) - eq_(len(result.all()), 4) def test_no_joinedload_opt(self): @@ -7515,23 +7515,80 @@ class ExecutionOptionsTest(QueryTest): assert u.addresses[0].email_address == "jack@bean.com" assert u.orders[1].items[2].description == "item 5" - def test_option_transfer_future(self): + @testing.variation("source", ["statement", "do_orm_exec"]) + def test_execution_options_to_load_options(self, source): User = self.classes.User - stmt = select(User).execution_options( - populate_existing=True, autoflush=False, yield_per=10 - ) + + stmt = select(User) + + if source.statement: + stmt = stmt.execution_options( + populate_existing=True, + autoflush=False, + yield_per=10, + identity_token="some_token", + ) s = fixture_session() m1 = mock.Mock() - event.listen(s, "do_orm_execute", m1) + def do_orm_execute(ctx): + m1(ctx) + if source.do_orm_exec: + ctx.update_execution_options( + autoflush=False, + populate_existing=True, + yield_per=10, + identity_token="some_token", + ) + + event.listen(s, "do_orm_execute", do_orm_execute) + + from sqlalchemy.orm import loading + + with mock.patch.object(loading, "instances") as m2: + s.execute(stmt) + + if source.do_orm_exec: + # in do_orm_exec version, load options are empty, our new + # execution options have not yet been transferred. + eq_( + m1.mock_calls[0][1][0].load_options, + QueryContext.default_load_options, + ) + elif source.statement: + # in statement version, the incoming exc options have been + # transferred, because the fact that do_orm_exec is used + # means the options were set up up front for the benefit + # of the do_orm_exec hook itself. + eq_( + m1.mock_calls[0][1][0].load_options, + QueryContext.default_load_options( + _autoflush=False, + _populate_existing=True, + _yield_per=10, + _identity_token="some_token", + ), + ) + + # py37 mock does not have .args + call_args = m2.mock_calls[0][1] - s.execute(stmt) + cursor = call_args[0] + cursor.all() + # the orm_pre_session_exec() method + # was called unconditionally after the event handler + # in both cases (i.e. a second time) so options were transferred + # even if we set them up in the do_orm_exec hook only. + query_context = call_args[1] eq_( - m1.mock_calls[0][1][0].load_options, + query_context.load_options, QueryContext.default_load_options( - _autoflush=False, _populate_existing=True, _yield_per=10 + _autoflush=False, + _populate_existing=True, + _yield_per=10, + _identity_token="some_token", ), ) diff --git a/test/orm/test_session.py b/test/orm/test_session.py index 79ea5d1703..921c55f74a 100644 --- a/test/orm/test_session.py +++ b/test/orm/test_session.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import inspect as _py_inspect import pickle +from typing import TYPE_CHECKING import sqlalchemy as sa from sqlalchemy import delete @@ -48,6 +51,9 @@ from sqlalchemy.testing.util import gc_collect from sqlalchemy.util.compat import inspect_getfullargspec from test.orm import _fixtures +if TYPE_CHECKING: + from sqlalchemy.orm import ORMExecuteState + class ExecutionTest(_fixtures.FixtureTest): run_inserts = None @@ -563,7 +569,10 @@ class SessionUtilTest(_fixtures.FixtureTest): u1, ) - def test_get_execution_option(self): + @testing.variation( + "arg", ["execution_options", "identity_token", "bind_arguments"] + ) + def test_get_arguments(self, arg: testing.Variation) -> None: users, User = self.tables.users, self.classes.User self.mapper_registry.map_imperatively(User, users) @@ -571,12 +580,28 @@ class SessionUtilTest(_fixtures.FixtureTest): called = False @event.listens_for(sess, "do_orm_execute") - def check(ctx): + def check(ctx: ORMExecuteState) -> None: nonlocal called called = True - eq_(ctx.execution_options["foo"], "bar") - sess.get(User, 42, execution_options={"foo": "bar"}) + if arg.execution_options: + eq_(ctx.execution_options["foo"], "bar") + elif arg.bind_arguments: + eq_(ctx.bind_arguments["foo"], "bar") + elif arg.identity_token: + eq_(ctx.load_options._identity_token, "foobar") + else: + arg.fail() + + if arg.execution_options: + sess.get(User, 42, execution_options={"foo": "bar"}) + elif arg.bind_arguments: + sess.get(User, 42, bind_arguments={"foo": "bar"}) + elif arg.identity_token: + sess.get(User, 42, identity_token="foobar") + else: + arg.fail() + sess.close() is_true(called) -- 2.47.2