]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-96471: Add asyncio queue shutdown (#104228)
authorLaurie O <laurie_opperman@hotmail.com>
Sat, 6 Apr 2024 14:27:13 +0000 (00:27 +1000)
committerGitHub <noreply@github.com>
Sat, 6 Apr 2024 14:27:13 +0000 (07:27 -0700)
Co-authored-by: Duprat <yduprat@gmail.com>
Doc/library/asyncio-queue.rst
Doc/whatsnew/3.13.rst
Lib/asyncio/queues.py
Lib/test/test_asyncio/test_queues.py
Misc/NEWS.d/next/Library/2023-05-06-05-00-42.gh-issue-96471.S3X5I-.rst [new file with mode: 0644]

index d86fbc21351e2dc273fb3f6e177a49d241b91a18..030d4310942d7af9d334519589a89feb70c7e32e 100644 (file)
@@ -62,6 +62,9 @@ Queue
       Remove and return an item from the queue. If queue is empty,
       wait until an item is available.
 
+      Raises :exc:`QueueShutDown` if the queue has been shut down and
+      is empty, or if the queue has been shut down immediately.
+
    .. method:: get_nowait()
 
       Return an item if one is immediately available, else raise
@@ -82,6 +85,8 @@ Queue
       Put an item into the queue. If the queue is full, wait until a
       free slot is available before adding the item.
 
+      Raises :exc:`QueueShutDown` if the queue has been shut down.
+
    .. method:: put_nowait(item)
 
       Put an item into the queue without blocking.
@@ -92,6 +97,21 @@ Queue
 
       Return the number of items in the queue.
 
+   .. method:: shutdown(immediate=False)
+
+      Shut down the queue, making :meth:`~Queue.get` and :meth:`~Queue.put`
+      raise :exc:`QueueShutDown`.
+
+      By default, :meth:`~Queue.get` on a shut down queue will only
+      raise once the queue is empty. Set *immediate* to true to make
+      :meth:`~Queue.get` raise immediately instead.
+
+      All blocked callers of :meth:`~Queue.put` will be unblocked. If
+      *immediate* is true, also unblock callers of :meth:`~Queue.get`
+      and :meth:`~Queue.join`.
+
+      .. versionadded:: 3.13
+
    .. method:: task_done()
 
       Indicate that a formerly enqueued task is complete.
@@ -105,6 +125,9 @@ Queue
       call was received for every item that had been :meth:`~Queue.put`
       into the queue).
 
+      ``shutdown(immediate=True)`` calls :meth:`task_done` for each
+      remaining item in the queue.
+
       Raises :exc:`ValueError` if called more times than there were
       items placed in the queue.
 
@@ -145,6 +168,14 @@ Exceptions
    on a queue that has reached its *maxsize*.
 
 
+.. exception:: QueueShutDown
+
+   Exception raised when :meth:`~Queue.put` or :meth:`~Queue.get` is
+   called on a queue which has been shut down.
+
+   .. versionadded:: 3.13
+
+
 Examples
 ========
 
index e31f0c52d4c5f54a6f5c0d8bda691d1c08293f1c..c785d4cfa8fdc33d6a0d15af5a8805e1ed8b93f8 100644 (file)
@@ -296,6 +296,10 @@ asyncio
   with the tasks being completed.
   (Contributed by Justin Arthur in :gh:`77714`.)
 
+* Add :meth:`asyncio.Queue.shutdown` (along with
+  :exc:`asyncio.QueueShutDown`) for queue termination.
+  (Contributed by Laurie Opperman in :gh:`104228`.)
+
 base64
 ------
 
index a9656a6df561ba8c3084ac19ce212afb46154ce1..b8156704b8fc23622f5fafb8bd79a16ff96921c5 100644 (file)
@@ -1,4 +1,11 @@
-__all__ = ('Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty')
+__all__ = (
+    'Queue',
+    'PriorityQueue',
+    'LifoQueue',
+    'QueueFull',
+    'QueueEmpty',
+    'QueueShutDown',
+)
 
 import collections
 import heapq
@@ -18,6 +25,11 @@ class QueueFull(Exception):
     pass
 
 
+class QueueShutDown(Exception):
+    """Raised when putting on to or getting from a shut-down Queue."""
+    pass
+
+
 class Queue(mixins._LoopBoundMixin):
     """A queue, useful for coordinating producer and consumer coroutines.
 
@@ -41,6 +53,7 @@ class Queue(mixins._LoopBoundMixin):
         self._finished = locks.Event()
         self._finished.set()
         self._init(maxsize)
+        self._is_shutdown = False
 
     # These three are overridable in subclasses.
 
@@ -81,6 +94,8 @@ class Queue(mixins._LoopBoundMixin):
             result += f' _putters[{len(self._putters)}]'
         if self._unfinished_tasks:
             result += f' tasks={self._unfinished_tasks}'
+        if self._is_shutdown:
+            result += ' shutdown'
         return result
 
     def qsize(self):
@@ -112,8 +127,12 @@ class Queue(mixins._LoopBoundMixin):
 
         Put an item into the queue. If the queue is full, wait until a free
         slot is available before adding item.
+
+        Raises QueueShutDown if the queue has been shut down.
         """
         while self.full():
+            if self._is_shutdown:
+                raise QueueShutDown
             putter = self._get_loop().create_future()
             self._putters.append(putter)
             try:
@@ -125,7 +144,7 @@ class Queue(mixins._LoopBoundMixin):
                     self._putters.remove(putter)
                 except ValueError:
                     # The putter could be removed from self._putters by a
-                    # previous get_nowait call.
+                    # previous get_nowait call or a shutdown call.
                     pass
                 if not self.full() and not putter.cancelled():
                     # We were woken up by get_nowait(), but can't take
@@ -138,7 +157,11 @@ class Queue(mixins._LoopBoundMixin):
         """Put an item into the queue without blocking.
 
         If no free slot is immediately available, raise QueueFull.
+
+        Raises QueueShutDown if the queue has been shut down.
         """
+        if self._is_shutdown:
+            raise QueueShutDown
         if self.full():
             raise QueueFull
         self._put(item)
@@ -150,8 +173,13 @@ class Queue(mixins._LoopBoundMixin):
         """Remove and return an item from the queue.
 
         If queue is empty, wait until an item is available.
+
+        Raises QueueShutDown if the queue has been shut down and is empty, or
+        if the queue has been shut down immediately.
         """
         while self.empty():
+            if self._is_shutdown and self.empty():
+                raise QueueShutDown
             getter = self._get_loop().create_future()
             self._getters.append(getter)
             try:
@@ -163,7 +191,7 @@ class Queue(mixins._LoopBoundMixin):
                     self._getters.remove(getter)
                 except ValueError:
                     # The getter could be removed from self._getters by a
-                    # previous put_nowait call.
+                    # previous put_nowait call, or a shutdown call.
                     pass
                 if not self.empty() and not getter.cancelled():
                     # We were woken up by put_nowait(), but can't take
@@ -176,8 +204,13 @@ class Queue(mixins._LoopBoundMixin):
         """Remove and return an item from the queue.
 
         Return an item if one is immediately available, else raise QueueEmpty.
+
+        Raises QueueShutDown if the queue has been shut down and is empty, or
+        if the queue has been shut down immediately.
         """
         if self.empty():
+            if self._is_shutdown:
+                raise QueueShutDown
             raise QueueEmpty
         item = self._get()
         self._wakeup_next(self._putters)
@@ -194,6 +227,9 @@ class Queue(mixins._LoopBoundMixin):
         been processed (meaning that a task_done() call was received for every
         item that had been put() into the queue).
 
+        shutdown(immediate=True) calls task_done() for each remaining item in
+        the queue.
+
         Raises ValueError if called more times than there were items placed in
         the queue.
         """
@@ -214,6 +250,32 @@ class Queue(mixins._LoopBoundMixin):
         if self._unfinished_tasks > 0:
             await self._finished.wait()
 
+    def shutdown(self, immediate=False):
+        """Shut-down the queue, making queue gets and puts raise QueueShutDown.
+
+        By default, gets will only raise once the queue is empty. Set
+        'immediate' to True to make gets raise immediately instead.
+
+        All blocked callers of put() will be unblocked, and also get()
+        and join() if 'immediate'.
+        """
+        self._is_shutdown = True
+        if immediate:
+            while not self.empty():
+                self._get()
+                if self._unfinished_tasks > 0:
+                    self._unfinished_tasks -= 1
+            if self._unfinished_tasks == 0:
+                self._finished.set()
+        while self._getters:
+            getter = self._getters.popleft()
+            if not getter.done():
+                getter.set_result(None)
+        while self._putters:
+            putter = self._putters.popleft()
+            if not putter.done():
+                putter.set_result(None)
+
 
 class PriorityQueue(Queue):
     """A subclass of Queue; retrieves entries in priority order (lowest first).
index 2d058ccf6a8c7293a3b0e9b85212ef2663db7ec6..5019e9a293525d15aeccc8cec492324c89bf5519 100644 (file)
@@ -522,5 +522,204 @@ class PriorityQueueJoinTests(_QueueJoinTestMixin, unittest.IsolatedAsyncioTestCa
     q_class = asyncio.PriorityQueue
 
 
+class _QueueShutdownTestMixin:
+    q_class = None
+
+    def assertRaisesShutdown(self, msg="Didn't appear to shut-down queue"):
+        return self.assertRaises(asyncio.QueueShutDown, msg=msg)
+
+    async def test_format(self):
+        q = self.q_class()
+        q.shutdown()
+        self.assertEqual(q._format(), 'maxsize=0 shutdown')
+
+    async def test_shutdown_empty(self):
+        # Test shutting down an empty queue
+
+        # Setup empty queue, and join() and get() tasks
+        q = self.q_class()
+        loop = asyncio.get_running_loop()
+        get_task = loop.create_task(q.get())
+        await asyncio.sleep(0)  # want get task pending before shutdown
+
+        # Perform shut-down
+        q.shutdown(immediate=False)  # unfinished tasks: 0 -> 0
+
+        self.assertEqual(q.qsize(), 0)
+
+        # Ensure join() task successfully finishes
+        await q.join()
+
+        # Ensure get() task is finished, and raised ShutDown
+        await asyncio.sleep(0)
+        self.assertTrue(get_task.done())
+        with self.assertRaisesShutdown():
+            await get_task
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+    async def test_shutdown_nonempty(self):
+        # Test shutting down a non-empty queue
+
+        # Setup full queue with 1 item, and join() and put() tasks
+        q = self.q_class(maxsize=1)
+        loop = asyncio.get_running_loop()
+
+        q.put_nowait("data")
+        join_task = loop.create_task(q.join())
+        put_task = loop.create_task(q.put("data2"))
+
+        # Ensure put() task is not finished
+        await asyncio.sleep(0)
+        self.assertFalse(put_task.done())
+
+        # Perform shut-down
+        q.shutdown(immediate=False)  # unfinished tasks: 1 -> 1
+
+        self.assertEqual(q.qsize(), 1)
+
+        # Ensure put() task is finished, and raised ShutDown
+        await asyncio.sleep(0)
+        self.assertTrue(put_task.done())
+        with self.assertRaisesShutdown():
+            await put_task
+
+        # Ensure get() succeeds on enqueued item
+        self.assertEqual(await q.get(), "data")
+
+        # Ensure join() task is not finished
+        await asyncio.sleep(0)
+        self.assertFalse(join_task.done())
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+        # Ensure there is 1 unfinished task, and join() task succeeds
+        q.task_done()
+
+        await asyncio.sleep(0)
+        self.assertTrue(join_task.done())
+        await join_task
+
+        with self.assertRaises(
+            ValueError, msg="Didn't appear to mark all tasks done"
+        ):
+            q.task_done()
+
+    async def test_shutdown_immediate(self):
+        # Test immediately shutting down a queue
+
+        # Setup queue with 1 item, and a join() task
+        q = self.q_class()
+        loop = asyncio.get_running_loop()
+        q.put_nowait("data")
+        join_task = loop.create_task(q.join())
+
+        # Perform shut-down
+        q.shutdown(immediate=True)  # unfinished tasks: 1 -> 0
+
+        self.assertEqual(q.qsize(), 0)
+
+        # Ensure join() task has successfully finished
+        await asyncio.sleep(0)
+        self.assertTrue(join_task.done())
+        await join_task
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+        # Ensure there are no unfinished tasks
+        with self.assertRaises(
+            ValueError, msg="Didn't appear to mark all tasks done"
+        ):
+            q.task_done()
+
+    async def test_shutdown_immediate_with_unfinished(self):
+        # Test immediately shutting down a queue with unfinished tasks
+
+        # Setup queue with 2 items (1 retrieved), and a join() task
+        q = self.q_class()
+        loop = asyncio.get_running_loop()
+        q.put_nowait("data")
+        q.put_nowait("data")
+        join_task = loop.create_task(q.join())
+        self.assertEqual(await q.get(), "data")
+
+        # Perform shut-down
+        q.shutdown(immediate=True)  # unfinished tasks: 2 -> 1
+
+        self.assertEqual(q.qsize(), 0)
+
+        # Ensure join() task is not finished
+        await asyncio.sleep(0)
+        self.assertFalse(join_task.done())
+
+        # Ensure put() and get() raise ShutDown
+        with self.assertRaisesShutdown():
+            await q.put("data")
+        with self.assertRaisesShutdown():
+            q.put_nowait("data")
+
+        with self.assertRaisesShutdown():
+            await q.get()
+        with self.assertRaisesShutdown():
+            q.get_nowait()
+
+        # Ensure there is 1 unfinished task
+        q.task_done()
+        with self.assertRaises(
+            ValueError, msg="Didn't appear to mark all tasks done"
+        ):
+            q.task_done()
+
+        # Ensure join() task has successfully finished
+        await asyncio.sleep(0)
+        self.assertTrue(join_task.done())
+        await join_task
+
+
+class QueueShutdownTests(
+    _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+    q_class = asyncio.Queue
+
+
+class LifoQueueShutdownTests(
+    _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+    q_class = asyncio.LifoQueue
+
+
+class PriorityQueueShutdownTests(
+    _QueueShutdownTestMixin, unittest.IsolatedAsyncioTestCase
+):
+    q_class = asyncio.PriorityQueue
+
+
 if __name__ == '__main__':
     unittest.main()
diff --git a/Misc/NEWS.d/next/Library/2023-05-06-05-00-42.gh-issue-96471.S3X5I-.rst b/Misc/NEWS.d/next/Library/2023-05-06-05-00-42.gh-issue-96471.S3X5I-.rst
new file mode 100644 (file)
index 0000000..128a85d
--- /dev/null
@@ -0,0 +1,2 @@
+Add :py:class:`asyncio.Queue` termination with
+:py:meth:`~asyncio.Queue.shutdown` method.