b = dis.Bytecode.from_traceback(tb)
self.assertEqual(b.dis(), dis_traceback)
+
+class TestDisTraceback(unittest.TestCase):
+ def setUp(self) -> None:
+ try: # We need to clean up existing tracebacks
+ del sys.last_traceback
+ except AttributeError:
+ pass
+ return super().setUp()
+
+ def get_disassembly(self, tb):
+ output = io.StringIO()
+ with contextlib.redirect_stdout(output):
+ dis.distb(tb)
+ return output.getvalue()
+
+ def test_distb_empty(self):
+ with self.assertRaises(RuntimeError):
+ dis.distb()
+
+ def test_distb_last_traceback(self):
+ # We need to have an existing last traceback in `sys`:
+ tb = get_tb()
+ sys.last_traceback = tb
+
+ self.assertEqual(self.get_disassembly(None), dis_traceback)
+
+ def test_distb_explicit_arg(self):
+ tb = get_tb()
+
+ self.assertEqual(self.get_disassembly(tb), dis_traceback)
+
+
+class TestDisTracebackWithFile(TestDisTraceback):
+ # Run the `distb` tests again, using the file arg instead of print
+ def get_disassembly(self, tb):
+ output = io.StringIO()
+ with contextlib.redirect_stdout(output):
+ dis.distb(tb, file=output)
+ return output.getvalue()
+
+
if __name__ == "__main__":
unittest.main()