From 8bed4b385814c10c927e970d5e3118385c2fa1b5 Mon Sep 17 00:00:00 2001 From: Otto Moerbeek Date: Tue, 8 Mar 2022 15:36:48 +0100 Subject: [PATCH] auth, rec IXFR-in: Fix a case where an incomplete read caused by network error might result in a truncated zone. As we might break from the loop early, we need to check if the end SOA was seen after the loop. Also make sure we detect end conditions for both AXFR and IXFR style properly, to avoid processing data after the end marker. --- pdns/ixfr.cc | 51 +++- regression-tests.auth-py/runtests | 1 + .../test_XFRIncomplete.py | 197 ++++++++++++++ .../test_RPZIncomplete.py | 241 ++++++++++++++++++ 4 files changed, 483 insertions(+), 7 deletions(-) create mode 100644 regression-tests.auth-py/test_XFRIncomplete.py create mode 100644 regression-tests.recursor-dnssec/test_RPZIncomplete.py diff --git a/pdns/ixfr.cc b/pdns/ixfr.cc index d299568488..1154eb0136 100644 --- a/pdns/ixfr.cc +++ b/pdns/ixfr.cc @@ -174,13 +174,21 @@ vector, vector > > getIXFRDeltas(const ComboAd std::shared_ptr primarySOA = nullptr; vector records; size_t receivedBytes = 0; - int8_t ixfrInProgress = -2; std::string reply; + enum transferStyle { Unknown, AXFR, IXFR } style = Unknown; + const unsigned int expectedSOAForAXFR = 2; + const unsigned int expectedSOAForIXFR = 3; + unsigned int primarySOACount = 0; + for(;;) { - // IXFR end - if (ixfrInProgress >= 0) + // IXFR or AXFR style end reached? We don't want to process trailing data after the closing SOA + if (style == AXFR && primarySOACount == expectedSOAForAXFR) { + break; + } + else if (style == IXFR && primarySOACount == expectedSOAForIXFR) { break; + } if(s.read((char*)&len, sizeof(len)) != sizeof(len)) break; @@ -225,16 +233,31 @@ vector, vector > > getIXFRDeltas(const ComboAd return ret; } primarySOA = sr; + ++primarySOACount; } else if (r.first.d_type == QType::SOA) { auto sr = getRR(r.first); if (!sr) { throw std::runtime_error("Error getting the content of SOA record of IXFR answer for zone '"+zone.toLogString()+"' from primary '"+primary.toStringWithPort()+"'"); } - // we hit the last SOA record - // IXFR is considered to be done if we hit the last SOA record twice + // we hit a marker SOA record if (primarySOA->d_st.serial == sr->d_st.serial) { - ixfrInProgress++; + ++primarySOACount; + } + } + // When we see the 2nd record, we can decide what the style is + if (records.size() == 1 && style == Unknown) { + if (r.first.d_type != QType::SOA) { + // Non-empty AXFR style has a non-SOA record following the first SOA + style = AXFR; + } + else if (primarySOACount == expectedSOAForAXFR) { + // Empty zone AXFR style: start SOA is immediately followed by end marker SOA + style = AXFR; + } + else { + // IXFR has a 2nd SOA (with different serial) following the first + style = IXFR; } } @@ -253,7 +276,21 @@ vector, vector > > getIXFRDeltas(const ComboAd } } - // cout<<"Got "< serial: + raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial)) + if currentSerial == serial: + return + + attempts = attempts + 1 + time.sleep(1) + + raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial)) + + def checkZone(self): + query = dns.message.make_query('zone.rpz.', 'SOA') + res = self.sendUDPQuery(query) # , count=len(expected)) + + expected = [dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. 1 3600 3600 3600 1')] + self.assertEqual(res.answer, expected) + + def doRetrieve(self): + os.system("$PDNSCONTROL --socket-dir=configs/auth retrieve zone.rpz.") + + def testXFR(self): + # self.waitForTCPSocket("127.0.0.1", self._wsPort) + # First zone + self.doRetrieve() + self.waitUntilCorrectSerialIsLoaded(1) + self.checkZone() + + # second zone, should fail, incomplete IXFR + self.doRetrieve() + self.waitUntilCorrectSerialIsLoaded(2) + self.checkZone() + + # third zone, should fail, incomplete AXFR + self.doRetrieve() + self.waitUntilCorrectSerialIsLoaded(3) + self.checkZone() diff --git a/regression-tests.recursor-dnssec/test_RPZIncomplete.py b/regression-tests.recursor-dnssec/test_RPZIncomplete.py new file mode 100644 index 0000000000..b5a0e8a985 --- /dev/null +++ b/regression-tests.recursor-dnssec/test_RPZIncomplete.py @@ -0,0 +1,241 @@ +import dns +import json +import os +import requests +import socket +import struct +import sys +import threading +import time + +from recursortests import RecursorTest + +class BadRPZServer(object): + + def __init__(self, port): + self._currentSerial = 0 + self._targetSerial = 1 + self._serverPort = port + listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[]) + listener.setDaemon(True) + listener.start() + + def getCurrentSerial(self): + return self._currentSerial + + def moveToSerial(self, newSerial): + if newSerial == self._currentSerial: + return False + + #if newSerial != self._currentSerial + 1: + # raise AssertionError("Asking the RPZ server to serve serial %d, already serving %d" % (newSerial, self._currentSerial)) + self._targetSerial = newSerial + return True + + def _getAnswer(self, message): + + response = dns.message.make_response(message) + records = [] + + if message.question[0].rdtype == dns.rdatatype.AXFR: + if self._currentSerial != 0: + print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial)) + return (None, self._currentSerial) + + newSerial = self._targetSerial + records = [ + dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), + dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'), + dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial) + ] + + elif message.question[0].rdtype == dns.rdatatype.IXFR: + oldSerial = message.authority[0][0].serial + + newSerial = self._targetSerial + if newSerial == 2: + records = [ + dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), + dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial), + # no deletion + dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), + dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'), + ] + elif newSerial == 3: + records = [ + dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial), + dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'), + ] + + response.answer = records + return (newSerial, response) + + def _connectionHandler(self, conn): + data = None + while True: + data = conn.recv(2) + if not data: + break + (datalen,) = struct.unpack("!H", data) + data = conn.recv(datalen) + if not data: + break + + message = dns.message.from_wire(data) + if len(message.question) != 1: + print('Invalid RPZ query, qdcount is %d' % (len(message.question)), file=sys.stderr) + break + if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]: + print('Invalid RPZ query, qtype is %d' % (message.question.rdtype), file=sys.stderr) + break + (serial, answer) = self._getAnswer(message) + if not answer: + print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr) + break + + wire = answer.to_wire() + conn.send(struct.pack("!H", len(wire))) + conn.send(wire) + self._currentSerial = serial + break + + conn.close() + + def _listener(self): + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + try: + sock.bind(("127.0.0.1", self._serverPort)) + except socket.error as e: + print("Error binding in the RPZ listener: %s" % str(e)) + sys.exit(1) + + sock.listen(100) + while True: + try: + (conn, _) = sock.accept() + thread = threading.Thread(name='RPZ Connection Handler', + target=self._connectionHandler, + args=[conn]) + thread.setDaemon(True) + thread.start() + + except socket.error as e: + print('Error in RPZ socket: %s' % str(e)) + sock.close() + +class RPZIncompleteRecursorTest(RecursorTest): + _wsPort = 8042 + _wsTimeout = 2 + _wsPassword = 'secretpassword' + _apiKey = 'secretapikey' + _confdir = 'RPZIncomplete' + _auth_zones = { + '8': {'threads': 1, + 'zones': ['ROOT']}, + '10': {'threads': 1, + 'zones': ['example']}, + } + + _config_template = """ +auth-zones=example=configs/%s/example.zone +webserver=yes +webserver-port=%d +webserver-address=127.0.0.1 +webserver-password=%s +api-key=%s +log-rpz-changes=yes +""" % (_confdir, _wsPort, _wsPassword, _apiKey) + + def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount, failedXFRCount): + headers = {'x-api-key': self._apiKey} + url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics' + r = requests.get(url, headers=headers, timeout=self._wsTimeout) + self.assertTrue(r) + self.assertEqual(r.status_code, 200) + self.assertTrue(r.json()) + content = r.json() + self.assertIn('zone.rpz.', content) + zone = content['zone.rpz.'] + for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']: + self.assertIn(key, zone) + + self.assertEqual(zone['serial'], serial) + self.assertEqual(zone['records'], recordsCount) + self.assertEqual(zone['transfers_full'], fullXFRCount) + self.assertEqual(zone['transfers_success'], totalXFRCount) + self.assertEqual(zone['transfers_failed'], failedXFRCount) + +badrpzServerPort = 4251 +badrpzServer = BadRPZServer(badrpzServerPort) + +class RPZXFRIncompleteRecursorTest(RPZIncompleteRecursorTest): + """ + This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR + """ + + global badrpzServerPort + _lua_config_file = """ + -- The first server is a bogus one, to test that we correctly fail over to the second one + rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 }) + """ % (badrpzServerPort) + _confdir = 'RPZXFRIncomplete' + _wsPort = 8042 + _wsTimeout = 2 + _wsPassword = 'secretpassword' + _apiKey = 'secretapikey' + _config_template = """ +auth-zones=example=configs/%s/example.zone +webserver=yes +webserver-port=%d +webserver-address=127.0.0.1 +webserver-password=%s +api-key=%s +""" % (_confdir, _wsPort, _wsPassword, _apiKey) + + @classmethod + def generateRecursorConfig(cls, confdir): + authzonepath = os.path.join(confdir, 'example.zone') + with open(authzonepath, 'w') as authzone: + authzone.write("""$ORIGIN example. +@ 3600 IN SOA {soa} +a 3600 IN A 192.0.2.42 +b 3600 IN A 192.0.2.42 +c 3600 IN A 192.0.2.42 +d 3600 IN A 192.0.2.42 +e 3600 IN A 192.0.2.42 +""".format(soa=cls._SOA)) + super(RPZIncompleteRecursorTest, cls).generateRecursorConfig(confdir) + + def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5): + global badrpzServer + + badrpzServer.moveToSerial(serial) + + attempts = 0 + while attempts < timeout: + currentSerial = badrpzServer.getCurrentSerial() + if currentSerial > serial: + raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial)) + if currentSerial == serial: + return + + attempts = attempts + 1 + time.sleep(1) + + raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial)) + + def testRPZ(self): + self.waitForTCPSocket("127.0.0.1", self._wsPort) + # First zone + self.waitUntilCorrectSerialIsLoaded(1) + self.checkRPZStats(1, 1, 1, 1, 1) # failure count includes a port 9999 attempt + + # second zone, should fail, incomplete IXFR + self.waitUntilCorrectSerialIsLoaded(2) + self.checkRPZStats(1, 1, 1, 1, 3) + + # third zone, should fail, incomplete AXFR + self.waitUntilCorrectSerialIsLoaded(3) + self.checkRPZStats(1, 1, 1, 1, 5) + -- 2.47.2