]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
GH-117881: fix athrow().throw()/asend().throw() concurrent access (GH-117882)
authorThomas Grainger <tagrain@gmail.com>
Wed, 1 May 2024 06:44:01 +0000 (07:44 +0100)
committerGitHub <noreply@github.com>
Wed, 1 May 2024 06:44:01 +0000 (08:44 +0200)
Lib/test/test_asyncgen.py
Misc/NEWS.d/next/Core and Builtins/2024-04-15-07-37-09.gh-issue-117881.07H0wI.rst [new file with mode: 0644]
Objects/genobject.c

index a1e9e1b89c6a2cf350cc74bf059cf423705a6431..1985ede656e7a05b4f2bfb57a70aecef322d1c35 100644 (file)
@@ -393,6 +393,151 @@ class AsyncGenTest(unittest.TestCase):
                 r'anext\(\): asynchronous generator is already running'):
             an.__next__()
 
+        with self.assertRaisesRegex(RuntimeError,
+                r"cannot reuse already awaited __anext__\(\)/asend\(\)"):
+            an.send(None)
+
+    def test_async_gen_asend_throw_concurrent_with_send(self):
+        import types
+
+        @types.coroutine
+        def _async_yield(v):
+            return (yield v)
+
+        class MyExc(Exception):
+            pass
+
+        async def agenfn():
+            while True:
+                try:
+                    await _async_yield(None)
+                except MyExc:
+                    pass
+            return
+            yield
+
+
+        agen = agenfn()
+        gen = agen.asend(None)
+        gen.send(None)
+        gen2 = agen.asend(None)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r'anext\(\): asynchronous generator is already running'):
+            gen2.throw(MyExc)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r"cannot reuse already awaited __anext__\(\)/asend\(\)"):
+            gen2.send(None)
+
+    def test_async_gen_athrow_throw_concurrent_with_send(self):
+        import types
+
+        @types.coroutine
+        def _async_yield(v):
+            return (yield v)
+
+        class MyExc(Exception):
+            pass
+
+        async def agenfn():
+            while True:
+                try:
+                    await _async_yield(None)
+                except MyExc:
+                    pass
+            return
+            yield
+
+
+        agen = agenfn()
+        gen = agen.asend(None)
+        gen.send(None)
+        gen2 = agen.athrow(MyExc)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r'athrow\(\): asynchronous generator is already running'):
+            gen2.throw(MyExc)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r"cannot reuse already awaited aclose\(\)/athrow\(\)"):
+            gen2.send(None)
+
+    def test_async_gen_asend_throw_concurrent_with_throw(self):
+        import types
+
+        @types.coroutine
+        def _async_yield(v):
+            return (yield v)
+
+        class MyExc(Exception):
+            pass
+
+        async def agenfn():
+            try:
+                yield
+            except MyExc:
+                pass
+            while True:
+                try:
+                    await _async_yield(None)
+                except MyExc:
+                    pass
+
+
+        agen = agenfn()
+        with self.assertRaises(StopIteration):
+            agen.asend(None).send(None)
+
+        gen = agen.athrow(MyExc)
+        gen.throw(MyExc)
+        gen2 = agen.asend(MyExc)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r'anext\(\): asynchronous generator is already running'):
+            gen2.throw(MyExc)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r"cannot reuse already awaited __anext__\(\)/asend\(\)"):
+            gen2.send(None)
+
+    def test_async_gen_athrow_throw_concurrent_with_throw(self):
+        import types
+
+        @types.coroutine
+        def _async_yield(v):
+            return (yield v)
+
+        class MyExc(Exception):
+            pass
+
+        async def agenfn():
+            try:
+                yield
+            except MyExc:
+                pass
+            while True:
+                try:
+                    await _async_yield(None)
+                except MyExc:
+                    pass
+
+        agen = agenfn()
+        with self.assertRaises(StopIteration):
+            agen.asend(None).send(None)
+
+        gen = agen.athrow(MyExc)
+        gen.throw(MyExc)
+        gen2 = agen.athrow(None)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r'athrow\(\): asynchronous generator is already running'):
+            gen2.throw(MyExc)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r"cannot reuse already awaited aclose\(\)/athrow\(\)"):
+            gen2.send(None)
+
     def test_async_gen_3_arg_deprecation_warning(self):
         async def gen():
             yield 123
@@ -1571,6 +1716,8 @@ class AsyncGenAsyncioTest(unittest.TestCase):
         self.assertIsInstance(message['exception'], ZeroDivisionError)
         self.assertIn('unhandled exception during asyncio.run() shutdown',
                       message['message'])
+        del message, messages
+        gc_collect()
 
     def test_async_gen_expression_01(self):
         async def arange(n):
@@ -1624,6 +1771,7 @@ class AsyncGenAsyncioTest(unittest.TestCase):
         asyncio.run(main())
 
         self.assertEqual([], messages)
+        gc_collect()
 
     def test_async_gen_await_same_anext_coro_twice(self):
         async def async_iterate():
@@ -1809,9 +1957,56 @@ class TestUnawaitedWarnings(unittest.TestCase):
         g = gen()
         with self.assertRaises(MyException):
             g.aclose().throw(MyException)
-            del g
-            gc_collect()
 
+        del g
+        gc_collect()  # does not warn unawaited
+
+    def test_asend_send_already_running(self):
+        @types.coroutine
+        def _async_yield(v):
+            return (yield v)
+
+        async def agenfn():
+            while True:
+                await _async_yield(1)
+            return
+            yield
+
+        agen = agenfn()
+        gen = agen.asend(None)
+        gen.send(None)
+        gen2 = agen.asend(None)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r'anext\(\): asynchronous generator is already running'):
+            gen2.send(None)
+
+        del gen2
+        gc_collect()  # does not warn unawaited
+
+
+    def test_athrow_send_already_running(self):
+        @types.coroutine
+        def _async_yield(v):
+            return (yield v)
+
+        async def agenfn():
+            while True:
+                await _async_yield(1)
+            return
+            yield
+
+        agen = agenfn()
+        gen = agen.asend(None)
+        gen.send(None)
+        gen2 = agen.athrow(Exception)
+
+        with self.assertRaisesRegex(RuntimeError,
+                r'athrow\(\): asynchronous generator is already running'):
+            gen2.send(None)
+
+        del gen2
+        gc_collect()  # does not warn unawaited
 
 if __name__ == "__main__":
     unittest.main()
diff --git a/Misc/NEWS.d/next/Core and Builtins/2024-04-15-07-37-09.gh-issue-117881.07H0wI.rst b/Misc/NEWS.d/next/Core and Builtins/2024-04-15-07-37-09.gh-issue-117881.07H0wI.rst
new file mode 100644 (file)
index 0000000..75b3426
--- /dev/null
@@ -0,0 +1 @@
+prevent concurrent access to an async generator via athrow().throw() or asend().throw()
index a1ed1cbd2bfb5809b2454a4d667eb07419a0ba79..89bb21a8674b9f1f200cc00a8f4bde2392bdda6c 100644 (file)
@@ -1774,6 +1774,7 @@ async_gen_asend_send(PyAsyncGenASend *o, PyObject *arg)
 
     if (o->ags_state == AWAITABLE_STATE_INIT) {
         if (o->ags_gen->ag_running_async) {
+            o->ags_state = AWAITABLE_STATE_CLOSED;
             PyErr_SetString(
                 PyExc_RuntimeError,
                 "anext(): asynchronous generator is already running");
@@ -1817,10 +1818,24 @@ async_gen_asend_throw(PyAsyncGenASend *o, PyObject *const *args, Py_ssize_t narg
         return NULL;
     }
 
+    if (o->ags_state == AWAITABLE_STATE_INIT) {
+        if (o->ags_gen->ag_running_async) {
+            o->ags_state = AWAITABLE_STATE_CLOSED;
+            PyErr_SetString(
+                PyExc_RuntimeError,
+                "anext(): asynchronous generator is already running");
+            return NULL;
+        }
+
+        o->ags_state = AWAITABLE_STATE_ITER;
+        o->ags_gen->ag_running_async = 1;
+    }
+
     result = gen_throw((PyGenObject*)o->ags_gen, args, nargs);
     result = async_gen_unwrap_value(o->ags_gen, result);
 
     if (result == NULL) {
+        o->ags_gen->ag_running_async = 0;
         o->ags_state = AWAITABLE_STATE_CLOSED;
     }
 
@@ -2209,10 +2224,31 @@ async_gen_athrow_throw(PyAsyncGenAThrow *o, PyObject *const *args, Py_ssize_t na
         return NULL;
     }
 
+    if (o->agt_state == AWAITABLE_STATE_INIT) {
+        if (o->agt_gen->ag_running_async) {
+            o->agt_state = AWAITABLE_STATE_CLOSED;
+            if (o->agt_args == NULL) {
+                PyErr_SetString(
+                    PyExc_RuntimeError,
+                    "aclose(): asynchronous generator is already running");
+            }
+            else {
+                PyErr_SetString(
+                    PyExc_RuntimeError,
+                    "athrow(): asynchronous generator is already running");
+            }
+            return NULL;
+        }
+
+        o->agt_state = AWAITABLE_STATE_ITER;
+        o->agt_gen->ag_running_async = 1;
+    }
+
     retval = gen_throw((PyGenObject*)o->agt_gen, args, nargs);
     if (o->agt_args) {
         retval = async_gen_unwrap_value(o->agt_gen, retval);
         if (retval == NULL) {
+            o->agt_gen->ag_running_async = 0;
             o->agt_state = AWAITABLE_STATE_CLOSED;
         }
         return retval;
@@ -2226,6 +2262,7 @@ async_gen_athrow_throw(PyAsyncGenAThrow *o, PyObject *const *args, Py_ssize_t na
             return NULL;
         }
         if (retval == NULL) {
+            o->agt_gen->ag_running_async = 0;
             o->agt_state = AWAITABLE_STATE_CLOSED;
         }
         if (PyErr_ExceptionMatches(PyExc_StopAsyncIteration) ||