]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
gh-125862: Keep ContextDecorator open across generator/coroutine execution (GH-136212)
authorAlex Grönholm <alex.gronholm@nextday.fi>
Tue, 28 Apr 2026 05:26:38 +0000 (08:26 +0300)
committerGitHub <noreply@github.com>
Tue, 28 Apr 2026 05:26:38 +0000 (05:26 +0000)
ContextDecorator and AsyncContextDecorator (and therefore @contextmanager
and @asynccontextmanager used as decorators) now detect generator,
coroutine, and asynchronous generator functions and emit a wrapper of the
matching kind, so the context manager spans iteration or await rather than
just the call that constructs the lazy object.  Wrapped generators are
explicitly closed when iteration ends.

For asynchronous generator wrappers, values passed via asend() and
exceptions via athrow() are not forwarded to the wrapped generator.

AsyncContextDecorator now also accepts synchronous functions and
generators, returning an asynchronous wrapper; ContextDecorator remains
the recommended choice for those.

inspect.isgeneratorfunction(), iscoroutinefunction(), and
isasyncgenfunction() now return True for the decorated result when the
input is of that kind.

---------

Co-authored-by: Gregory P. Smith <greg@krypto.org>
Doc/library/contextlib.rst
Doc/whatsnew/3.15.rst
Lib/contextlib.py
Lib/test/test_contextlib.py
Lib/test/test_contextlib_async.py
Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst [new file with mode: 0644]

index 5c6403879ab505671e53544b22aa6e65ae3521fa..77bac8dcc3afbdb9f852169729d0374f98faf927 100644 (file)
@@ -467,12 +467,40 @@ Functions and classes provided:
       statements. If this is not the case, then the original construct with the
       explicit :keyword:`!with` statement inside the function should be used.
 
+   When the decorated callable is a generator function, coroutine function, or
+   asynchronous generator function, the returned wrapper is of the same kind
+   and keeps the context manager open for the lifetime of the iteration or
+   await rather than only for the call that creates the generator or coroutine
+   object.  Wrapped generators and asynchronous generators are explicitly
+   closed when iteration ends, as if by :func:`closing` or :func:`aclosing`.
+
+   .. note::
+      For asynchronous generators the wrapper re-yields each value with
+      ``async for``; values sent with :meth:`~agen.asend` and exceptions
+      thrown with :meth:`~agen.athrow` are not forwarded to the wrapped
+      generator.
+
    .. versionadded:: 3.2
 
+   .. versionchanged:: next
+      Decorating a generator function, coroutine function, or asynchronous
+      generator function now keeps the context manager open across iteration
+      or await.  Previously the context manager exited as soon as the
+      generator or coroutine object was created.
+
 
 .. class:: AsyncContextDecorator
 
-   Similar to :class:`ContextDecorator` but only for asynchronous functions.
+   Similar to :class:`ContextDecorator`, but the context manager is entered
+   and exited with :keyword:`async with`.  Decorate coroutine functions and
+   asynchronous generator functions with this class; the returned wrapper is
+   of the same kind.
+
+   .. note::
+      Synchronous functions and generators are accepted, but the wrapper is
+      always asynchronous, so the decorated callable must then be awaited or
+      iterated with ``async for``.  If that change of calling convention is
+      not intended, use :class:`ContextDecorator` instead.
 
    Example of ``AsyncContextDecorator``::
 
@@ -510,6 +538,13 @@ Functions and classes provided:
 
    .. versionadded:: 3.10
 
+   .. versionchanged:: next
+      Decorating an asynchronous generator function now keeps the context
+      manager open across iteration.  Previously the context manager exited
+      as soon as the generator object was created.  Synchronous functions
+      and synchronous generator functions are also now accepted, with an
+      asynchronous wrapper returned.
+
 
 .. class:: ExitStack()
 
index 65965d504c09762027b084c83648e83d19736d8b..ee49d043de264195be2cfc188afd82ef267fd281 100644 (file)
@@ -846,6 +846,15 @@ contextlib
   consistency with the :keyword:`with` and :keyword:`async with` statements.
   (Contributed by Serhiy Storchaka in :gh:`144386`.)
 
+* :class:`~contextlib.ContextDecorator` and
+  :class:`~contextlib.AsyncContextDecorator` (and therefore
+  :func:`~contextlib.contextmanager` and :func:`~contextlib.asynccontextmanager`
+  used as decorators) now detect generator functions, coroutine functions, and
+  asynchronous generator functions and keep the context manager open across
+  iteration or await.  Previously the context manager exited as soon as the
+  generator or coroutine object was created.
+  (Contributed by Alex Grönholm & Gregory P. Smith in :gh:`125862`.)
+
 
 dataclasses
 -----------
index cac3e39eba8b520568335c476919c687feb7d088..efc02bfa9243da6cb42e95d3f6b4dd0adc6782eb 100644 (file)
@@ -1,10 +1,16 @@
 """Utilities for with-statement contexts.  See PEP 343."""
+
 import abc
 import os
 import sys
 import _collections_abc
 from collections import deque
 from functools import wraps
+lazy from inspect import (
+    isasyncgenfunction as _isasyncgenfunction,
+    iscoroutinefunction as _iscoroutinefunction,
+    isgeneratorfunction as _isgeneratorfunction,
+)
 from types import GenericAlias
 
 __all__ = ["asynccontextmanager", "contextmanager", "closing", "nullcontext",
@@ -79,11 +85,37 @@ class ContextDecorator(object):
         return self
 
     def __call__(self, func):
-        @wraps(func)
-        def inner(*args, **kwds):
-            with self._recreate_cm():
-                return func(*args, **kwds)
-        return inner
+        wrapper = wraps(func)
+        if _isasyncgenfunction(func):
+
+            async def asyncgen_inner(*args, **kwds):
+                with self._recreate_cm():
+                    async with aclosing(func(*args, **kwds)) as gen:
+                        async for value in gen:
+                            yield value
+
+            return wrapper(asyncgen_inner)
+        elif _iscoroutinefunction(func):
+
+            async def async_inner(*args, **kwds):
+                with self._recreate_cm():
+                    return await func(*args, **kwds)
+
+            return wrapper(async_inner)
+        elif _isgeneratorfunction(func):
+
+            def gen_inner(*args, **kwds):
+                with self._recreate_cm(), closing(func(*args, **kwds)) as gen:
+                    return (yield from gen)
+
+            return wrapper(gen_inner)
+        else:
+
+            def inner(*args, **kwds):
+                with self._recreate_cm():
+                    return func(*args, **kwds)
+
+            return wrapper(inner)
 
 
 class AsyncContextDecorator(object):
@@ -95,11 +127,41 @@ class AsyncContextDecorator(object):
         return self
 
     def __call__(self, func):
-        @wraps(func)
-        async def inner(*args, **kwds):
-            async with self._recreate_cm():
-                return await func(*args, **kwds)
-        return inner
+        wrapper = wraps(func)
+        if _isasyncgenfunction(func):
+
+            async def asyncgen_inner(*args, **kwds):
+                async with (
+                    self._recreate_cm(),
+                    aclosing(func(*args, **kwds)) as gen
+                ):
+                    async for value in gen:
+                        yield value
+
+            return wrapper(asyncgen_inner)
+        elif _iscoroutinefunction(func):
+
+            async def async_inner(*args, **kwds):
+                async with self._recreate_cm():
+                    return await func(*args, **kwds)
+
+            return wrapper(async_inner)
+        elif _isgeneratorfunction(func):
+
+            async def gen_inner(*args, **kwds):
+                async with self._recreate_cm():
+                    with closing(func(*args, **kwds)) as gen:
+                        for value in gen:
+                            yield value
+
+            return wrapper(gen_inner)
+        else:
+
+            async def inner(*args, **kwds):
+                async with self._recreate_cm():
+                    return func(*args, **kwds)
+
+            return wrapper(inner)
 
 
 class _GeneratorContextManagerBase:
index 1fd8b3cb18c2d44254995c51e8accf4c9a932c7b..e291f814edbd9303fe2411d2f15d03acedc6dd4a 100644 (file)
@@ -680,6 +680,154 @@ class TestContextDecorator(unittest.TestCase):
         self.assertEqual(state, [1, 'something else', 999])
 
 
+    def test_contextmanager_decorate_generator_function(self):
+        @contextmanager
+        def woohoo(y):
+            state.append(y)
+            yield
+            state.append(999)
+
+        state = []
+        @woohoo(1)
+        def test(x):
+            self.assertEqual(state, [1])
+            state.append(x)
+            yield
+            state.append("second item")
+            return "result"
+
+        gen = test("something")
+        for _ in gen:
+            self.assertEqual(state, [1, "something"])
+        self.assertEqual(state, [1, "something", "second item", 999])
+
+        # The wrapped generator's return value is preserved.
+        state = []
+        gen = test("something")
+        with self.assertRaises(StopIteration) as cm:
+            while True:
+                next(gen)
+        self.assertEqual(cm.exception.value, "result")
+
+
+    def test_contextmanager_decorate_generator_function_exception(self):
+        @contextmanager
+        def woohoo():
+            state.append("enter")
+            try:
+                yield
+            finally:
+                state.append("exit")
+
+        state = []
+        @woohoo()
+        def test():
+            state.append("body")
+            yield
+            raise ZeroDivisionError
+
+        with self.assertRaises(ZeroDivisionError):
+            for _ in test():
+                pass
+        self.assertEqual(state, ["enter", "body", "exit"])
+
+
+    def test_contextmanager_decorate_generator_function_early_stop(self):
+        @contextmanager
+        def woohoo():
+            state.append("enter")
+            try:
+                yield
+            finally:
+                state.append("exit")
+
+        state = []
+        @woohoo()
+        def test():
+            try:
+                yield 1
+                yield 2
+            finally:
+                state.append("inner closed")
+
+        gen = test()
+        self.assertEqual(next(gen), 1)
+        gen.close()
+        # The inner generator is closed before the context manager exits.
+        self.assertEqual(state, ["enter", "inner closed", "exit"])
+
+
+    def test_contextmanager_decorate_generator_function_send_throw(self):
+        @contextmanager
+        def woohoo():
+            yield
+
+        @woohoo()
+        def test():
+            received = yield "first"
+            state.append(("received", received))
+            try:
+                yield "second"
+            except ValueError as exc:
+                state.append(("caught", type(exc)))
+                yield "after throw"
+
+        # .send() and .throw() are forwarded to the wrapped generator.
+        state = []
+        gen = test()
+        self.assertEqual(next(gen), "first")
+        self.assertEqual(gen.send("VALUE"), "second")
+        self.assertEqual(gen.throw(ValueError), "after throw")
+        gen.close()
+        self.assertEqual(
+            state, [("received", "VALUE"), ("caught", ValueError)]
+        )
+
+
+    def test_contextmanager_decorate_coroutine_function(self):
+        @contextmanager
+        def woohoo(y):
+            state.append(y)
+            yield
+            state.append(999)
+
+        state = []
+        @woohoo(1)
+        async def test(x):
+            self.assertEqual(state, [1])
+            state.append(x)
+
+        coro = test("something")
+        with self.assertRaises(StopIteration):
+            coro.send(None)
+
+        self.assertEqual(state, [1, "something", 999])
+
+
+    def test_contextmanager_decorate_asyncgen_function(self):
+        @contextmanager
+        def woohoo(y):
+            state.append(y)
+            yield
+            state.append(999)
+
+        state = []
+        @woohoo(1)
+        async def test(x):
+            self.assertEqual(state, [1])
+            state.append(x)
+            yield
+            state.append("second item")
+
+        agen = test("something")
+        with self.assertRaises(StopIteration):
+            agen.asend(None).send(None)
+        with self.assertRaises(StopAsyncIteration):
+            agen.asend(None).send(None)
+
+        self.assertEqual(state, [1, "something", "second item", 999])
+
+
 class TestBaseExitStack:
     exit_stack = None
 
index 248d32d615225d6073f1077459c993d71cee6afa..95bdfdb3d9d4a661632059b1069c101a9332cd8f 100644 (file)
@@ -402,6 +402,144 @@ class AsyncContextManagerTestCase(unittest.TestCase):
         await test()
         self.assertFalse(entered)
 
+    @_async_test
+    async def test_decorator_decorate_sync_function(self):
+        @asynccontextmanager
+        async def context():
+            state.append(1)
+            yield
+            state.append(999)
+
+        state = []
+        @context()
+        def test(x):
+            self.assertEqual(state, [1])
+            state.append(x)
+
+        await test("something")
+        self.assertEqual(state, [1, "something", 999])
+
+    @_async_test
+    async def test_decorator_decorate_generator_function(self):
+        @asynccontextmanager
+        async def context():
+            state.append(1)
+            yield
+            state.append(999)
+
+        state = []
+        @context()
+        def test(x):
+            self.assertEqual(state, [1])
+            state.append(x)
+            yield
+            state.append("second item")
+
+        async for _ in test("something"):
+            self.assertEqual(state, [1, "something"])
+        self.assertEqual(state, [1, "something", "second item", 999])
+
+    @_async_test
+    async def test_decorator_decorate_asyncgen_function(self):
+        @asynccontextmanager
+        async def context():
+            state.append(1)
+            yield
+            state.append(999)
+
+        state = []
+        @context()
+        async def test(x):
+            self.assertEqual(state, [1])
+            state.append(x)
+            yield
+            state.append("second item")
+
+        async for _ in test("something"):
+            self.assertEqual(state, [1, "something"])
+        self.assertEqual(state, [1, "something", "second item", 999])
+
+    @_async_test
+    async def test_decorator_decorate_asyncgen_function_exception(self):
+        @asynccontextmanager
+        async def context():
+            state.append("enter")
+            try:
+                yield
+            finally:
+                state.append("exit")
+
+        state = []
+        @context()
+        async def test():
+            state.append("body")
+            yield
+            raise ZeroDivisionError
+
+        with self.assertRaises(ZeroDivisionError):
+            async for _ in test():
+                pass
+        self.assertEqual(state, ["enter", "body", "exit"])
+
+    @_async_test
+    async def test_decorator_decorate_asyncgen_function_early_stop(self):
+        @asynccontextmanager
+        async def context():
+            state.append("enter")
+            try:
+                yield
+            finally:
+                state.append("exit")
+
+        state = []
+        @context()
+        async def test():
+            try:
+                yield 1
+                yield 2
+            finally:
+                state.append("inner closed")
+
+        agen = test()
+        async for value in agen:
+            self.assertEqual(value, 1)
+            break
+        await agen.aclose()
+        # The inner async generator is closed before the context
+        # manager exits.
+        self.assertEqual(state, ["enter", "inner closed", "exit"])
+
+    @_async_test
+    async def test_decorator_decorate_asyncgen_function_asend_athrow(self):
+        @asynccontextmanager
+        async def context():
+            yield
+
+        @context()
+        async def test():
+            try:
+                received = yield "first"
+                state.append(("received", received))
+                yield "second"
+            except ValueError:
+                state.append("inner saw ValueError")
+                raise
+            finally:
+                state.append("inner closed")
+
+        # asend() values and athrow() exceptions are not forwarded to the
+        # wrapped generator (a documented limitation).
+        state = []
+        agen = test()
+        self.assertEqual(await agen.__anext__(), "first")
+        self.assertEqual(await agen.asend("VALUE"), "second")
+        # The inner generator received None, not "VALUE".
+        self.assertEqual(state, [("received", None)])
+        with self.assertRaises(ValueError):
+            await agen.athrow(ValueError)
+        # The inner generator was closed, not thrown into.
+        self.assertEqual(state, [("received", None), "inner closed"])
+
     @_async_test
     async def test_decorator_with_exception(self):
         entered = False
diff --git a/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst b/Misc/NEWS.d/next/Library/2025-07-02-17-01-17.gh-issue-125862.WgFYj3.rst
new file mode 100644 (file)
index 0000000..1ccc91d
--- /dev/null
@@ -0,0 +1,4 @@
+The :func:`contextlib.contextmanager` and
+:func:`contextlib.asynccontextmanager` decorators now work correctly with
+generators, coroutine functions, and async generators when the wrapped
+callables are used as decorators.