yield node
+def compare(
+ a,
+ b,
+ /,
+ *,
+ compare_attributes=False,
+):
+ """Recursively compares two ASTs.
+
+ compare_attributes affects whether AST attributes are considered
+ in the comparison. If compare_attributes is False (default), then
+ attributes are ignored. Otherwise they must all be equal. This
+ option is useful to check whether the ASTs are structurally equal but
+ might differ in whitespace or similar details.
+ """
+
+ def _compare(a, b):
+ # Compare two fields on an AST object, which may themselves be
+ # AST objects, lists of AST objects, or primitive ASDL types
+ # like identifiers and constants.
+ if isinstance(a, AST):
+ return compare(
+ a,
+ b,
+ compare_attributes=compare_attributes,
+ )
+ elif isinstance(a, list):
+ # If a field is repeated, then both objects will represent
+ # the value as a list.
+ if len(a) != len(b):
+ return False
+ for a_item, b_item in zip(a, b):
+ if not _compare(a_item, b_item):
+ return False
+ else:
+ return True
+ else:
+ return type(a) is type(b) and a == b
+
+ def _compare_fields(a, b):
+ if a._fields != b._fields:
+ return False
+ for field in a._fields:
+ a_field = getattr(a, field)
+ b_field = getattr(b, field)
+ if not _compare(a_field, b_field):
+ return False
+ else:
+ return True
+
+ def _compare_attributes(a, b):
+ if a._attributes != b._attributes:
+ return False
+ # Attributes are always ints.
+ for attr in a._attributes:
+ a_attr = getattr(a, attr)
+ b_attr = getattr(b, attr)
+ if a_attr != b_attr:
+ return False
+ else:
+ return True
+
+ if type(a) is not type(b):
+ return False
+ if not _compare_fields(a, b):
+ return False
+ if compare_attributes and not _compare_attributes(a, b):
+ return False
+ return True
+
+
class NodeVisitor(object):
"""
A node visitor base class that walks the abstract syntax tree and calls a
result.append(to_tuple(getattr(t, f)))
return tuple(result)
+STDLIB = os.path.dirname(ast.__file__)
+STDLIB_FILES = [fn for fn in os.listdir(STDLIB) if fn.endswith(".py")]
+STDLIB_FILES.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
# These tests are compiled through "exec"
# There should be at least one test per statement
expressions[0] = f"expr = {ast.expr.__subclasses__()[0].__doc__}"
self.assertCountEqual(ast.expr.__doc__.split("\n"), expressions)
+ def test_compare_basics(self):
+ self.assertTrue(ast.compare(ast.parse("x = 10"), ast.parse("x = 10")))
+ self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("")))
+ self.assertFalse(ast.compare(ast.parse("x = 10"), ast.parse("x")))
+ self.assertFalse(
+ ast.compare(ast.parse("x = 10;y = 20"), ast.parse("class C:pass"))
+ )
+
+ def test_compare_modified_ast(self):
+ # The ast API is a bit underspecified. The objects are mutable,
+ # and even _fields and _attributes are mutable. The compare() does
+ # some simple things to accommodate mutability.
+ a = ast.parse("m * x + b", mode="eval")
+ b = ast.parse("m * x + b", mode="eval")
+ self.assertTrue(ast.compare(a, b))
+
+ a._fields = a._fields + ("spam",)
+ a.spam = "Spam"
+ self.assertNotEqual(a._fields, b._fields)
+ self.assertFalse(ast.compare(a, b))
+ self.assertFalse(ast.compare(b, a))
+
+ b._fields = a._fields
+ b.spam = a.spam
+ self.assertTrue(ast.compare(a, b))
+ self.assertTrue(ast.compare(b, a))
+
+ b._attributes = b._attributes + ("eggs",)
+ b.eggs = "eggs"
+ self.assertNotEqual(a._attributes, b._attributes)
+ self.assertFalse(ast.compare(a, b, compare_attributes=True))
+ self.assertFalse(ast.compare(b, a, compare_attributes=True))
+
+ a._attributes = b._attributes
+ a.eggs = b.eggs
+ self.assertTrue(ast.compare(a, b, compare_attributes=True))
+ self.assertTrue(ast.compare(b, a, compare_attributes=True))
+
+ def test_compare_literals(self):
+ constants = (
+ -20,
+ 20,
+ 20.0,
+ 1,
+ 1.0,
+ True,
+ 0,
+ False,
+ frozenset(),
+ tuple(),
+ "ABCD",
+ "abcd",
+ "中文字",
+ 1e1000,
+ -1e1000,
+ )
+ for next_index, constant in enumerate(constants[:-1], 1):
+ next_constant = constants[next_index]
+ with self.subTest(literal=constant, next_literal=next_constant):
+ self.assertTrue(
+ ast.compare(ast.Constant(constant), ast.Constant(constant))
+ )
+ self.assertFalse(
+ ast.compare(
+ ast.Constant(constant), ast.Constant(next_constant)
+ )
+ )
+
+ same_looking_literal_cases = [
+ {1, 1.0, True, 1 + 0j},
+ {0, 0.0, False, 0 + 0j},
+ ]
+ for same_looking_literals in same_looking_literal_cases:
+ for literal in same_looking_literals:
+ for same_looking_literal in same_looking_literals - {literal}:
+ self.assertFalse(
+ ast.compare(
+ ast.Constant(literal),
+ ast.Constant(same_looking_literal),
+ )
+ )
+
+ def test_compare_fieldless(self):
+ self.assertTrue(ast.compare(ast.Add(), ast.Add()))
+ self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
+
+ def test_compare_modes(self):
+ for mode, sources in (
+ ("exec", exec_tests),
+ ("eval", eval_tests),
+ ("single", single_tests),
+ ):
+ for source in sources:
+ a = ast.parse(source, mode=mode)
+ b = ast.parse(source, mode=mode)
+ self.assertTrue(
+ ast.compare(a, b), f"{ast.dump(a)} != {ast.dump(b)}"
+ )
+
+ def test_compare_attributes_option(self):
+ def parse(a, b):
+ return ast.parse(a), ast.parse(b)
+
+ a, b = parse("2 + 2", "2+2")
+ self.assertTrue(ast.compare(a, b))
+ self.assertTrue(ast.compare(a, b, compare_attributes=False))
+ self.assertFalse(ast.compare(a, b, compare_attributes=True))
+
def test_positional_only_feature_version(self):
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))
for node, attr, source in tests:
self.assert_none_check(node, attr, source)
+
class ASTHelpers_Test(unittest.TestCase):
maxDiff = None
@support.requires_resource('cpu')
def test_stdlib_validates(self):
- stdlib = os.path.dirname(ast.__file__)
- tests = [fn for fn in os.listdir(stdlib) if fn.endswith(".py")]
- tests.extend(["test/test_grammar.py", "test/test_unpack_ex.py"])
- for module in tests:
+ for module in STDLIB_FILES:
with self.subTest(module):
- fn = os.path.join(stdlib, module)
+ fn = os.path.join(STDLIB, module)
with open(fn, "r", encoding="utf-8") as fp:
source = fp.read()
mod = ast.parse(source, fn)
compile(mod, fn, "exec")
+ mod2 = ast.parse(source, fn)
+ self.assertTrue(ast.compare(mod, mod2))
constant_1 = ast.Constant(1)
pattern_1 = ast.MatchValue(constant_1)