might differ in whitespace or similar details.
"""
+ sentinel = object() # handle the possibility of a missing attribute/field
+
def _compare(a, b):
# Compare two fields on an AST object, which may themselves be
# AST objects, lists of AST objects, or primitive ASDL types
if a._fields != b._fields:
return False
for field in a._fields:
- a_field = getattr(a, field)
- b_field = getattr(b, field)
+ a_field = getattr(a, field, sentinel)
+ b_field = getattr(b, field, sentinel)
+ if a_field is sentinel and b_field is sentinel:
+ # both nodes are missing a field at runtime
+ continue
+ if a_field is sentinel or b_field is sentinel:
+ # one of the node is missing a field
+ return False
if not _compare(a_field, b_field):
return False
else:
return False
# Attributes are always ints.
for attr in a._attributes:
- a_attr = getattr(a, attr)
- b_attr = getattr(b, attr)
+ a_attr = getattr(a, attr, sentinel)
+ b_attr = getattr(b, attr, sentinel)
+ if a_attr is sentinel and b_attr is sentinel:
+ # both nodes are missing an attribute at runtime
+ continue
if a_attr != b_attr:
return False
else:
self.assertTrue(ast.compare(ast.Add(), ast.Add()))
self.assertFalse(ast.compare(ast.Sub(), ast.Add()))
+ # test that missing runtime fields is handled in ast.compare()
+ a1, a2 = ast.Name('a'), ast.Name('a')
+ self.assertTrue(ast.compare(a1, a2))
+ self.assertTrue(ast.compare(a1, a2))
+ del a1.id
+ self.assertFalse(ast.compare(a1, a2))
+ del a2.id
+ self.assertTrue(ast.compare(a1, a2))
+
def test_compare_modes(self):
for mode, sources in (
("exec", exec_tests),
self.assertTrue(ast.compare(a, b, compare_attributes=False))
self.assertFalse(ast.compare(a, b, compare_attributes=True))
+ def test_compare_attributes_option_missing_attribute(self):
+ # test that missing runtime attributes is handled in ast.compare()
+ a1, a2 = ast.Name('a', lineno=1), ast.Name('a', lineno=1)
+ self.assertTrue(ast.compare(a1, a2))
+ self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
+ del a1.lineno
+ self.assertFalse(ast.compare(a1, a2, compare_attributes=True))
+ del a2.lineno
+ self.assertTrue(ast.compare(a1, a2, compare_attributes=True))
+
def test_positional_only_feature_version(self):
ast.parse('def foo(x, /): ...', feature_version=(3, 8))
ast.parse('def bar(x=1, /): ...', feature_version=(3, 8))