]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-38415: Allow using @asynccontextmanager-made ctx managers as decorators (GH-16667)
authorJason Fried <fried@fb.com>
Thu, 23 Sep 2021 21:36:03 +0000 (14:36 -0700)
committerGitHub <noreply@github.com>
Thu, 23 Sep 2021 21:36:03 +0000 (23:36 +0200)
Lib/contextlib.py
Lib/test/test_contextlib_async.py
Misc/NEWS.d/next/Library/2019-10-08-14-08-59.bpo-38415.N1bUw6.rst [new file with mode: 0644]

index 8343d7e5196713ff44345039a7b0deab51d12a8e..1384d8903d17bf1a75eff6dc4b8b44c99a31a650 100644 (file)
@@ -191,6 +191,14 @@ class _AsyncGeneratorContextManager(
 ):
     """Helper for @asynccontextmanager decorator."""
 
+    def __call__(self, func):
+        @wraps(func)
+        async def inner(*args, **kwds):
+            async with self.__class__(self.func, self.args, self.kwds):
+                return await func(*args, **kwds)
+
+        return inner
+
     async def __aenter__(self):
         # do not keep args and kwds alive unnecessarily
         # they are only needed for recreation, which is not possible anymore
index 74fddef3f34ec55ba666d4d4a050d5b700665ff7..c738bf3c0bdfebcbc5220f1cc07f5a07b9ace3dd 100644 (file)
@@ -318,6 +318,82 @@ class AsyncContextManagerTestCase(unittest.TestCase):
         self.assertEqual(ncols, 10)
         self.assertEqual(depth, 0)
 
+    @_async_test
+    async def test_decorator(self):
+        entered = False
+
+        @asynccontextmanager
+        async def context():
+            nonlocal entered
+            entered = True
+            yield
+            entered = False
+
+        @context()
+        async def test():
+            self.assertTrue(entered)
+
+        self.assertFalse(entered)
+        await test()
+        self.assertFalse(entered)
+
+    @_async_test
+    async def test_decorator_with_exception(self):
+        entered = False
+
+        @asynccontextmanager
+        async def context():
+            nonlocal entered
+            try:
+                entered = True
+                yield
+            finally:
+                entered = False
+
+        @context()
+        async def test():
+            self.assertTrue(entered)
+            raise NameError('foo')
+
+        self.assertFalse(entered)
+        with self.assertRaisesRegex(NameError, 'foo'):
+            await test()
+        self.assertFalse(entered)
+
+    @_async_test
+    async def test_decorating_method(self):
+
+        @asynccontextmanager
+        async def context():
+            yield
+
+
+        class Test(object):
+
+            @context()
+            async def method(self, a, b, c=None):
+                self.a = a
+                self.b = b
+                self.c = c
+
+        # these tests are for argument passing when used as a decorator
+        test = Test()
+        await test.method(1, 2)
+        self.assertEqual(test.a, 1)
+        self.assertEqual(test.b, 2)
+        self.assertEqual(test.c, None)
+
+        test = Test()
+        await test.method('a', 'b', 'c')
+        self.assertEqual(test.a, 'a')
+        self.assertEqual(test.b, 'b')
+        self.assertEqual(test.c, 'c')
+
+        test = Test()
+        await test.method(a=1, b=2)
+        self.assertEqual(test.a, 1)
+        self.assertEqual(test.b, 2)
+
 
 class AclosingTestCase(unittest.TestCase):
 
diff --git a/Misc/NEWS.d/next/Library/2019-10-08-14-08-59.bpo-38415.N1bUw6.rst b/Misc/NEWS.d/next/Library/2019-10-08-14-08-59.bpo-38415.N1bUw6.rst
new file mode 100644 (file)
index 0000000..f99bf0d
--- /dev/null
@@ -0,0 +1,3 @@
+Added missing behavior to :func:`contextlib.asynccontextmanager` to match
+:func:`contextlib.contextmanager` so decorated functions can themselves be
+decorators.