def __iter__(self):
return iter(self.rdatasets)
+ def _append_rdataset(self, rdataset):
+ """Append rdataset to the node with special handling for CNAME and
+ other data conditions.
+
+ Specifically, if the rdataset being appended is a CNAME, then
+ all rdatasets other than NSEC, NSEC3, and their covering RRSIGs
+ are deleted. If the rdataset being appended is NOT a CNAME, then
+ CNAME and RRSIG(CNAME) are deleted.
+ """
+ # Make having just one rdataset at the node fast.
+ if len(self.rdatasets) > 0:
+ if rdataset.rdtype == dns.rdatatype.CNAME:
+ self.rdatasets = [rds for rds in self.rdatasets
+ if rds.ok_for_cname()]
+ else:
+ self.rdatasets = [rds for rds in self.rdatasets
+ if rds.ok_for_other_data()]
+ self.rdatasets.append(rdataset)
+
+
def find_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
create=False):
"""Find an rdataset matching the specified properties in the
if not create:
raise KeyError
rds = dns.rdataset.Rdataset(rdclass, rdtype, covers)
- self.rdatasets.append(rds)
+ self._append_rdataset(rds)
return rds
def get_rdataset(self, rdclass, rdtype, covers=dns.rdatatype.NONE,
replacement = replacement.to_rdataset()
self.delete_rdataset(replacement.rdclass, replacement.rdtype,
replacement.covers)
- self.rdatasets.append(replacement)
+ self._append_rdataset(replacement)
"""An attempt was made to add DNS RR data of an incompatible type."""
+_ok_for_cname = {
+ (dns.rdatatype.CNAME, dns.rdatatype.NONE),
+ (dns.rdatatype.RRSIG, dns.rdatatype.CNAME),
+ (dns.rdatatype.NSEC, dns.rdatatype.NONE),
+ (dns.rdatatype.RRSIG, dns.rdatatype.NSEC),
+ (dns.rdatatype.NSEC3, dns.rdatatype.NONE),
+ (dns.rdatatype.RRSIG, dns.rdatatype.NSEC3),
+}
+
+_delete_for_other_data = {
+ (dns.rdatatype.CNAME, dns.rdatatype.NONE),
+ (dns.rdatatype.RRSIG, dns.rdatatype.CNAME),
+}
+
class Rdataset(dns.set.Set):
"""A DNS rdataset."""
else:
return self[0]._processing_order(iter(self))
+ def ok_for_cname(self):
+ """Is this rdataset compatible with a CNAME node?"""
+ return (self.rdtype, self.covers) in _ok_for_cname
+
+ def ok_for_other_data(self):
+ """Is this rdataset compatible with an 'other data' (i.e. not CNAME)
+ node?"""
+ return (self.rdtype, self.covers) not in _delete_for_other_data
+
@dns.immutable.immutable
class ImmutableRdataset(Rdataset):
ns2 3600 IN A 10.0.0.2 ; comment2
"""
+
+example_cname = """$TTL 3600
+$ORIGIN example.
+@ soa foo bar (1 2 3 4 5)
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1
+ns2 a 10.0.0.2
+www a 10.0.0.3
+web cname www
+ nsec @ CNAME RRSIG
+ rrsig NSEC 1 3 3600 20200101000000 20030101000000 2143 foo MxFcby9k/yvedMfQgKzhH5er0Mu/vILz 45IkskceFGgiWCn/GxHhai6VAuHAoNUz 4YoU1tVfSCSqQYn6//11U6Nld80jEeC8 aTrO+KKmCaY=
+ rrsig CNAME 1 3 3600 20200101000000 20030101000000 2143 foo MxFcby9k/yvedMfQgKzhH5er0Mu/vILz 45IkskceFGgiWCn/GxHhai6VAuHAoNUz 4YoU1tVfSCSqQYn6//11U6Nld80jEeC8 aTrO+KKmCaY=
+web2 cname www
+ nsec3 1 1 12 aabbccdd 2t7b4g4vsa5smi47k61mv5bv1a22bojr CNAME RRSIG
+ rrsig NSEC3 1 3 3600 20200101000000 20030101000000 2143 foo MxFcby9k/yvedMfQgKzhH5er0Mu/vILz 45IkskceFGgiWCn/GxHhai6VAuHAoNUz 4YoU1tVfSCSqQYn6//11U6Nld80jEeC8 aTrO+KKmCaY=
+ rrsig CNAME 1 3 3600 20200101000000 20030101000000 2143 foo MxFcby9k/yvedMfQgKzhH5er0Mu/vILz 45IkskceFGgiWCn/GxHhai6VAuHAoNUz 4YoU1tVfSCSqQYn6//11U6Nld80jEeC8 aTrO+KKmCaY=
+"""
+
+
+example_other_data = """$TTL 3600
+$ORIGIN example.
+@ soa foo bar (1 2 3 4 5)
+@ ns ns1
+@ ns ns2
+ns1 a 10.0.0.1
+ns2 a 10.0.0.2
+www a 10.0.0.3
+web a 10.0.0.4
+ nsec @ A RRSIG
+ rrsig A 1 3 3600 20200101000000 20030101000000 2143 foo MxFcby9k/yvedMfQgKzhH5er0Mu/vILz 45IkskceFGgiWCn/GxHhai6VAuHAoNUz 4YoU1tVfSCSqQYn6//11U6Nld80jEeC8 aTrO+KKmCaY=
+ rrsig NSEC 1 3 3600 20200101000000 20030101000000 2143 foo MxFcby9k/yvedMfQgKzhH5er0Mu/vILz 45IkskceFGgiWCn/GxHhai6VAuHAoNUz 4YoU1tVfSCSqQYn6//11U6Nld80jEeC8 aTrO+KKmCaY=
+ rrsig CNAME 1 3 3600 20200101000000 20030101000000 2143 foo MxFcby9k/yvedMfQgKzhH5er0Mu/vILz 45IkskceFGgiWCn/GxHhai6VAuHAoNUz 4YoU1tVfSCSqQYn6//11U6Nld80jEeC8 aTrO+KKmCaY=
+"""
+
+
_keep_output = True
def _rdata_sort(a):
self.assertTrue(rds is not rrs)
self.assertFalse(isinstance(rds, dns.rrset.RRset))
+ def testCnameAndOtherDataAddOther(self):
+ z = dns.zone.from_text(example_cname, 'example.', relativize=True)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1')
+ z.replace_rdataset('web', rds)
+ z.replace_rdataset('web2', rds.copy())
+ n = z.find_node('web')
+ self.assertEqual(len(n.rdatasets), 3)
+ self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A),
+ rds)
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.NSEC))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.NSEC))
+ n = z.find_node('web2')
+ self.assertEqual(len(n.rdatasets), 3)
+ self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A),
+ rds)
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.NSEC3))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.NSEC3))
+
+ def testCnameAndOtherDataAddCname(self):
+ z = dns.zone.from_text(example_other_data, 'example.', relativize=True)
+ rds = dns.rdataset.from_text('in', 'cname', 300, 'www')
+ z.replace_rdataset('web', rds)
+ n = z.find_node('web')
+ self.assertEqual(len(n.rdatasets), 4)
+ self.assertEqual(n.find_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.CNAME),
+ rds)
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.NSEC))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.NSEC))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.CNAME))
+
+ def testNameInZoneWithStr(self):
+ z = dns.zone.from_text(example_text, 'example.', relativize=False)
+ self.assertTrue('ns1.example.' in z)
+ self.assertTrue('bar.foo.example.' in z)
+
+ def testNameInZoneWhereNameIsNotValid(self):
+ z = dns.zone.from_text(example_text, 'example.', relativize=False)
+ with self.assertRaises(KeyError):
+ self.assertTrue(1 in z)
+
class VersionedZoneTestCase(unittest.TestCase):
def testUseTransaction(self):
rds = txn.get('example.', 'soa')
self.assertEqual(rds[0].serial, 1)
- def testNameInZoneWithStr(self):
- z = dns.zone.from_text(example_text, 'example.', relativize=False)
- self.assertTrue('ns1.example.' in z)
- self.assertTrue('bar.foo.example.' in z)
+ def testCnameAndOtherDataAddOther(self):
+ z = dns.zone.from_text(example_cname, 'example.', relativize=True,
+ zone_factory=dns.versioned.Zone)
+ rds = dns.rdataset.from_text('in', 'a', 300, '10.0.0.1')
+ with z.writer() as txn:
+ txn.replace('web', rds)
+ txn.replace('web2', rds.copy())
+ n = z.find_node('web')
+ self.assertEqual(len(n.rdatasets), 3)
+ self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A),
+ rds)
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.NSEC))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.NSEC))
+ n = z.find_node('web2')
+ self.assertEqual(len(n.rdatasets), 3)
+ self.assertEqual(n.find_rdataset(dns.rdataclass.IN, dns.rdatatype.A),
+ rds)
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.NSEC3))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.NSEC3))
+
+ def testCnameAndOtherDataAddCname(self):
+ z = dns.zone.from_text(example_other_data, 'example.', relativize=True,
+ zone_factory=dns.versioned.Zone)
+ rds = dns.rdataset.from_text('in', 'cname', 300, 'www')
+ with z.writer() as txn:
+ txn.replace('web', rds)
+ n = z.find_node('web')
+ self.assertEqual(len(n.rdatasets), 4)
+ self.assertEqual(n.find_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.CNAME),
+ rds)
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.NSEC))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.NSEC))
+ self.assertIsNotNone(n.get_rdataset(dns.rdataclass.IN,
+ dns.rdatatype.RRSIG,
+ dns.rdatatype.CNAME))
- def testNameInZoneWhereNameIsNotValid(self):
- z = dns.zone.from_text(example_text, 'example.', relativize=False)
- with self.assertRaises(KeyError):
- self.assertTrue(1 in z)
if __name__ == '__main__':
unittest.main()