]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
[3.10] gh-95882: Add tests for traceback from contextlib context managers (GH-95883...
authorIrit Katriel <1055913+iritkatriel@users.noreply.github.com>
Tue, 3 Jan 2023 22:24:19 +0000 (22:24 +0000)
committerGitHub <noreply@github.com>
Tue, 3 Jan 2023 22:24:19 +0000 (22:24 +0000)
Lib/test/test_contextlib.py
Lib/test/test_contextlib_async.py

index 68bd45d97232c7a17ed68017ec02c3dab4242b09..bcedf17fc8061f51c0bf2c591d5cd3f5e32e83d8 100644 (file)
@@ -87,6 +87,56 @@ class ContextManagerTestCase(unittest.TestCase):
                 raise ZeroDivisionError()
         self.assertEqual(state, [1, 42, 999])
 
+    def test_contextmanager_traceback(self):
+        @contextmanager
+        def f():
+            yield
+
+        try:
+            with f():
+                1/0
+        except ZeroDivisionError as e:
+            frames = traceback.extract_tb(e.__traceback__)
+
+        self.assertEqual(len(frames), 1)
+        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+        self.assertEqual(frames[0].line, '1/0')
+
+        # Repeat with RuntimeError (which goes through a different code path)
+        class RuntimeErrorSubclass(RuntimeError):
+            pass
+
+        try:
+            with f():
+                raise RuntimeErrorSubclass(42)
+        except RuntimeErrorSubclass as e:
+            frames = traceback.extract_tb(e.__traceback__)
+
+        self.assertEqual(len(frames), 1)
+        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+        self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
+
+        class StopIterationSubclass(StopIteration):
+            pass
+
+        for stop_exc in (
+            StopIteration('spam'),
+            StopIterationSubclass('spam'),
+        ):
+            with self.subTest(type=type(stop_exc)):
+                try:
+                    with f():
+                        raise stop_exc
+                except type(stop_exc) as e:
+                    self.assertIs(e, stop_exc)
+                    frames = traceback.extract_tb(e.__traceback__)
+                else:
+                    self.fail(f'{stop_exc} was suppressed')
+
+                self.assertEqual(len(frames), 1)
+                self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+                self.assertEqual(frames[0].line, 'raise stop_exc')
+
     def test_contextmanager_no_reraise(self):
         @contextmanager
         def whee():
index d44d36209c7035bebc3b695ca616f4aae60e68c1..3d1079c205aca101e63ea2a410441620bcd1d717 100644 (file)
@@ -4,6 +4,7 @@ from contextlib import (
     AsyncExitStack, nullcontext, aclosing, contextmanager)
 import functools
 from test import support
+import traceback
 import unittest
 
 from test.test_contextlib import TestBaseExitStack
@@ -124,6 +125,63 @@ class AsyncContextManagerTestCase(unittest.TestCase):
                 raise ZeroDivisionError()
         self.assertEqual(state, [1, 42, 999])
 
+    @_async_test
+    async def test_contextmanager_traceback(self):
+        @asynccontextmanager
+        async def f():
+            yield
+
+        try:
+            async with f():
+                1/0
+        except ZeroDivisionError as e:
+            frames = traceback.extract_tb(e.__traceback__)
+
+        self.assertEqual(len(frames), 1)
+        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+        self.assertEqual(frames[0].line, '1/0')
+
+        # Repeat with RuntimeError (which goes through a different code path)
+        class RuntimeErrorSubclass(RuntimeError):
+            pass
+
+        try:
+            async with f():
+                raise RuntimeErrorSubclass(42)
+        except RuntimeErrorSubclass as e:
+            frames = traceback.extract_tb(e.__traceback__)
+
+        self.assertEqual(len(frames), 1)
+        self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+        self.assertEqual(frames[0].line, 'raise RuntimeErrorSubclass(42)')
+
+        class StopIterationSubclass(StopIteration):
+            pass
+
+        class StopAsyncIterationSubclass(StopAsyncIteration):
+            pass
+
+        for stop_exc in (
+            StopIteration('spam'),
+            StopAsyncIteration('ham'),
+            StopIterationSubclass('spam'),
+            StopAsyncIterationSubclass('spam')
+        ):
+            with self.subTest(type=type(stop_exc)):
+                try:
+                    async with f():
+                        raise stop_exc
+                except type(stop_exc) as e:
+                    self.assertIs(e, stop_exc)
+                    frames = traceback.extract_tb(e.__traceback__)
+                else:
+                    self.fail(f'{stop_exc} was suppressed')
+
+                self.assertEqual(len(frames), 1)
+                self.assertEqual(frames[0].name, 'test_contextmanager_traceback')
+                self.assertEqual(frames[0].line, 'raise stop_exc')
+
+
     @_async_test
     async def test_contextmanager_no_reraise(self):
         @asynccontextmanager