]> git.ipfire.org Git - thirdparty/Python/cpython.git/commitdiff
bpo-47131: Speedup AST comparisons in test_unparse by using node traversal (GH-32132)
authorJeremy Kloth <jeremy.kloth@gmail.com>
Sat, 2 Apr 2022 01:54:04 +0000 (19:54 -0600)
committerGitHub <noreply@github.com>
Sat, 2 Apr 2022 01:54:04 +0000 (02:54 +0100)
Lib/test/test_unparse.py

index e38b33574ccccfcfa7b3f59277090eba8f4d8c4e..f999ae8c16ceafc0de29a855110d59d5deb9adc3 100644 (file)
@@ -130,7 +130,43 @@ docstring_prefixes = (
 
 class ASTTestCase(unittest.TestCase):
     def assertASTEqual(self, ast1, ast2):
-        self.assertEqual(ast.dump(ast1), ast.dump(ast2))
+        # Ensure the comparisons start at an AST node
+        self.assertIsInstance(ast1, ast.AST)
+        self.assertIsInstance(ast2, ast.AST)
+
+        # An AST comparison routine modeled after ast.dump(), but
+        # instead of string building, it traverses the two trees
+        # in lock-step.
+        def traverse_compare(a, b, missing=object()):
+            if type(a) is not type(b):
+                self.fail(f"{type(a)!r} is not {type(b)!r}")
+            if isinstance(a, ast.AST):
+                for field in a._fields:
+                    value1 = getattr(a, field, missing)
+                    value2 = getattr(b, field, missing)
+                    # Singletons are equal by definition, so further
+                    # testing can be skipped.
+                    if value1 is not value2:
+                        traverse_compare(value1, value2)
+            elif isinstance(a, list):
+                try:
+                    for node1, node2 in zip(a, b, strict=True):
+                        traverse_compare(node1, node2)
+                except ValueError:
+                    # Attempt a "pretty" error ala assertSequenceEqual()
+                    len1 = len(a)
+                    len2 = len(b)
+                    if len1 > len2:
+                        what = "First"
+                        diff = len1 - len2
+                    else:
+                        what = "Second"
+                        diff = len2 - len1
+                    msg = f"{what} list contains {diff} additional elements."
+                    raise self.failureException(msg) from None
+            elif a != b:
+                self.fail(f"{a!r} != {b!r}")
+        traverse_compare(ast1, ast2)
 
     def check_ast_roundtrip(self, code1, **kwargs):
         with self.subTest(code1=code1, ast_parse_kwargs=kwargs):