]> git.ipfire.org Git - thirdparty/dnspython.git/commitdiff
Remove choose_relativity() from zone.from_xfr() 428/head
authorBrian Wellington <bwelling@xbill.org>
Wed, 18 Mar 2020 21:56:58 +0000 (14:56 -0700)
committerBrian Wellington <bwelling@xbill.org>
Wed, 18 Mar 2020 21:57:29 +0000 (14:57 -0700)
The comment states that relativize must be consistent between
dns.query.xfr() and dns.zone.from_xfr(), and the code fails if they're
not (if check_origin is True, at least).  This means that the rdata is
already correctly relativized (or not).

This also adds a test of creating zones from xfrs, both relativized and
not.

dns/zone.py
tests/test_zone.py

index 542ec7eaa638c0f8a76a763e0997f85e55aca4c8..0da4d89f4d325a4752ff29ae9a62db2e4e0ce18b 100644 (file)
@@ -1117,7 +1117,6 @@ def from_xfr(xfr, zone_factory=Zone, relativize=True, check_origin=True):
                                        rrset.covers, True)
             zrds.update_ttl(rrset.ttl)
             for rd in rrset:
-                rd.choose_relativity(z.origin, relativize)
                 zrds.add(rd)
     if check_origin:
         z.check_origin()
index 0fae7047be2f72488dcc08511e5db9e4944d6ef1..7e4a4d8fe92f607e4674d7cb5a30e825022f1d7d 100644 (file)
@@ -22,6 +22,7 @@ import unittest
 from typing import cast
 
 import dns.exception
+import dns.message
 import dns.rdata
 import dns.rdataset
 import dns.rdataclass
@@ -135,6 +136,29 @@ _keep_output = True
 def _rdata_sort(a):
     return (a[0], a[2].rdclass, a[2].to_text())
 
+def add_rdataset(msg, name, rds):
+    rrset = msg.get_rrset(msg.answer, name, rds.rdclass, rds.rdtype,
+                          create=True, force_unique=True)
+    for rd in rds:
+        rrset.add(rd, ttl=rds.ttl)
+
+def make_xfr(zone):
+    q = dns.message.make_query(zone.origin, 'AXFR')
+    msg = dns.message.make_response(q)
+    if zone.relativize:
+        msg.origin = zone.origin
+        soa_name = dns.name.empty
+    else:
+        soa_name = zone.origin
+    soa = zone.find_rdataset(soa_name, 'SOA')
+    add_rdataset(msg, soa_name, soa)
+    for (name, rds) in zone.iterate_rdatasets():
+        if rds.rdtype == dns.rdatatype.SOA:
+            continue
+        add_rdataset(msg, name, rds)
+    add_rdataset(msg, soa_name, soa)
+    return [msg]
+
 class ZoneTestCase(unittest.TestCase):
 
     def testFromFile1(self): # type: () -> None
@@ -544,5 +568,14 @@ class ZoneTestCase(unittest.TestCase):
     def testZoneOriginNone(self): # type: () -> None
         dns.zone.Zone(cast(str, None))
 
+    def testZoneFromXFR(self): # type: () -> None
+        z1_abs = dns.zone.from_text(example_text, 'example.', relativize=False)
+        z2_abs = dns.zone.from_xfr(make_xfr(z1_abs), relativize=False)
+        self.assertEqual(z1_abs, z2_abs)
+
+        z1_rel = dns.zone.from_text(example_text, 'example.', relativize=True)
+        z2_rel = dns.zone.from_xfr(make_xfr(z1_rel), relativize=True)
+        self.assertEqual(z1_rel, z2_rel)
+
 if __name__ == '__main__':
     unittest.main()