]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Add loader options to session.merge (fixes #6955)
authorDaniel Stone <me@danstone.uk>
Sun, 29 Aug 2021 15:31:29 +0000 (16:31 +0100)
committerDaniel Stone <me@danstone.uk>
Sun, 29 Aug 2021 15:31:29 +0000 (16:31 +0100)
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/orm/session.py
test/ext/asyncio/test_session_py3k.py
test/orm/test_merge.py

index 5c6e7f5a7c23cedfacc046862c62d2d554591811..e8c76110023b8e8218126fd5df7899a176c95631 100644 (file)
@@ -242,13 +242,13 @@ class AsyncSession(ReversibleProxy):
         """
         return await greenlet_spawn(self.sync_session.delete, instance)
 
-    async def merge(self, instance, load=True):
+    async def merge(self, instance, load=True, options=None):
         """Copy the state of a given instance into a corresponding instance
         within this :class:`_asyncio.AsyncSession`.
 
         """
         return await greenlet_spawn(
-            self.sync_session.merge, instance, load=load
+            self.sync_session.merge, instance, load=load, options=options
         )
 
     async def flush(self, objects=None):
index 0bdd5cc959d4ee568cc469001b2f9bba366351b9..a32152c7d33570c707a698238e5d02026582b6aa 100644 (file)
@@ -2843,7 +2843,7 @@ class Session(_SessionClassMethods):
             load_options=load_options,
         )
 
-    def merge(self, instance, load=True):
+    def merge(self, instance, load=True, options=None):
         """Copy the state of a given instance into a corresponding instance
         within this :class:`.Session`.
 
@@ -2889,6 +2889,8 @@ class Session(_SessionClassMethods):
          produced as "clean", so it is only appropriate that the given objects
          should be "clean" as well, else this suggests a mis-use of the
          method.
+        :param options: optional sequence of loader options which will be
+         applied to the query, if the instance is not found locally.
 
 
         .. seealso::
@@ -2916,6 +2918,7 @@ class Session(_SessionClassMethods):
                 attributes.instance_state(instance),
                 attributes.instance_dict(instance),
                 load=load,
+                options=options,
                 _recursive=_recursive,
                 _resolve_conflict_map=_resolve_conflict_map,
             )
@@ -2927,6 +2930,7 @@ class Session(_SessionClassMethods):
         state,
         state_dict,
         load=True,
+        options=None,
         _recursive=None,
         _resolve_conflict_map=None,
     ):
@@ -2990,7 +2994,12 @@ class Session(_SessionClassMethods):
                 new_instance = True
 
             elif key_is_persistent:
-                merged = self.get(mapper.class_, key[1], identity_token=key[2])
+                merged = self.get(
+                    mapper.class_,
+                    key[1],
+                    identity_token=key[2],
+                    options=options,
+                )
 
         if merged is None:
             merged = mapper.class_manager.new_instance()
index ebedfedbfba0926a0ab3607d54bc0669c0e50f4d..fc223c3417d65d532c4603ea51cc213fbc34223f 100644 (file)
@@ -104,6 +104,17 @@ class AsyncSessionQueryTest(AsyncFixture):
         u3 = await async_session.get(User, 12)
         is_(u3, None)
 
+    @async_test
+    async def test_get_loader_options(self, async_session):
+        User = self.classes.User
+
+        u = await async_session.get(
+            User, 7, options=[selectinload(User.addresses)]
+        )
+
+        eq_(u.name, "jack")
+        eq_(len(u.addresses), 1)
+
     @async_test
     @testing.requires.independent_cursors
     @testing.combinations(
@@ -333,6 +344,28 @@ class AsyncSessionTransactionTest(AsyncFixture):
             is_(new_u_merged, u1)
             eq_(u1.name, "new u1")
 
+    @async_test
+    async def test_merge_loader_options(self, async_session):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        async with async_session.begin():
+            u1 = User(id=1, name="u1", addresses=[Address(email_address="e1")])
+
+            async_session.add(u1)
+
+        await async_session.close()
+
+        async with async_session.begin():
+            new_u1 = User(id=1, name="new u1")
+
+            new_u_merged = await async_session.merge(
+                new_u1, options=[selectinload(User.addresses)]
+            )
+
+            eq_(new_u_merged.name, "new u1")
+            eq_(len(new_u_merged.addresses), 1)
+
     @async_test
     async def test_join_to_external_transaction(self, async_engine):
         User = self.classes.User
index a0cb5426f3adacd52735bf59fe5bc6ae8ee43c03..cf4e81c98e0e272fc5ea47f6b8d5598dfb927b64 100644 (file)
@@ -18,6 +18,7 @@ from sqlalchemy.orm import deferred
 from sqlalchemy.orm import foreign
 from sqlalchemy.orm import mapper
 from sqlalchemy.orm import relationship
+from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import Session
 from sqlalchemy.orm import synonym
 from sqlalchemy.orm.collections import attribute_mapped_collection
@@ -52,6 +53,36 @@ class MergeTest(_fixtures.FixtureTest):
 
         return canary
 
+    def test_loader_options(self):
+        User, Address, addresses, users = (
+            self.classes.User,
+            self.classes.Address,
+            self.tables.addresses,
+            self.tables.users,
+        )
+
+        mapper(
+            User,
+            users,
+            properties={"addresses": relationship(Address, backref="user")},
+        )
+        mapper(Address, addresses)
+
+        s = fixture_session()
+        u = User(
+            id=7,
+            name="fred",
+            addresses=[Address(id=1, email_address="jack@bean.com")],
+        )
+        s.add(u)
+        s.commit()
+        s.close()
+
+        u = User(id=7, name="fred")
+        u2 = s.merge(u, options=[selectinload(User.addresses)])
+
+        eq_(len(u2.__dict__["addresses"]), 1)
+
     def test_transient_to_pending(self):
         User, users = self.classes.User, self.tables.users