return handleAXFR(fd, mdp);
}
- std::vector<std::vector<uint8_t>> packets;
- for (const auto& diff : toSend) {
- /* An IXFR packet's ANSWER section looks as follows:
- * SOA new_serial
- * SOA old_serial
- * ... removed records ...
- * SOA new_serial
- * ... added records ...
- * SOA new_serial
- */
+ /* An IXFR packet's ANSWER section looks as follows:
+ * SOA latest_serial C
+
+ First set of changes:
+ * SOA requested_serial A
+ * ... removed records ...
+ * SOA intermediate_serial B
+ * ... added records ...
+
+ Next set of changes:
+ * SOA intermediate_serial B
+ * ... removed records ...
+ * SOA latest_serial C
+ * ... added records ...
+
+ * SOA latest_serial C
+ */
+
+ const auto latestSOAPacket = getSOAPacket(mdp, zoneInfo->soa, zoneInfo->soaTTL);
+ if (!sendPacketOverTCP(fd, latestSOAPacket)) {
+ return false;
+ }
+
+ for (const auto& diff : toSend) {
const auto newSOAPacket = getSOAPacket(mdp, diff->newSOA, diff->newSOATTL);
const auto oldSOAPacket = getSOAPacket(mdp, diff->oldSOA, diff->oldSOATTL);
- if (!sendPacketOverTCP(fd, newSOAPacket)) {
- return false;
- }
-
if (!sendPacketOverTCP(fd, oldSOAPacket)) {
return false;
}
if (!sendRecordsOverTCP(fd, mdp, diff->additions)) {
return false;
}
+ }
- if (!sendPacketOverTCP(fd, newSOAPacket)) {
- return false;
- }
+ if (!sendPacketOverTCP(fd, latestSOAPacket)) {
+ return false;
}
return true;
import dns
+import dns.serial
import time
+import itertools
from ixfrdisttests import IXFRDistTest
from xfrserver.xfrserver import AXFRServer
ns1.example. 4242 A 192.0.2.1
ns2.example. 4242 A 192.0.2.2
newrecord.example. 8484 A 192.0.2.42
+""",
+ 3: """
+$ORIGIN example.
+@ 86400 SOA foo bar 3 2 3 4 5
+@ 4242 NS ns1.example.
+@ 4242 NS ns2.example.
+ns1.example. 4242 A 192.0.2.1
+ns2.example. 4242 A 192.0.2.2
+newrecord2.example. 8484 A 192.0.2.42
+""",
+ 4: """
+$ORIGIN example.
+@ 86400 SOA foo bar 4 2 3 4 5
+@ 4242 NS ns1.example.
+@ 4242 NS ns2.example.
+ns1.example. 4242 A 192.0.2.1
+ns2.example. 4242 A 192.0.2.2
+newrecord2.example. 8484 A 192.0.2.42
+other.example. 1234 TXT "foo"
"""
}
'example2': '127.0.0.1:1', # bogus port is intentional - zone is intentionally unloadable
# example3 # intentionally absent for 'unconfigured zone' testing
'example4': '127.0.0.1:' + str(xfrServerPort) } # for testing how ixfrdist deals with getting the wrong zone on XFR
+ _loaded_serials = []
@classmethod
def setUpClass(cls):
xfrServer.moveToSerial(serial)
+ def get_current_serial():
+ query = dns.message.make_query('example.', 'SOA')
+ response_message = self.sendUDPQuery(query)
+
+ if response_message.rcode() == dns.rcode.REFUSED:
+ return 0
+
+ soa_rrset = response_message.find_rrset(dns.message.ANSWER, dns.name.from_text("example."), dns.rdataclass.IN, dns.rdatatype.SOA)
+ return soa_rrset[0].serial
+
attempts = 0
while attempts < timeout:
print('attempts=%s timeout=%s' % (attempts, timeout))
- servedSerial = xfrServer.getServedSerial()
+ servedSerial = get_current_serial()
print('servedSerial=%s' % servedSerial)
if servedSerial > serial:
raise AssertionError("Expected serial %d, got %d" % (serial, servedSerial))
if servedSerial == serial:
self._xfrDone = self._xfrDone + 1
+ self._loaded_serials.append(serial)
return
attempts = attempts + 1
def checkIXFR(self, fromserial, toserial):
global zones, xfrServer
- ixfr = []
- soa1 = xfrServer._getSOAForSerial(fromserial)
- soa2 = xfrServer._getSOAForSerial(toserial)
- newrecord = [r for r in xfrServer._getRecordsForSerial(toserial) if r.name==dns.name.from_text('newrecord.example.')]
+ soa_requested = xfrServer._getSOAForSerial(fromserial)
+ soa_latest = xfrServer._getSOAForSerial(self._loaded_serials[-1])
+
+ self.assertEqual(soa_latest[0].serial, toserial)
+
query = dns.message.make_query('example.', 'IXFR')
- query.authority = [soa1]
+ query.authority = [soa_requested]
+
+ expected = []
+ expected.append([soa_latest]) #latest SOA
+
+ def pairwise(iterable): # itertools.pairwise exists in 3.10, but until then...
+ # pairwise('ABCDEFG') --> AB BC CD DE EF FG
+ a, b = itertools.tee(iterable)
+ next(b, None)
+ return zip(a, b)
+
+ found_starting_version = False
+ for serial_pair in pairwise(self._loaded_serials):
+ if dns.serial.Serial(serial_pair[0]) < dns.serial.Serial(fromserial):
+ continue
+
+ if serial_pair[0] == fromserial:
+ found_starting_version = True
+
+ old_records = [r for r in xfrServer._getRecordsForSerial(serial_pair[0]) if r.rdtype != dns.rdatatype.SOA]
+ new_records = [r for r in xfrServer._getRecordsForSerial(serial_pair[1]) if r.rdtype != dns.rdatatype.SOA]
+ added = [r for r in new_records if r not in old_records]
+ removed = [r for r in old_records if r not in new_records]
+
+ expected.append([xfrServer._getSOAForSerial(serial_pair[0])]) # old SOA
+ if removed: expected.append(removed) # removed records from old SOA (sendTCPQueryMultiResponse skips if empty)
+ expected.append([xfrServer._getSOAForSerial(serial_pair[1])]) # new SOA
+ if added: expected.append(added) # added records in new SOA (sendTCPQueryMultiResponse skips if empty)
+
+ expected.append([soa_latest]) # latest SOA
+
+ if not found_starting_version:
+ raise AssertionError("Did not find zone version with requested serial {fromserial}, impossible to IXFR scenario?")
- expected = [[soa2], [soa1], [soa2], newrecord, [soa2]]
res = self.sendTCPQueryMultiResponse(query, count=len(expected)+1) # +1 for trailing data check
answers = [r.answer for r in res]
pos = pos + 1
answerPos = answerPos + 1
+
def test_a_XFR(self):
self.waitUntilCorrectSerialIsLoaded(1)
self.checkFullZone(1)
response = self.sendUDPQuery(query)
self.assertEqual(expected, response)
+
+ def test_c_IXFR_multi(self):
+ self.waitUntilCorrectSerialIsLoaded(3)
+ self.checkFullZone(3)
+ self.checkIXFR(2,3)
+ self.checkIXFR(1,3)
+
+ self.waitUntilCorrectSerialIsLoaded(4)
+ self.checkFullZone(4)
+ self.checkIXFR(3,4)
+ self.checkIXFR(2,4)
+ self.checkIXFR(1,4)