std::shared_ptr<SOARecordContent> primarySOA = nullptr;
vector<DNSRecord> records;
size_t receivedBytes = 0;
- int8_t ixfrInProgress = -2;
std::string reply;
+ enum transferStyle { Unknown, AXFR, IXFR } style = Unknown;
+ const unsigned int expectedSOAForAXFR = 2;
+ const unsigned int expectedSOAForIXFR = 3;
+ unsigned int primarySOACount = 0;
+
for(;;) {
- // IXFR end
- if (ixfrInProgress >= 0)
+ // IXFR or AXFR style end reached? We don't want to process trailing data after the closing SOA
+ if (style == AXFR && primarySOACount == expectedSOAForAXFR) {
+ break;
+ }
+ else if (style == IXFR && primarySOACount == expectedSOAForIXFR) {
break;
+ }
if(s.read((char*)&len, sizeof(len)) != sizeof(len))
break;
return ret;
}
primarySOA = sr;
+ ++primarySOACount;
} else if (r.first.d_type == QType::SOA) {
auto sr = getRR<SOARecordContent>(r.first);
if (!sr) {
throw std::runtime_error("Error getting the content of SOA record of IXFR answer for zone '"+zone.toLogString()+"' from primary '"+primary.toStringWithPort()+"'");
}
- // we hit the last SOA record
- // IXFR is considered to be done if we hit the last SOA record twice
+ // we hit a marker SOA record
if (primarySOA->d_st.serial == sr->d_st.serial) {
- ixfrInProgress++;
+ ++primarySOACount;
+ }
+ }
+ // When we see the 2nd record, we can decide what the style is
+ if (records.size() == 1 && style == Unknown) {
+ if (r.first.d_type != QType::SOA) {
+ // Non-empty AXFR style has a non-SOA record following the first SOA
+ style = AXFR;
+ }
+ else if (primarySOACount == expectedSOAForAXFR) {
+ // Empty zone AXFR style: start SOA is immediately followed by end marker SOA
+ style = AXFR;
+ }
+ else {
+ // IXFR has a 2nd SOA (with different serial) following the first
+ style = IXFR;
}
}
}
}
- // cout<<"Got "<<records.size()<<" records"<<endl;
+ switch (style) {
+ case IXFR:
+ if (primarySOACount != expectedSOAForIXFR) {
+ throw std::runtime_error("Incomplete IXFR transfer for '" + zone.toLogString() + "' from primary '" + primary.toStringWithPort());
+ }
+ break;
+ case AXFR:
+ if (primarySOACount != expectedSOAForAXFR){
+ throw std::runtime_error("Incomplete AXFR style transfer for '" + zone.toLogString() + "' from primary '" + primary.toStringWithPort());
+ }
+ break;
+ case Unknown:
+ throw std::runtime_error("Incomplete XFR for '" + zone.toLogString() + "' from primary '" + primary.toStringWithPort());
+ break;
+ }
return processIXFRRecords(primary, zone, records, primarySOA);
}
--- /dev/null
+import dns
+import json
+import os
+import requests
+import socket
+import struct
+import sys
+import threading
+import time
+
+from authtests import AuthTest
+
+class BadXFRServer(object):
+
+ def __init__(self, port):
+ self._currentSerial = 0
+ self._targetSerial = 1
+ self._serverPort = port
+ listener = threading.Thread(name='XFR Listener', target=self._listener, args=[])
+ listener.setDaemon(True)
+ listener.start()
+
+ def getCurrentSerial(self):
+ return self._currentSerial
+
+ def moveToSerial(self, newSerial):
+ if newSerial == self._currentSerial:
+ return False
+
+ #if newSerial != self._currentSerial + 1:
+ # raise AssertionError("Asking the XFR server to serve serial %d, already serving %d" % (newSerial, self._currentSerial))
+ self._targetSerial = newSerial
+ return True
+
+ def _getAnswer(self, message):
+
+ response = dns.message.make_response(message)
+ records = []
+
+ if message.question[0].rdtype == dns.rdatatype.AXFR:
+ if self._currentSerial != 0:
+ print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
+ return (None, self._currentSerial)
+
+ newSerial = self._targetSerial
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+ ]
+
+ elif message.question[0].rdtype == dns.rdatatype.IXFR:
+ oldSerial = message.authority[0][0].serial
+
+ newSerial = self._targetSerial
+ if newSerial == 2:
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+ # no deletion
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ ]
+ elif newSerial == 3:
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ ]
+
+ response.answer = records
+ return (newSerial, response)
+
+ def _connectionHandler(self, conn):
+ data = None
+ while True:
+ data = conn.recv(2)
+ if not data:
+ break
+ (datalen,) = struct.unpack("!H", data)
+ data = conn.recv(datalen)
+ if not data:
+ break
+
+ message = dns.message.from_wire(data)
+ if len(message.question) != 1:
+ print('Invalid query, qdcount is %d' % (len(message.question)), file=sys.stderr)
+ break
+ if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
+ print('Invalid query, qtype is %d' % (message.question.rdtype), file=sys.stderr)
+ break
+ (serial, answer) = self._getAnswer(message)
+ if not answer:
+ print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr)
+ break
+
+ wire = answer.to_wire()
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ self._currentSerial = serial
+ break
+
+ conn.close()
+
+ def _listener(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ try:
+ sock.bind(("127.0.0.1", self._serverPort))
+ except socket.error as e:
+ print("Error binding in the IXFR listener: %s" % str(e))
+ sys.exit(1)
+
+ sock.listen(100)
+ while True:
+ try:
+ (conn, _) = sock.accept()
+ thread = threading.Thread(name='IXFR Connection Handler',
+ target=self._connectionHandler,
+ args=[conn])
+ thread.setDaemon(True)
+ thread.start()
+
+ except socket.error as e:
+ print('Error in IXFR socket: %s' % str(e))
+ sock.close()
+
+badxfrServerPort = 4251
+badxfrServer = BadXFRServer(badxfrServerPort)
+
+class XFRIncompleteAuthTest(AuthTest):
+ """
+ This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR
+ """
+
+ global badxfrServerPort
+ _config_template = """
+launch=gsqlite3 bind
+gsqlite3-database=configs/auth/powerdns.sqlite
+gsqlite3-dnssec
+slave
+cache-ttl=0
+query-cache-ttl=0
+domain-metadata-cache-ttl=0
+negquery-cache-ttl=0
+slave-cycle-interval=1
+"""
+
+ @classmethod
+ def setUpClass(cls):
+ super(XFRIncompleteAuthTest, cls).setUpClass()
+ os.system("$PDNSUTIL --config-dir=configs/auth create-slave-zone zone.rpz. 127.0.0.1:%s" % (badxfrServerPort,))
+ os.system("$PDNSUTIL --config-dir=configs/auth set-meta zone.rpz. IXFR 1")
+
+ def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
+ global badxfrServer
+
+ badxfrServer.moveToSerial(serial)
+
+ attempts = 0
+ while attempts < timeout:
+ currentSerial = badxfrServer.getCurrentSerial()
+ if currentSerial > serial:
+ raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
+ if currentSerial == serial:
+ return
+
+ attempts = attempts + 1
+ time.sleep(1)
+
+ raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
+
+ def checkZone(self):
+ query = dns.message.make_query('zone.rpz.', 'SOA')
+ res = self.sendUDPQuery(query) # , count=len(expected))
+
+ expected = [dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. 1 3600 3600 3600 1')]
+ self.assertEqual(res.answer, expected)
+
+ def doRetrieve(self):
+ os.system("$PDNSCONTROL --socket-dir=configs/auth retrieve zone.rpz.")
+
+ def testXFR(self):
+ # self.waitForTCPSocket("127.0.0.1", self._wsPort)
+ # First zone
+ self.doRetrieve()
+ self.waitUntilCorrectSerialIsLoaded(1)
+ self.checkZone()
+
+ # second zone, should fail, incomplete IXFR
+ self.doRetrieve()
+ self.waitUntilCorrectSerialIsLoaded(2)
+ self.checkZone()
+
+ # third zone, should fail, incomplete AXFR
+ self.doRetrieve()
+ self.waitUntilCorrectSerialIsLoaded(3)
+ self.checkZone()
--- /dev/null
+import dns
+import json
+import os
+import requests
+import socket
+import struct
+import sys
+import threading
+import time
+
+from recursortests import RecursorTest
+
+class BadRPZServer(object):
+
+ def __init__(self, port):
+ self._currentSerial = 0
+ self._targetSerial = 1
+ self._serverPort = port
+ listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
+ listener.setDaemon(True)
+ listener.start()
+
+ def getCurrentSerial(self):
+ return self._currentSerial
+
+ def moveToSerial(self, newSerial):
+ if newSerial == self._currentSerial:
+ return False
+
+ #if newSerial != self._currentSerial + 1:
+ # raise AssertionError("Asking the RPZ server to serve serial %d, already serving %d" % (newSerial, self._currentSerial))
+ self._targetSerial = newSerial
+ return True
+
+ def _getAnswer(self, message):
+
+ response = dns.message.make_response(message)
+ records = []
+
+ if message.question[0].rdtype == dns.rdatatype.AXFR:
+ if self._currentSerial != 0:
+ print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
+ return (None, self._currentSerial)
+
+ newSerial = self._targetSerial
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+ ]
+
+ elif message.question[0].rdtype == dns.rdatatype.IXFR:
+ oldSerial = message.authority[0][0].serial
+
+ newSerial = self._targetSerial
+ if newSerial == 2:
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+ # no deletion
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ ]
+ elif newSerial == 3:
+ records = [
+ dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+ dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+ ]
+
+ response.answer = records
+ return (newSerial, response)
+
+ def _connectionHandler(self, conn):
+ data = None
+ while True:
+ data = conn.recv(2)
+ if not data:
+ break
+ (datalen,) = struct.unpack("!H", data)
+ data = conn.recv(datalen)
+ if not data:
+ break
+
+ message = dns.message.from_wire(data)
+ if len(message.question) != 1:
+ print('Invalid RPZ query, qdcount is %d' % (len(message.question)), file=sys.stderr)
+ break
+ if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
+ print('Invalid RPZ query, qtype is %d' % (message.question.rdtype), file=sys.stderr)
+ break
+ (serial, answer) = self._getAnswer(message)
+ if not answer:
+ print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr)
+ break
+
+ wire = answer.to_wire()
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ self._currentSerial = serial
+ break
+
+ conn.close()
+
+ def _listener(self):
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+ try:
+ sock.bind(("127.0.0.1", self._serverPort))
+ except socket.error as e:
+ print("Error binding in the RPZ listener: %s" % str(e))
+ sys.exit(1)
+
+ sock.listen(100)
+ while True:
+ try:
+ (conn, _) = sock.accept()
+ thread = threading.Thread(name='RPZ Connection Handler',
+ target=self._connectionHandler,
+ args=[conn])
+ thread.setDaemon(True)
+ thread.start()
+
+ except socket.error as e:
+ print('Error in RPZ socket: %s' % str(e))
+ sock.close()
+
+class RPZIncompleteRecursorTest(RecursorTest):
+ _wsPort = 8042
+ _wsTimeout = 2
+ _wsPassword = 'secretpassword'
+ _apiKey = 'secretapikey'
+ _confdir = 'RPZIncomplete'
+ _auth_zones = {
+ '8': {'threads': 1,
+ 'zones': ['ROOT']},
+ '10': {'threads': 1,
+ 'zones': ['example']},
+ }
+
+ _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+log-rpz-changes=yes
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+ def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount, failedXFRCount):
+ headers = {'x-api-key': self._apiKey}
+ url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
+ r = requests.get(url, headers=headers, timeout=self._wsTimeout)
+ self.assertTrue(r)
+ self.assertEqual(r.status_code, 200)
+ self.assertTrue(r.json())
+ content = r.json()
+ self.assertIn('zone.rpz.', content)
+ zone = content['zone.rpz.']
+ for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
+ self.assertIn(key, zone)
+
+ self.assertEqual(zone['serial'], serial)
+ self.assertEqual(zone['records'], recordsCount)
+ self.assertEqual(zone['transfers_full'], fullXFRCount)
+ self.assertEqual(zone['transfers_success'], totalXFRCount)
+ self.assertEqual(zone['transfers_failed'], failedXFRCount)
+
+badrpzServerPort = 4251
+badrpzServer = BadRPZServer(badrpzServerPort)
+
+class RPZXFRIncompleteRecursorTest(RPZIncompleteRecursorTest):
+ """
+ This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR
+ """
+
+ global badrpzServerPort
+ _lua_config_file = """
+ -- The first server is a bogus one, to test that we correctly fail over to the second one
+ rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
+ """ % (badrpzServerPort)
+ _confdir = 'RPZXFRIncomplete'
+ _wsPort = 8042
+ _wsTimeout = 2
+ _wsPassword = 'secretpassword'
+ _apiKey = 'secretapikey'
+ _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+ @classmethod
+ def generateRecursorConfig(cls, confdir):
+ authzonepath = os.path.join(confdir, 'example.zone')
+ with open(authzonepath, 'w') as authzone:
+ authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+ super(RPZIncompleteRecursorTest, cls).generateRecursorConfig(confdir)
+
+ def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
+ global badrpzServer
+
+ badrpzServer.moveToSerial(serial)
+
+ attempts = 0
+ while attempts < timeout:
+ currentSerial = badrpzServer.getCurrentSerial()
+ if currentSerial > serial:
+ raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
+ if currentSerial == serial:
+ return
+
+ attempts = attempts + 1
+ time.sleep(1)
+
+ raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
+
+ def testRPZ(self):
+ self.waitForTCPSocket("127.0.0.1", self._wsPort)
+ # First zone
+ self.waitUntilCorrectSerialIsLoaded(1)
+ self.checkRPZStats(1, 1, 1, 1, 1) # failure count includes a port 9999 attempt
+
+ # second zone, should fail, incomplete IXFR
+ self.waitUntilCorrectSerialIsLoaded(2)
+ self.checkRPZStats(1, 1, 1, 1, 3)
+
+ # third zone, should fail, incomplete AXFR
+ self.waitUntilCorrectSerialIsLoaded(3)
+ self.checkRPZStats(1, 1, 1, 1, 5)
+