]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
auth, rec IXFR-in: Fix a case where an incomplete read caused by network error might... 11466/head
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 8 Mar 2022 14:36:48 +0000 (15:36 +0100)
committerPeter van Dijk <peter.van.dijk@powerdns.com>
Mon, 28 Mar 2022 18:28:15 +0000 (20:28 +0200)
As we might break from the loop early, we need to check if the end SOA was seen after the loop.
Also make sure we detect end conditions for both AXFR and IXFR style properly, to avoid processing
data after the end marker.

pdns/ixfr.cc
regression-tests.auth-py/runtests
regression-tests.auth-py/test_XFRIncomplete.py [new file with mode: 0644]
regression-tests.recursor-dnssec/test_RPZIncomplete.py [new file with mode: 0644]

index d299568488d0afaa1e000b41fc4d462253af3015..1154eb01366c47155808ecd643deff2208578f19 100644 (file)
@@ -174,13 +174,21 @@ vector<pair<vector<DNSRecord>, vector<DNSRecord> > > getIXFRDeltas(const ComboAd
   std::shared_ptr<SOARecordContent> primarySOA = nullptr;
   vector<DNSRecord> records;
   size_t receivedBytes = 0;
-  int8_t ixfrInProgress = -2;
   std::string reply;
 
+  enum transferStyle { Unknown, AXFR, IXFR } style = Unknown;
+  const unsigned int expectedSOAForAXFR = 2;
+  const unsigned int expectedSOAForIXFR = 3;
+  unsigned int primarySOACount = 0;
+
   for(;;) {
-    // IXFR end
-    if (ixfrInProgress >= 0)
+    // IXFR or AXFR style end reached? We don't want to process trailing data after the closing SOA
+    if (style == AXFR && primarySOACount == expectedSOAForAXFR) {
+      break;
+    }
+    else if (style == IXFR && primarySOACount == expectedSOAForIXFR) {
       break;
+    }
 
     if(s.read((char*)&len, sizeof(len)) != sizeof(len))
       break;
@@ -225,16 +233,31 @@ vector<pair<vector<DNSRecord>, vector<DNSRecord> > > getIXFRDeltas(const ComboAd
           return ret;
         }
         primarySOA = sr;
+        ++primarySOACount;
       } else if (r.first.d_type == QType::SOA) {
         auto sr = getRR<SOARecordContent>(r.first);
         if (!sr) {
           throw std::runtime_error("Error getting the content of SOA record of IXFR answer for zone '"+zone.toLogString()+"' from primary '"+primary.toStringWithPort()+"'");
         }
 
-        // we hit the last SOA record
-        // IXFR is considered to be done if we hit the last SOA record twice
+        // we hit a marker SOA record
         if (primarySOA->d_st.serial == sr->d_st.serial) {
-          ixfrInProgress++;
+          ++primarySOACount;
+        }
+      }
+      // When we see the 2nd record, we can decide what the style is
+      if (records.size() == 1 && style == Unknown) {
+        if (r.first.d_type != QType::SOA) {
+          // Non-empty AXFR style has a non-SOA record following the first SOA
+          style = AXFR;
+        }
+        else if (primarySOACount == expectedSOAForAXFR) {
+          // Empty zone AXFR style: start SOA is immediately followed by end marker SOA
+          style = AXFR;
+        }
+        else {
+          // IXFR has a 2nd SOA (with different serial) following the first
+          style = IXFR;
         }
       }
 
@@ -253,7 +276,21 @@ vector<pair<vector<DNSRecord>, vector<DNSRecord> > > getIXFRDeltas(const ComboAd
     }
   }
 
-  //  cout<<"Got "<<records.size()<<" records"<<endl;
+  switch (style) {
+  case IXFR:
+    if (primarySOACount != expectedSOAForIXFR) {
+      throw std::runtime_error("Incomplete IXFR transfer for '" + zone.toLogString() + "' from primary '" + primary.toStringWithPort());
+    }
+    break;
+  case AXFR:
+    if (primarySOACount != expectedSOAForAXFR){
+      throw std::runtime_error("Incomplete AXFR style transfer for '" + zone.toLogString() + "' from primary '" + primary.toStringWithPort());
+    }
+    break;
+  case Unknown:
+    throw std::runtime_error("Incomplete XFR for '" + zone.toLogString() + "' from primary '" + primary.toStringWithPort());
+    break;
+  }
 
   return processIXFRRecords(primary, zone, records, primarySOA);
 }
index 496652c485077109ecc01fe0a8837851980b00de..064d0bc89fd9f68ec77a15bcaceb7ca5c80fdd74 100755 (executable)
@@ -16,6 +16,7 @@ mkdir -p configs
 
 export PDNS=${PDNS:-${PWD}/../pdns/pdns_server}
 export PDNSUTIL=${PDNSUTIL:-${PWD}/../pdns/pdnsutil}
+export PDNSCONTROL=${PDNSCONTROL:-${PWD}/../pdns/pdns_control}
 
 export PREFIX=127.0.0
 
diff --git a/regression-tests.auth-py/test_XFRIncomplete.py b/regression-tests.auth-py/test_XFRIncomplete.py
new file mode 100644 (file)
index 0000000..5a9c10a
--- /dev/null
@@ -0,0 +1,197 @@
+import dns
+import json
+import os
+import requests
+import socket
+import struct
+import sys
+import threading
+import time
+
+from authtests import AuthTest
+
+class BadXFRServer(object):
+
+    def __init__(self, port):
+        self._currentSerial = 0
+        self._targetSerial = 1
+        self._serverPort = port
+        listener = threading.Thread(name='XFR Listener', target=self._listener, args=[])
+        listener.setDaemon(True)
+        listener.start()
+
+    def getCurrentSerial(self):
+        return self._currentSerial
+
+    def moveToSerial(self, newSerial):
+        if newSerial == self._currentSerial:
+            return False
+
+        #if newSerial != self._currentSerial + 1:
+        #    raise AssertionError("Asking the XFR server to serve serial %d, already serving %d" % (newSerial, self._currentSerial))
+        self._targetSerial = newSerial
+        return True
+
+    def _getAnswer(self, message):
+
+        response = dns.message.make_response(message)
+        records = []
+
+        if message.question[0].rdtype == dns.rdatatype.AXFR:
+            if self._currentSerial != 0:
+                print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
+                return (None, self._currentSerial)
+
+            newSerial = self._targetSerial
+            records = [
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+                ]
+
+        elif message.question[0].rdtype == dns.rdatatype.IXFR:
+            oldSerial = message.authority[0][0].serial
+
+            newSerial = self._targetSerial
+            if newSerial == 2:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+                    # no deletion
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+            elif newSerial == 3:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+
+        response.answer = records
+        return (newSerial, response)
+
+    def _connectionHandler(self, conn):
+        data = None
+        while True:
+            data = conn.recv(2)
+            if not data:
+                break
+            (datalen,) = struct.unpack("!H", data)
+            data = conn.recv(datalen)
+            if not data:
+                break
+
+            message = dns.message.from_wire(data)
+            if len(message.question) != 1:
+                print('Invalid query, qdcount is %d' % (len(message.question)), file=sys.stderr)
+                break
+            if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
+                print('Invalid query, qtype is %d' % (message.question.rdtype), file=sys.stderr)
+                break
+            (serial, answer) = self._getAnswer(message)
+            if not answer:
+                print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr)
+                break
+
+            wire = answer.to_wire()
+            conn.send(struct.pack("!H", len(wire)))
+            conn.send(wire)
+            self._currentSerial = serial
+            break
+
+        conn.close()
+
+    def _listener(self):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", self._serverPort))
+        except socket.error as e:
+            print("Error binding in the IXFR listener: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        while True:
+            try:
+                (conn, _) = sock.accept()
+                thread = threading.Thread(name='IXFR Connection Handler',
+                                      target=self._connectionHandler,
+                                      args=[conn])
+                thread.setDaemon(True)
+                thread.start()
+
+            except socket.error as e:
+                print('Error in IXFR socket: %s' % str(e))
+                sock.close()
+
+badxfrServerPort = 4251
+badxfrServer = BadXFRServer(badxfrServerPort)
+
+class XFRIncompleteAuthTest(AuthTest):
+    """
+    This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR
+    """
+
+    global badxfrServerPort
+    _config_template = """
+launch=gsqlite3 bind
+gsqlite3-database=configs/auth/powerdns.sqlite
+gsqlite3-dnssec
+slave
+cache-ttl=0
+query-cache-ttl=0
+domain-metadata-cache-ttl=0
+negquery-cache-ttl=0
+slave-cycle-interval=1
+"""
+
+    @classmethod
+    def setUpClass(cls):
+        super(XFRIncompleteAuthTest, cls).setUpClass()
+        os.system("$PDNSUTIL --config-dir=configs/auth create-slave-zone zone.rpz. 127.0.0.1:%s" % (badxfrServerPort,))
+        os.system("$PDNSUTIL --config-dir=configs/auth set-meta zone.rpz. IXFR 1")
+    
+    def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
+        global badxfrServer
+
+        badxfrServer.moveToSerial(serial)
+
+        attempts = 0
+        while attempts < timeout:
+            currentSerial = badxfrServer.getCurrentSerial()
+            if currentSerial > serial:
+                raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
+            if currentSerial == serial:
+                return
+
+            attempts = attempts + 1
+            time.sleep(1)
+
+        raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
+
+    def checkZone(self):
+        query = dns.message.make_query('zone.rpz.', 'SOA')
+        res = self.sendUDPQuery(query) # , count=len(expected))
+
+        expected = [dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. 1 3600 3600 3600 1')]
+        self.assertEqual(res.answer, expected)
+
+    def doRetrieve(self):
+        os.system("$PDNSCONTROL --socket-dir=configs/auth retrieve zone.rpz.")
+
+    def testXFR(self):
+        # self.waitForTCPSocket("127.0.0.1", self._wsPort)
+        # First zone
+        self.doRetrieve()
+        self.waitUntilCorrectSerialIsLoaded(1)
+        self.checkZone()
+
+        # second zone, should fail, incomplete IXFR
+        self.doRetrieve()
+        self.waitUntilCorrectSerialIsLoaded(2)
+        self.checkZone()
+
+        # third zone, should fail, incomplete AXFR
+        self.doRetrieve()
+        self.waitUntilCorrectSerialIsLoaded(3)
+        self.checkZone()
diff --git a/regression-tests.recursor-dnssec/test_RPZIncomplete.py b/regression-tests.recursor-dnssec/test_RPZIncomplete.py
new file mode 100644 (file)
index 0000000..b5a0e8a
--- /dev/null
@@ -0,0 +1,241 @@
+import dns
+import json
+import os
+import requests
+import socket
+import struct
+import sys
+import threading
+import time
+
+from recursortests import RecursorTest
+
+class BadRPZServer(object):
+
+    def __init__(self, port):
+        self._currentSerial = 0
+        self._targetSerial = 1
+        self._serverPort = port
+        listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
+        listener.setDaemon(True)
+        listener.start()
+
+    def getCurrentSerial(self):
+        return self._currentSerial
+
+    def moveToSerial(self, newSerial):
+        if newSerial == self._currentSerial:
+            return False
+
+        #if newSerial != self._currentSerial + 1:
+        #    raise AssertionError("Asking the RPZ server to serve serial %d, already serving %d" % (newSerial, self._currentSerial))
+        self._targetSerial = newSerial
+        return True
+
+    def _getAnswer(self, message):
+
+        response = dns.message.make_response(message)
+        records = []
+
+        if message.question[0].rdtype == dns.rdatatype.AXFR:
+            if self._currentSerial != 0:
+                print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
+                return (None, self._currentSerial)
+
+            newSerial = self._targetSerial
+            records = [
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+                ]
+
+        elif message.question[0].rdtype == dns.rdatatype.IXFR:
+            oldSerial = message.authority[0][0].serial
+
+            newSerial = self._targetSerial
+            if newSerial == 2:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+                    # no deletion
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+            elif newSerial == 3:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+
+        response.answer = records
+        return (newSerial, response)
+
+    def _connectionHandler(self, conn):
+        data = None
+        while True:
+            data = conn.recv(2)
+            if not data:
+                break
+            (datalen,) = struct.unpack("!H", data)
+            data = conn.recv(datalen)
+            if not data:
+                break
+
+            message = dns.message.from_wire(data)
+            if len(message.question) != 1:
+                print('Invalid RPZ query, qdcount is %d' % (len(message.question)), file=sys.stderr)
+                break
+            if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
+                print('Invalid RPZ query, qtype is %d' % (message.question.rdtype), file=sys.stderr)
+                break
+            (serial, answer) = self._getAnswer(message)
+            if not answer:
+                print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr)
+                break
+
+            wire = answer.to_wire()
+            conn.send(struct.pack("!H", len(wire)))
+            conn.send(wire)
+            self._currentSerial = serial
+            break
+
+        conn.close()
+
+    def _listener(self):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", self._serverPort))
+        except socket.error as e:
+            print("Error binding in the RPZ listener: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        while True:
+            try:
+                (conn, _) = sock.accept()
+                thread = threading.Thread(name='RPZ Connection Handler',
+                                      target=self._connectionHandler,
+                                      args=[conn])
+                thread.setDaemon(True)
+                thread.start()
+
+            except socket.error as e:
+                print('Error in RPZ socket: %s' % str(e))
+                sock.close()
+
+class RPZIncompleteRecursorTest(RecursorTest):
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _confdir = 'RPZIncomplete'
+    _auth_zones = {
+        '8': {'threads': 1,
+              'zones': ['ROOT']},
+        '10': {'threads': 1,
+               'zones': ['example']},
+    }
+
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+log-rpz-changes=yes
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+    def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount, failedXFRCount):
+        headers = {'x-api-key': self._apiKey}
+        url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
+        r = requests.get(url, headers=headers, timeout=self._wsTimeout)
+        self.assertTrue(r)
+        self.assertEqual(r.status_code, 200)
+        self.assertTrue(r.json())
+        content = r.json()
+        self.assertIn('zone.rpz.', content)
+        zone = content['zone.rpz.']
+        for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
+            self.assertIn(key, zone)
+
+        self.assertEqual(zone['serial'], serial)
+        self.assertEqual(zone['records'], recordsCount)
+        self.assertEqual(zone['transfers_full'], fullXFRCount)
+        self.assertEqual(zone['transfers_success'], totalXFRCount)
+        self.assertEqual(zone['transfers_failed'], failedXFRCount)
+
+badrpzServerPort = 4251
+badrpzServer = BadRPZServer(badrpzServerPort)
+
+class RPZXFRIncompleteRecursorTest(RPZIncompleteRecursorTest):
+    """
+    This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR
+    """
+
+    global badrpzServerPort
+    _lua_config_file = """
+    -- The first server is a bogus one, to test that we correctly fail over to the second one
+    rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
+    """ % (badrpzServerPort)
+    _confdir = 'RPZXFRIncomplete'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+        super(RPZIncompleteRecursorTest, cls).generateRecursorConfig(confdir)
+
+    def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
+        global badrpzServer
+
+        badrpzServer.moveToSerial(serial)
+
+        attempts = 0
+        while attempts < timeout:
+            currentSerial = badrpzServer.getCurrentSerial()
+            if currentSerial > serial:
+                raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
+            if currentSerial == serial:
+                return
+
+            attempts = attempts + 1
+            time.sleep(1)
+
+        raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
+
+    def testRPZ(self):
+        self.waitForTCPSocket("127.0.0.1", self._wsPort)
+        # First zone
+        self.waitUntilCorrectSerialIsLoaded(1)
+        self.checkRPZStats(1, 1, 1, 1, 1) # failure count includes a port 9999 attempt
+
+        # second zone, should fail, incomplete IXFR
+        self.waitUntilCorrectSerialIsLoaded(2)
+        self.checkRPZStats(1, 1, 1, 1, 3)
+
+        # third zone, should fail, incomplete AXFR
+        self.waitUntilCorrectSerialIsLoaded(3)
+        self.checkRPZStats(1, 1, 1, 1, 5)
+