From ab3e32c8110229b35a6a0cd33ef9e21c9435c246 Mon Sep 17 00:00:00 2001 From: Daniel Stone Date: Sun, 29 Aug 2021 16:31:29 +0100 Subject: [PATCH] Add loader options to session.merge (fixes #6955) --- lib/sqlalchemy/ext/asyncio/session.py | 4 ++-- lib/sqlalchemy/orm/session.py | 13 +++++++++-- test/ext/asyncio/test_session_py3k.py | 33 +++++++++++++++++++++++++++ test/orm/test_merge.py | 31 +++++++++++++++++++++++++ 4 files changed, 77 insertions(+), 4 deletions(-) diff --git a/lib/sqlalchemy/ext/asyncio/session.py b/lib/sqlalchemy/ext/asyncio/session.py index 5c6e7f5a7c..e8c7611002 100644 --- a/lib/sqlalchemy/ext/asyncio/session.py +++ b/lib/sqlalchemy/ext/asyncio/session.py @@ -242,13 +242,13 @@ class AsyncSession(ReversibleProxy): """ return await greenlet_spawn(self.sync_session.delete, instance) - async def merge(self, instance, load=True): + async def merge(self, instance, load=True, options=None): """Copy the state of a given instance into a corresponding instance within this :class:`_asyncio.AsyncSession`. """ return await greenlet_spawn( - self.sync_session.merge, instance, load=load + self.sync_session.merge, instance, load=load, options=options ) async def flush(self, objects=None): diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py index 0bdd5cc959..a32152c7d3 100644 --- a/lib/sqlalchemy/orm/session.py +++ b/lib/sqlalchemy/orm/session.py @@ -2843,7 +2843,7 @@ class Session(_SessionClassMethods): load_options=load_options, ) - def merge(self, instance, load=True): + def merge(self, instance, load=True, options=None): """Copy the state of a given instance into a corresponding instance within this :class:`.Session`. @@ -2889,6 +2889,8 @@ class Session(_SessionClassMethods): produced as "clean", so it is only appropriate that the given objects should be "clean" as well, else this suggests a mis-use of the method. + :param options: optional sequence of loader options which will be + applied to the query, if the instance is not found locally. .. seealso:: @@ -2916,6 +2918,7 @@ class Session(_SessionClassMethods): attributes.instance_state(instance), attributes.instance_dict(instance), load=load, + options=options, _recursive=_recursive, _resolve_conflict_map=_resolve_conflict_map, ) @@ -2927,6 +2930,7 @@ class Session(_SessionClassMethods): state, state_dict, load=True, + options=None, _recursive=None, _resolve_conflict_map=None, ): @@ -2990,7 +2994,12 @@ class Session(_SessionClassMethods): new_instance = True elif key_is_persistent: - merged = self.get(mapper.class_, key[1], identity_token=key[2]) + merged = self.get( + mapper.class_, + key[1], + identity_token=key[2], + options=options, + ) if merged is None: merged = mapper.class_manager.new_instance() diff --git a/test/ext/asyncio/test_session_py3k.py b/test/ext/asyncio/test_session_py3k.py index ebedfedbfb..fc223c3417 100644 --- a/test/ext/asyncio/test_session_py3k.py +++ b/test/ext/asyncio/test_session_py3k.py @@ -104,6 +104,17 @@ class AsyncSessionQueryTest(AsyncFixture): u3 = await async_session.get(User, 12) is_(u3, None) + @async_test + async def test_get_loader_options(self, async_session): + User = self.classes.User + + u = await async_session.get( + User, 7, options=[selectinload(User.addresses)] + ) + + eq_(u.name, "jack") + eq_(len(u.addresses), 1) + @async_test @testing.requires.independent_cursors @testing.combinations( @@ -333,6 +344,28 @@ class AsyncSessionTransactionTest(AsyncFixture): is_(new_u_merged, u1) eq_(u1.name, "new u1") + @async_test + async def test_merge_loader_options(self, async_session): + User = self.classes.User + Address = self.classes.Address + + async with async_session.begin(): + u1 = User(id=1, name="u1", addresses=[Address(email_address="e1")]) + + async_session.add(u1) + + await async_session.close() + + async with async_session.begin(): + new_u1 = User(id=1, name="new u1") + + new_u_merged = await async_session.merge( + new_u1, options=[selectinload(User.addresses)] + ) + + eq_(new_u_merged.name, "new u1") + eq_(len(new_u_merged.addresses), 1) + @async_test async def test_join_to_external_transaction(self, async_engine): User = self.classes.User diff --git a/test/orm/test_merge.py b/test/orm/test_merge.py index a0cb5426f3..cf4e81c98e 100644 --- a/test/orm/test_merge.py +++ b/test/orm/test_merge.py @@ -18,6 +18,7 @@ from sqlalchemy.orm import deferred from sqlalchemy.orm import foreign from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship +from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import synonym from sqlalchemy.orm.collections import attribute_mapped_collection @@ -52,6 +53,36 @@ class MergeTest(_fixtures.FixtureTest): return canary + def test_loader_options(self): + User, Address, addresses, users = ( + self.classes.User, + self.classes.Address, + self.tables.addresses, + self.tables.users, + ) + + mapper( + User, + users, + properties={"addresses": relationship(Address, backref="user")}, + ) + mapper(Address, addresses) + + s = fixture_session() + u = User( + id=7, + name="fred", + addresses=[Address(id=1, email_address="jack@bean.com")], + ) + s.add(u) + s.commit() + s.close() + + u = User(id=7, name="fred") + u2 = s.merge(u, options=[selectinload(User.addresses)]) + + eq_(len(u2.__dict__["addresses"]), 1) + def test_transient_to_pending(self): User, users = self.classes.User, self.tables.users -- 2.47.2