cmp_fields = (field for field in field_list if field.compare)
terms = [f'self.{field.name}==other.{field.name}' for field in cmp_fields]
field_comparisons = ' and '.join(terms) or 'True'
- body = [f'if other.__class__ is self.__class__:',
+ body = [f'if self is other:',
+ f' return True',
+ f'if other.__class__ is self.__class__:',
f' return {field_comparisons}',
f'return NotImplemented']
func = _create_fn('__eq__',
class TestEq(unittest.TestCase):
+ def test_recursive_eq(self):
+ # Test a class with recursive child
+ @dataclass
+ class C:
+ recursive: object = ...
+ c = C()
+ c.recursive = c
+ self.assertEqual(c, c)
+
def test_no_eq(self):
# Test a class with no __eq__ and eq=False.
@dataclass(eq=False)