from typing import cast
import dns.exception
+import dns.message
import dns.rdata
import dns.rdataset
import dns.rdataclass
def _rdata_sort(a):
return (a[0], a[2].rdclass, a[2].to_text())
+def add_rdataset(msg, name, rds):
+ rrset = msg.get_rrset(msg.answer, name, rds.rdclass, rds.rdtype,
+ create=True, force_unique=True)
+ for rd in rds:
+ rrset.add(rd, ttl=rds.ttl)
+
+def make_xfr(zone):
+ q = dns.message.make_query(zone.origin, 'AXFR')
+ msg = dns.message.make_response(q)
+ if zone.relativize:
+ msg.origin = zone.origin
+ soa_name = dns.name.empty
+ else:
+ soa_name = zone.origin
+ soa = zone.find_rdataset(soa_name, 'SOA')
+ add_rdataset(msg, soa_name, soa)
+ for (name, rds) in zone.iterate_rdatasets():
+ if rds.rdtype == dns.rdatatype.SOA:
+ continue
+ add_rdataset(msg, name, rds)
+ add_rdataset(msg, soa_name, soa)
+ return [msg]
+
class ZoneTestCase(unittest.TestCase):
def testFromFile1(self): # type: () -> None
def testZoneOriginNone(self): # type: () -> None
dns.zone.Zone(cast(str, None))
+ def testZoneFromXFR(self): # type: () -> None
+ z1_abs = dns.zone.from_text(example_text, 'example.', relativize=False)
+ z2_abs = dns.zone.from_xfr(make_xfr(z1_abs), relativize=False)
+ self.assertEqual(z1_abs, z2_abs)
+
+ z1_rel = dns.zone.from_text(example_text, 'example.', relativize=True)
+ z2_rel = dns.zone.from_xfr(make_xfr(z1_rel), relativize=True)
+ self.assertEqual(z1_rel, z2_rel)
+
if __name__ == '__main__':
unittest.main()