]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
convert AsyncSession.delete into awaitable
authorMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Mar 2021 16:16:49 +0000 (11:16 -0500)
committerMike Bayer <mike_mp@zzzcomputing.com>
Tue, 2 Mar 2021 16:16:49 +0000 (11:16 -0500)
The API for :meth:`_asyncio.AsyncSession.delete` is now an awaitable;
this method cascades along relationships which must be loaded in a
similar manner as the :meth:`_asyncio.AsyncSession.merge` method.

Fixes: #5998
Change-Id: Iae001efe99a1dcc47598b4a2491d17c4157fbbfa

doc/build/changelog/unreleased_14/5998.rst [new file with mode: 0644]
lib/sqlalchemy/ext/asyncio/session.py
lib/sqlalchemy/testing/fixtures.py
test/ext/asyncio/test_session_py3k.py

diff --git a/doc/build/changelog/unreleased_14/5998.rst b/doc/build/changelog/unreleased_14/5998.rst
new file mode 100644 (file)
index 0000000..8ff3659
--- /dev/null
@@ -0,0 +1,8 @@
+.. change::
+    :tags: bug, orm, asyncio
+    :tickets: 5998
+
+    The API for :meth:`_asyncio.AsyncSession.delete` is now an awaitable;
+    this method cascades along relationships which must be loaded in a
+    similar manner as the :meth:`_asyncio.AsyncSession.merge` method.
+
index faa279cf9c0929d087f9ebf184243589e4789426..93af178a35196cd260ae9c5eeb037b0c0088ed74 100644 (file)
@@ -28,7 +28,6 @@ T = TypeVar("T")
         "__iter__",
         "add",
         "add_all",
-        "delete",
         "expire",
         "expire_all",
         "expunge",
@@ -223,6 +222,18 @@ class AsyncSession:
         )
         return _result.AsyncResult(result)
 
+    async def delete(self, instance):
+        """Mark an instance as deleted.
+
+        The database delete operation occurs upon ``flush()``.
+
+        As this operation may need to cascade along unloaded relationships,
+        it is awaitable to allow for those queries to take place.
+
+
+        """
+        return await greenlet_spawn(self.sync_session.delete, instance)
+
     async def merge(self, instance, load=True):
         """Copy the state of a given instance into a corresponding instance
         within this :class:`_asyncio.AsyncSession`.
index 4b76e6d88f3b41074d28d98220a897f6f45391c5..95dce02a9d4b2687725e091c15695b8c54933355 100644 (file)
@@ -466,7 +466,7 @@ class MappedTest(TablesTest, assertions.AssertsExecutionResults):
 
     def _setup_each_mappers(self):
         if self.run_setup_mappers == "each":
-            self.mapper = self._generate_mapper()
+            self.__class__.mapper = self._generate_mapper()
             self._with_register_classes(self.setup_mappers)
 
     def _setup_each_classes(self):
index e56adec4d3a55b39dae51a23ebf7024a74977999..d308764fbcc11062180e589d21e6f4c0e7ab0f99 100644 (file)
@@ -5,6 +5,7 @@ from sqlalchemy import select
 from sqlalchemy import testing
 from sqlalchemy import update
 from sqlalchemy.ext.asyncio import AsyncSession
+from sqlalchemy.orm import relationship
 from sqlalchemy.orm import selectinload
 from sqlalchemy.orm import sessionmaker
 from sqlalchemy.testing import async_test
@@ -236,7 +237,7 @@ class AsyncSessionTransactionTest(AsyncFixture):
 
             eq_(await conn.scalar(select(func.count(User.id))), 1)
 
-            async_session.delete(u1)
+            await async_session.delete(u1)
 
             await async_session.flush()
 
@@ -404,6 +405,53 @@ class AsyncSessionTransactionTest(AsyncFixture):
             eq_(result.all(), [])
 
 
+class AsyncCascadesTest(AsyncFixture):
+    run_inserts = None
+
+    @classmethod
+    def setup_mappers(cls):
+        User, Address = cls.classes("User", "Address")
+        users, addresses = cls.tables("users", "addresses")
+
+        cls.mapper(
+            User,
+            users,
+            properties={
+                "addresses": relationship(
+                    Address, cascade="all, delete-orphan"
+                )
+            },
+        )
+        cls.mapper(
+            Address,
+            addresses,
+        )
+
+    @async_test
+    async def test_delete_w_cascade(self, async_session):
+        User = self.classes.User
+        Address = self.classes.Address
+
+        async with async_session.begin():
+            u1 = User(id=1, name="u1", addresses=[Address(email_address="e1")])
+
+            async_session.add(u1)
+
+        async with async_session.begin():
+            u1 = (await async_session.execute(select(User))).scalar_one()
+
+            await async_session.delete(u1)
+
+        eq_(
+            (
+                await async_session.execute(
+                    select(func.count()).select_from(Address)
+                )
+            ).scalar(),
+            0,
+        )
+
+
 class AsyncEventTest(AsyncFixture):
     """The engine events all run in their normal synchronous context.