]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-111693: Propagate correct asyncio.CancelledError instance out of asyncio.Condition...
authorKristján Valur Jónsson <sweskman@gmail.com>
Mon, 8 Jan 2024 19:57:48 +0000 (19:57 +0000)
committerGitHub <noreply@github.com>
Mon, 8 Jan 2024 19:57:48 +0000 (11:57 -0800)
Also fix a race condition in `asyncio.Semaphore.acquire()` when cancelled.

Lib/asyncio/futures.py
Lib/asyncio/locks.py
Lib/test/test_asyncio/test_locks.py
Misc/NEWS.d/next/Library/2024-01-07-13-36-03.gh-issue-111693.xN2LuL.rst [new file with mode: 0644]

index 97fc4e3fcb60ee24eabcf9a32cb0140a2b81b2c4..d19e5d8c9194fdab300e4d9ae56de218a37f8b13 100644 (file)
@@ -138,9 +138,6 @@ class Future:
             exc = exceptions.CancelledError()
         else:
             exc = exceptions.CancelledError(self._cancel_message)
-        exc.__context__ = self._cancelled_exc
-        # Remove the reference since we don't need this anymore.
-        self._cancelled_exc = None
         return exc
 
     def cancel(self, msg=None):
index ce5d8d5bfb2e81eeaff5229411962f1d8629e011..04158e667a895fceaf22acc4bd3ee66806563451 100644 (file)
@@ -95,6 +95,8 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
         This method blocks until the lock is unlocked, then sets it to
         locked and returns True.
         """
+        # Implement fair scheduling, where thread always waits
+        # its turn. Jumping the queue if all are cancelled is an optimization.
         if (not self._locked and (self._waiters is None or
                 all(w.cancelled() for w in self._waiters))):
             self._locked = True
@@ -105,19 +107,22 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
         fut = self._get_loop().create_future()
         self._waiters.append(fut)
 
-        # Finally block should be called before the CancelledError
-        # handling as we don't want CancelledError to call
-        # _wake_up_first() and attempt to wake up itself.
         try:
             try:
                 await fut
             finally:
                 self._waiters.remove(fut)
         except exceptions.CancelledError:
+            # Currently the only exception designed be able to occur here.
+
+            # Ensure the lock invariant: If lock is not claimed (or about
+            # to be claimed by us) and there is a Task in waiters,
+            # ensure that the Task at the head will run.
             if not self._locked:
                 self._wake_up_first()
             raise
 
+        # assert self._locked is False
         self._locked = True
         return True
 
@@ -139,7 +144,7 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
             raise RuntimeError('Lock is not acquired.')
 
     def _wake_up_first(self):
-        """Wake up the first waiter if it isn't done."""
+        """Ensure that the first waiter will wake up."""
         if not self._waiters:
             return
         try:
@@ -147,9 +152,7 @@ class Lock(_ContextManagerMixin, mixins._LoopBoundMixin):
         except StopIteration:
             return
 
-        # .done() necessarily means that a waiter will wake up later on and
-        # either take the lock, or, if it was cancelled and lock wasn't
-        # taken already, will hit this again and wake up a new waiter.
+        # .done() means that the waiter is already set to wake up.
         if not fut.done():
             fut.set_result(True)
 
@@ -269,17 +272,22 @@ class Condition(_ContextManagerMixin, mixins._LoopBoundMixin):
                 self._waiters.remove(fut)
 
         finally:
-            # Must reacquire lock even if wait is cancelled
-            cancelled = False
+            # Must re-acquire lock even if wait is cancelled.
+            # We only catch CancelledError here, since we don't want any
+            # other (fatal) errors with the future to cause us to spin.
+            err = None
             while True:
                 try:
                     await self.acquire()
                     break
-                except exceptions.CancelledError:
-                    cancelled = True
+                except exceptions.CancelledError as e:
+                    err = e
 
-            if cancelled:
-                raise exceptions.CancelledError
+            if err:
+                try:
+                    raise err  # Re-raise most recent exception instance.
+                finally:
+                    err = None  # Break reference cycles.
 
     async def wait_for(self, predicate):
         """Wait until a predicate becomes true.
@@ -357,6 +365,7 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
 
     def locked(self):
         """Returns True if semaphore cannot be acquired immediately."""
+        # Due to state, or FIFO rules (must allow others to run first).
         return self._value == 0 or (
             any(not w.cancelled() for w in (self._waiters or ())))
 
@@ -370,6 +379,7 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
         True.
         """
         if not self.locked():
+            # Maintain FIFO, wait for others to start even if _value > 0.
             self._value -= 1
             return True
 
@@ -378,22 +388,27 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
         fut = self._get_loop().create_future()
         self._waiters.append(fut)
 
-        # Finally block should be called before the CancelledError
-        # handling as we don't want CancelledError to call
-        # _wake_up_first() and attempt to wake up itself.
         try:
             try:
                 await fut
             finally:
                 self._waiters.remove(fut)
         except exceptions.CancelledError:
-            if not fut.cancelled():
+            # Currently the only exception designed be able to occur here.
+            if fut.done() and not fut.cancelled():
+                # Our Future was successfully set to True via _wake_up_next(),
+                # but we are not about to successfully acquire(). Therefore we
+                # must undo the bookkeeping already done and attempt to wake
+                # up someone else.
                 self._value += 1
-                self._wake_up_next()
             raise
 
-        if self._value > 0:
-            self._wake_up_next()
+        finally:
+            # New waiters may have arrived but had to wait due to FIFO.
+            # Wake up as many as are allowed.
+            while self._value > 0:
+                if not self._wake_up_next():
+                    break  # There was no-one to wake up.
         return True
 
     def release(self):
@@ -408,13 +423,15 @@ class Semaphore(_ContextManagerMixin, mixins._LoopBoundMixin):
     def _wake_up_next(self):
         """Wake up the first waiter that isn't done."""
         if not self._waiters:
-            return
+            return False
 
         for fut in self._waiters:
             if not fut.done():
                 self._value -= 1
                 fut.set_result(True)
-                return
+                # `fut` is now `done()` and not `cancelled()`.
+                return True
+        return False
 
 
 class BoundedSemaphore(Semaphore):
index f6c6a282429a21f993e4bdb8a6ef43ced8d1b0a2..9029efd2355b46fc3fd8a26ea9ed6f89ec844777 100644 (file)
@@ -758,6 +758,63 @@ class ConditionTests(unittest.IsolatedAsyncioTestCase):
             with self.assertRaises(asyncio.TimeoutError):
                 await asyncio.wait_for(condition.wait(), timeout=0.5)
 
+    async def test_cancelled_error_wakeup(self):
+        # Test that a cancelled error, received when awaiting wakeup,
+        # will be re-raised un-modified.
+        wake = False
+        raised = None
+        cond = asyncio.Condition()
+
+        async def func():
+            nonlocal raised
+            async with cond:
+                with self.assertRaises(asyncio.CancelledError) as err:
+                    await cond.wait_for(lambda: wake)
+                raised = err.exception
+                raise raised
+
+        task = asyncio.create_task(func())
+        await asyncio.sleep(0)
+        # Task is waiting on the condition, cancel it there.
+        task.cancel(msg="foo")
+        with self.assertRaises(asyncio.CancelledError) as err:
+            await task
+        self.assertEqual(err.exception.args, ("foo",))
+        # We should have got the _same_ exception instance as the one
+        # originally raised.
+        self.assertIs(err.exception, raised)
+
+    async def test_cancelled_error_re_aquire(self):
+        # Test that a cancelled error, received when re-aquiring lock,
+        # will be re-raised un-modified.
+        wake = False
+        raised = None
+        cond = asyncio.Condition()
+
+        async def func():
+            nonlocal raised
+            async with cond:
+                with self.assertRaises(asyncio.CancelledError) as err:
+                    await cond.wait_for(lambda: wake)
+                raised = err.exception
+                raise raised
+
+        task = asyncio.create_task(func())
+        await asyncio.sleep(0)
+        # Task is waiting on the condition
+        await cond.acquire()
+        wake = True
+        cond.notify()
+        await asyncio.sleep(0)
+        # Task is now trying to re-acquire the lock, cancel it there.
+        task.cancel(msg="foo")
+        cond.release()
+        with self.assertRaises(asyncio.CancelledError) as err:
+            await task
+        self.assertEqual(err.exception.args, ("foo",))
+        # We should have got the _same_ exception instance as the one
+        # originally raised.
+        self.assertIs(err.exception, raised)
 
 class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
 
@@ -1044,6 +1101,62 @@ class SemaphoreTests(unittest.IsolatedAsyncioTestCase):
         await asyncio.gather(*tasks, return_exceptions=True)
         self.assertEqual([2, 3], result)
 
+    async def test_acquire_fifo_order_4(self):
+        # Test that a successfule `acquire()` will wake up multiple Tasks
+        # that were waiting in the Semaphore queue due to FIFO rules.
+        sem = asyncio.Semaphore(0)
+        result = []
+        count = 0
+
+        async def c1(result):
+            # First task immediatlly waits for semaphore.  It will be awoken by c2.
+            self.assertEqual(sem._value, 0)
+            await sem.acquire()
+            # We should have woken up all waiting tasks now.
+            self.assertEqual(sem._value, 0)
+            # Create a fourth task.  It should run after c3, not c2.
+            nonlocal t4
+            t4 = asyncio.create_task(c4(result))
+            result.append(1)
+            return True
+
+        async def c2(result):
+            # The second task begins by releasing semaphore three times,
+            # for c1, c2, and c3.
+            sem.release()
+            sem.release()
+            sem.release()
+            self.assertEqual(sem._value, 2)
+            # It is locked, because c1 hasn't woken up yet.
+            self.assertTrue(sem.locked())
+            await sem.acquire()
+            result.append(2)
+            return True
+
+        async def c3(result):
+            await sem.acquire()
+            self.assertTrue(sem.locked())
+            result.append(3)
+            return True
+
+        async def c4(result):
+            result.append(4)
+            return True
+
+        t1 = asyncio.create_task(c1(result))
+        t2 = asyncio.create_task(c2(result))
+        t3 = asyncio.create_task(c3(result))
+        t4 = None
+
+        await asyncio.sleep(0)
+        # Three tasks are in the queue, the first hasn't woken up yet.
+        self.assertEqual(sem._value, 2)
+        self.assertEqual(len(sem._waiters), 3)
+        await asyncio.sleep(0)
+
+        tasks = [t1, t2, t3, t4]
+        await asyncio.gather(*tasks)
+        self.assertEqual([1, 2, 3, 4], result)
 
 class BarrierTests(unittest.IsolatedAsyncioTestCase):
 
diff --git a/Misc/NEWS.d/next/Library/2024-01-07-13-36-03.gh-issue-111693.xN2LuL.rst b/Misc/NEWS.d/next/Library/2024-01-07-13-36-03.gh-issue-111693.xN2LuL.rst
new file mode 100644 (file)
index 0000000..2201f47
--- /dev/null
@@ -0,0 +1 @@
+:func:`asyncio.Condition.wait()` now re-raises the same :exc:`CancelledError` instance that may have caused it to be interrupted.  Fixed race condition in :func:`asyncio.Semaphore.aquire` when interrupted with a :exc:`CancelledError`.