From 7d4d4715b5ebbe01cc4e4a0da662116263d747b1 Mon Sep 17 00:00:00 2001 From: kimbo Date: Tue, 16 Feb 2021 15:21:13 -0700 Subject: [PATCH] make `name in zone` consistent with `zone[name]` specifically, allow name to be a str, and raise a KeyError if name cannot be converted into a dns.name.Name --- dns/zone.py | 5 +++-- tests/test_zone.py | 10 ++++++++++ 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/dns/zone.py b/dns/zone.py index c9c1c201..ac957638 100644 --- a/dns/zone.py +++ b/dns/zone.py @@ -162,8 +162,9 @@ class Zone(dns.transaction.TransactionManager): key = self._validate_name(key) return self.nodes.get(key) - def __contains__(self, other): - return other in self.nodes + def __contains__(self, key): + key = self._validate_name(key) + return key in self.nodes def find_node(self, name, create=False): """Find a node in the zone, possibly creating it. diff --git a/tests/test_zone.py b/tests/test_zone.py index 66f3ad5a..26adc878 100644 --- a/tests/test_zone.py +++ b/tests/test_zone.py @@ -872,5 +872,15 @@ class VersionedZoneTestCase(unittest.TestCase): rds = txn.get('example.', 'soa') self.assertEqual(rds[0].serial, 1) + def testNameInZoneWithStr(self): + z = dns.zone.from_text(example_text, 'example.', relativize=False) + self.assertTrue('ns1.example.' in z) + self.assertTrue('bar.foo.example.' in z) + + def testNameInZoneWhereNameIsNotValid(self): + z = dns.zone.from_text(example_text, 'example.', relativize=False) + with self.assertRaises(KeyError): + self.assertTrue(1 in z) + if __name__ == '__main__': unittest.main() -- 2.47.3