]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
Merge pull request #12127 from hlindqvist/ixfrdist-ixfr-multiple-changes
authorPeter van Dijk <peter.van.dijk@powerdns.com>
Mon, 19 Dec 2022 15:07:10 +0000 (16:07 +0100)
committerGitHub <noreply@github.com>
Mon, 19 Dec 2022 15:07:10 +0000 (16:07 +0100)
Fix multiple-version IXFR request handling in ixfrdist

pdns/ixfrdist.cc
regression-tests.ixfrdist/test_IXFR.py

index aea08ee066fbab3867b39866e814d314587aa5cc..24ac0e6a3a60eccb2d161e547a919de55e604429 100644 (file)
@@ -717,24 +717,34 @@ static bool handleIXFR(int fd, const ComboAddress& destination, const MOADNSPars
     return handleAXFR(fd, mdp);
   }
 
-  std::vector<std::vector<uint8_t>> packets;
-  for (const auto& diff : toSend) {
-    /* An IXFR packet's ANSWER section looks as follows:
-     * SOA new_serial
-     * SOA old_serial
-     * ... removed records ...
-     * SOA new_serial
-     * ... added records ...
-     * SOA new_serial
-     */
 
+  /* An IXFR packet's ANSWER section looks as follows:
+    * SOA latest_serial C
+
+    First set of changes:
+    * SOA requested_serial A
+    * ... removed records ...
+    * SOA intermediate_serial B
+    * ... added records ...
+
+    Next set of changes:
+    * SOA intermediate_serial B
+    * ... removed records ...
+    * SOA latest_serial C
+    * ... added records ...
+
+    * SOA latest_serial C
+    */
+
+  const auto latestSOAPacket = getSOAPacket(mdp, zoneInfo->soa, zoneInfo->soaTTL);
+  if (!sendPacketOverTCP(fd, latestSOAPacket)) {
+    return false;
+  }
+
+  for (const auto& diff : toSend) {
     const auto newSOAPacket = getSOAPacket(mdp, diff->newSOA, diff->newSOATTL);
     const auto oldSOAPacket = getSOAPacket(mdp, diff->oldSOA, diff->oldSOATTL);
 
-    if (!sendPacketOverTCP(fd, newSOAPacket)) {
-      return false;
-    }
-
     if (!sendPacketOverTCP(fd, oldSOAPacket)) {
       return false;
     }
@@ -750,10 +760,10 @@ static bool handleIXFR(int fd, const ComboAddress& destination, const MOADNSPars
     if (!sendRecordsOverTCP(fd, mdp, diff->additions)) {
       return false;
     }
+  }
 
-    if (!sendPacketOverTCP(fd, newSOAPacket)) {
-      return false;
-    }
+  if (!sendPacketOverTCP(fd, latestSOAPacket)) {
+    return false;
   }
 
   return true;
index 0d0f0b16e68b018a58b4ee58accaf33f97f6456c..35f0a30c8ecac6bb3d881807f3027c0a9de7cfb2 100644 (file)
@@ -1,5 +1,7 @@
 import dns
+import dns.serial
 import time
+import itertools
 
 from ixfrdisttests import IXFRDistTest
 from xfrserver.xfrserver import AXFRServer
@@ -21,6 +23,25 @@ $ORIGIN example.
 ns1.example.    4242    A       192.0.2.1
 ns2.example.    4242    A       192.0.2.2
 newrecord.example.        8484    A       192.0.2.42
+""",
+    3: """
+$ORIGIN example.
+@        86400   SOA    foo bar 3 2 3 4 5
+@        4242    NS     ns1.example.
+@        4242    NS     ns2.example.
+ns1.example.    4242    A       192.0.2.1
+ns2.example.    4242    A       192.0.2.2
+newrecord2.example.        8484    A       192.0.2.42
+""",
+    4: """
+$ORIGIN example.
+@        86400   SOA    foo bar 4 2 3 4 5
+@        4242    NS     ns1.example.
+@        4242    NS     ns2.example.
+ns1.example.    4242    A       192.0.2.1
+ns2.example.    4242    A       192.0.2.2
+newrecord2.example.        8484    A       192.0.2.42
+other.example.  1234    TXT     "foo"
 """
 }
 
@@ -39,6 +60,7 @@ class IXFRDistBasicTest(IXFRDistTest):
                         'example2': '127.0.0.1:1',                      # bogus port is intentional - zone is intentionally unloadable
                         # example3                                      # intentionally absent for 'unconfigured zone' testing
                         'example4': '127.0.0.1:' + str(xfrServerPort) } # for testing how ixfrdist deals with getting the wrong zone on XFR
+    _loaded_serials = []
 
     @classmethod
     def setUpClass(cls):
@@ -55,15 +77,26 @@ class IXFRDistBasicTest(IXFRDistTest):
 
         xfrServer.moveToSerial(serial)
 
+        def get_current_serial():
+            query = dns.message.make_query('example.', 'SOA')
+            response_message = self.sendUDPQuery(query)
+
+            if response_message.rcode() == dns.rcode.REFUSED:
+                return 0
+
+            soa_rrset = response_message.find_rrset(dns.message.ANSWER, dns.name.from_text("example."), dns.rdataclass.IN, dns.rdatatype.SOA)
+            return soa_rrset[0].serial
+
         attempts = 0
         while attempts < timeout:
             print('attempts=%s timeout=%s' % (attempts, timeout))
-            servedSerial = xfrServer.getServedSerial()
+            servedSerial = get_current_serial()
             print('servedSerial=%s' % servedSerial)
             if servedSerial > serial:
                 raise AssertionError("Expected serial %d, got %d" % (serial, servedSerial))
             if servedSerial == serial:
                 self._xfrDone = self._xfrDone + 1
+                self._loaded_serials.append(serial)
                 return
 
             attempts = attempts + 1
@@ -93,14 +126,46 @@ class IXFRDistBasicTest(IXFRDistTest):
     def checkIXFR(self, fromserial, toserial):
         global zones, xfrServer
 
-        ixfr = []
-        soa1 = xfrServer._getSOAForSerial(fromserial)
-        soa2 = xfrServer._getSOAForSerial(toserial)
-        newrecord = [r for r in xfrServer._getRecordsForSerial(toserial) if r.name==dns.name.from_text('newrecord.example.')]
+        soa_requested = xfrServer._getSOAForSerial(fromserial)
+        soa_latest = xfrServer._getSOAForSerial(self._loaded_serials[-1])
+
+        self.assertEqual(soa_latest[0].serial, toserial)
+
         query = dns.message.make_query('example.', 'IXFR')
-        query.authority = [soa1]
+        query.authority = [soa_requested]
+
+        expected = []
+        expected.append([soa_latest]) #latest SOA
+
+        def pairwise(iterable): # itertools.pairwise exists in 3.10, but until then...
+            # pairwise('ABCDEFG') --> AB BC CD DE EF FG
+            a, b = itertools.tee(iterable)
+            next(b, None)
+            return zip(a, b)
+
+        found_starting_version = False
+        for serial_pair in pairwise(self._loaded_serials):
+            if dns.serial.Serial(serial_pair[0]) < dns.serial.Serial(fromserial):
+                continue
+
+            if serial_pair[0] == fromserial:
+                found_starting_version = True
+
+            old_records = [r for r in xfrServer._getRecordsForSerial(serial_pair[0]) if r.rdtype != dns.rdatatype.SOA]
+            new_records = [r for r in xfrServer._getRecordsForSerial(serial_pair[1]) if r.rdtype != dns.rdatatype.SOA]
+            added = [r for r in new_records if r not in old_records]
+            removed = [r for r in old_records if r not in new_records]
+
+            expected.append([xfrServer._getSOAForSerial(serial_pair[0])]) # old SOA
+            if removed: expected.append(removed) # removed records from old SOA (sendTCPQueryMultiResponse skips if empty)
+            expected.append([xfrServer._getSOAForSerial(serial_pair[1])]) # new SOA
+            if added: expected.append(added) # added records in new SOA (sendTCPQueryMultiResponse skips if empty)
+
+        expected.append([soa_latest]) # latest SOA
+
+        if not found_starting_version:
+            raise AssertionError("Did not find zone version with requested serial {fromserial}, impossible to IXFR scenario?")
 
-        expected = [[soa2], [soa1], [soa2], newrecord, [soa2]]
         res = self.sendTCPQueryMultiResponse(query, count=len(expected)+1) # +1 for trailing data check
         answers = [r.answer for r in res]
 
@@ -115,6 +180,7 @@ class IXFRDistBasicTest(IXFRDistTest):
                 pos = pos + 1
             answerPos = answerPos + 1
 
+
     def test_a_XFR(self):
         self.waitUntilCorrectSerialIsLoaded(1)
         self.checkFullZone(1)
@@ -153,3 +219,15 @@ class IXFRDistBasicTest(IXFRDistTest):
 
         response = self.sendUDPQuery(query)
         self.assertEqual(expected, response)
+
+    def test_c_IXFR_multi(self):
+        self.waitUntilCorrectSerialIsLoaded(3)
+        self.checkFullZone(3)
+        self.checkIXFR(2,3)
+        self.checkIXFR(1,3)
+
+        self.waitUntilCorrectSerialIsLoaded(4)
+        self.checkFullZone(4)
+        self.checkIXFR(3,4)
+        self.checkIXFR(2,4)
+        self.checkIXFR(1,4)