]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
auth, rec IXFR-in: Fix a case where an incomplete read caused by network error might... 11453/head auth-4.4.3
authorOtto Moerbeek <otto.moerbeek@open-xchange.com>
Tue, 8 Mar 2022 14:36:48 +0000 (15:36 +0100)
committerPeter van Dijk <peter.van.dijk@powerdns.com>
Wed, 16 Mar 2022 15:23:54 +0000 (16:23 +0100)
As we might break from the loop early, we need to check if the end SOA was seen after the loop.
Also make sure we detect end conditions for both AXFR and IXFR style properly, to avoid processing
data after the end marker.

pdns/ixfr.cc
regression-tests.auth-py/runtests
regression-tests.auth-py/test_XFRIncomplete.py [new file with mode: 0644]
regression-tests.recursor-dnssec/test_RPZIncomplete.py [new file with mode: 0644]

index bafead50e1ca9ecc66d5d095161ef747ec55d78f..95add2a2b0ab4cd9c6bdead61a50ce83e97e4e75 100644 (file)
@@ -174,13 +174,21 @@ vector<pair<vector<DNSRecord>, vector<DNSRecord> > > getIXFRDeltas(const ComboAd
   std::shared_ptr<SOARecordContent> masterSOA = nullptr;
   vector<DNSRecord> records;
   size_t receivedBytes = 0;
-  int8_t ixfrInProgress = -2;
   std::string reply;
 
+  enum transferStyle { Unknown, AXFR, IXFR } style = Unknown;
+  const unsigned int expectedSOAForAXFR = 2;
+  const unsigned int expectedSOAForIXFR = 3;
+  unsigned int masterSOACount = 0;
+
   for(;;) {
-    // IXFR end
-    if (ixfrInProgress >= 0)
+    // IXFR or AXFR style end reached? We don't want to process trailing data after the closing SOA
+    if (style == AXFR && masterSOACount == expectedSOAForAXFR) {
+      break;
+    }
+    else if (style == IXFR && masterSOACount == expectedSOAForIXFR) {
       break;
+    }
 
     if(s.read((char*)&len, sizeof(len)) != sizeof(len))
       break;
@@ -225,16 +233,31 @@ vector<pair<vector<DNSRecord>, vector<DNSRecord> > > getIXFRDeltas(const ComboAd
           return ret;
         }
         masterSOA = sr;
+        ++masterSOACount;
       } else if (r.first.d_type == QType::SOA) {
         auto sr = getRR<SOARecordContent>(r.first);
         if (!sr) {
           throw std::runtime_error("Error getting the content of SOA record of IXFR answer for zone '"+zone.toLogString()+"' from master '"+master.toStringWithPort()+"'");
         }
 
-        // we hit the last SOA record
-        // IXFR is considered to be done if we hit the last SOA record twice
+        // we hit a marker SOA record
         if (masterSOA->d_st.serial == sr->d_st.serial) {
-          ixfrInProgress++;
+          ++masterSOACount;
+        }
+      }
+      // When we see the 2nd record, we can decide what the style is
+      if (records.size() == 1 && style == Unknown) {
+        if (r.first.d_type != QType::SOA) {
+          // Non-empty AXFR style has a non-SOA record following the first SOA
+          style = AXFR;
+        }
+        else if (masterSOACount == expectedSOAForAXFR) {
+          // Empty zone AXFR style: start SOA is immediately followed by end marker SOA
+          style = AXFR;
+        }
+        else {
+          // IXFR has a 2nd SOA (with different serial) following the first
+          style = IXFR;
         }
       }
 
@@ -253,7 +276,21 @@ vector<pair<vector<DNSRecord>, vector<DNSRecord> > > getIXFRDeltas(const ComboAd
     }
   }
 
-  //  cout<<"Got "<<records.size()<<" records"<<endl;
+  switch (style) {
+  case IXFR:
+    if (masterSOACount != expectedSOAForIXFR) {
+      throw std::runtime_error("Incomplete IXFR transfer for '" + zone.toLogString() + "' from primary '" + master.toStringWithPort());
+    }
+    break;
+  case AXFR:
+    if (masterSOACount != expectedSOAForAXFR){
+      throw std::runtime_error("Incomplete AXFR style transfer for '" + zone.toLogString() + "' from primary '" + master.toStringWithPort());
+    }
+    break;
+  case Unknown:
+    throw std::runtime_error("Incomplete XFR for '" + zone.toLogString() + "' from primary '" + master.toStringWithPort());
+    break;
+  }
 
   return processIXFRRecords(master, zone, records, masterSOA);
 }
index 5a251623c306ef522a5540f57ea5783979ffb6a8..046a94b2d855db288c8a765b4260830d68c0b2f1 100755 (executable)
@@ -15,6 +15,7 @@ mkdir -p configs
 
 export PDNS=${PDNS:-${PWD}/../pdns/pdns_server}
 export PDNSUTIL=${PDNSUTIL:-${PWD}/../pdns/pdnsutil}
+export PDNSCONTROL=${PDNSCONTROL:-${PWD}/../pdns/pdns_control}
 
 export PREFIX=127.0.0
 
diff --git a/regression-tests.auth-py/test_XFRIncomplete.py b/regression-tests.auth-py/test_XFRIncomplete.py
new file mode 100644 (file)
index 0000000..5a9c10a
--- /dev/null
@@ -0,0 +1,197 @@
+import dns
+import json
+import os
+import requests
+import socket
+import struct
+import sys
+import threading
+import time
+
+from authtests import AuthTest
+
+class BadXFRServer(object):
+
+    def __init__(self, port):
+        self._currentSerial = 0
+        self._targetSerial = 1
+        self._serverPort = port
+        listener = threading.Thread(name='XFR Listener', target=self._listener, args=[])
+        listener.setDaemon(True)
+        listener.start()
+
+    def getCurrentSerial(self):
+        return self._currentSerial
+
+    def moveToSerial(self, newSerial):
+        if newSerial == self._currentSerial:
+            return False
+
+        #if newSerial != self._currentSerial + 1:
+        #    raise AssertionError("Asking the XFR server to serve serial %d, already serving %d" % (newSerial, self._currentSerial))
+        self._targetSerial = newSerial
+        return True
+
+    def _getAnswer(self, message):
+
+        response = dns.message.make_response(message)
+        records = []
+
+        if message.question[0].rdtype == dns.rdatatype.AXFR:
+            if self._currentSerial != 0:
+                print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
+                return (None, self._currentSerial)
+
+            newSerial = self._targetSerial
+            records = [
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+                ]
+
+        elif message.question[0].rdtype == dns.rdatatype.IXFR:
+            oldSerial = message.authority[0][0].serial
+
+            newSerial = self._targetSerial
+            if newSerial == 2:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+                    # no deletion
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+            elif newSerial == 3:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+
+        response.answer = records
+        return (newSerial, response)
+
+    def _connectionHandler(self, conn):
+        data = None
+        while True:
+            data = conn.recv(2)
+            if not data:
+                break
+            (datalen,) = struct.unpack("!H", data)
+            data = conn.recv(datalen)
+            if not data:
+                break
+
+            message = dns.message.from_wire(data)
+            if len(message.question) != 1:
+                print('Invalid query, qdcount is %d' % (len(message.question)), file=sys.stderr)
+                break
+            if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
+                print('Invalid query, qtype is %d' % (message.question.rdtype), file=sys.stderr)
+                break
+            (serial, answer) = self._getAnswer(message)
+            if not answer:
+                print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr)
+                break
+
+            wire = answer.to_wire()
+            conn.send(struct.pack("!H", len(wire)))
+            conn.send(wire)
+            self._currentSerial = serial
+            break
+
+        conn.close()
+
+    def _listener(self):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", self._serverPort))
+        except socket.error as e:
+            print("Error binding in the IXFR listener: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        while True:
+            try:
+                (conn, _) = sock.accept()
+                thread = threading.Thread(name='IXFR Connection Handler',
+                                      target=self._connectionHandler,
+                                      args=[conn])
+                thread.setDaemon(True)
+                thread.start()
+
+            except socket.error as e:
+                print('Error in IXFR socket: %s' % str(e))
+                sock.close()
+
+badxfrServerPort = 4251
+badxfrServer = BadXFRServer(badxfrServerPort)
+
+class XFRIncompleteAuthTest(AuthTest):
+    """
+    This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR
+    """
+
+    global badxfrServerPort
+    _config_template = """
+launch=gsqlite3 bind
+gsqlite3-database=configs/auth/powerdns.sqlite
+gsqlite3-dnssec
+slave
+cache-ttl=0
+query-cache-ttl=0
+domain-metadata-cache-ttl=0
+negquery-cache-ttl=0
+slave-cycle-interval=1
+"""
+
+    @classmethod
+    def setUpClass(cls):
+        super(XFRIncompleteAuthTest, cls).setUpClass()
+        os.system("$PDNSUTIL --config-dir=configs/auth create-slave-zone zone.rpz. 127.0.0.1:%s" % (badxfrServerPort,))
+        os.system("$PDNSUTIL --config-dir=configs/auth set-meta zone.rpz. IXFR 1")
+    
+    def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
+        global badxfrServer
+
+        badxfrServer.moveToSerial(serial)
+
+        attempts = 0
+        while attempts < timeout:
+            currentSerial = badxfrServer.getCurrentSerial()
+            if currentSerial > serial:
+                raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
+            if currentSerial == serial:
+                return
+
+            attempts = attempts + 1
+            time.sleep(1)
+
+        raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
+
+    def checkZone(self):
+        query = dns.message.make_query('zone.rpz.', 'SOA')
+        res = self.sendUDPQuery(query) # , count=len(expected))
+
+        expected = [dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. 1 3600 3600 3600 1')]
+        self.assertEqual(res.answer, expected)
+
+    def doRetrieve(self):
+        os.system("$PDNSCONTROL --socket-dir=configs/auth retrieve zone.rpz.")
+
+    def testXFR(self):
+        # self.waitForTCPSocket("127.0.0.1", self._wsPort)
+        # First zone
+        self.doRetrieve()
+        self.waitUntilCorrectSerialIsLoaded(1)
+        self.checkZone()
+
+        # second zone, should fail, incomplete IXFR
+        self.doRetrieve()
+        self.waitUntilCorrectSerialIsLoaded(2)
+        self.checkZone()
+
+        # third zone, should fail, incomplete AXFR
+        self.doRetrieve()
+        self.waitUntilCorrectSerialIsLoaded(3)
+        self.checkZone()
diff --git a/regression-tests.recursor-dnssec/test_RPZIncomplete.py b/regression-tests.recursor-dnssec/test_RPZIncomplete.py
new file mode 100644 (file)
index 0000000..b5a0e8a
--- /dev/null
@@ -0,0 +1,241 @@
+import dns
+import json
+import os
+import requests
+import socket
+import struct
+import sys
+import threading
+import time
+
+from recursortests import RecursorTest
+
+class BadRPZServer(object):
+
+    def __init__(self, port):
+        self._currentSerial = 0
+        self._targetSerial = 1
+        self._serverPort = port
+        listener = threading.Thread(name='RPZ Listener', target=self._listener, args=[])
+        listener.setDaemon(True)
+        listener.start()
+
+    def getCurrentSerial(self):
+        return self._currentSerial
+
+    def moveToSerial(self, newSerial):
+        if newSerial == self._currentSerial:
+            return False
+
+        #if newSerial != self._currentSerial + 1:
+        #    raise AssertionError("Asking the RPZ server to serve serial %d, already serving %d" % (newSerial, self._currentSerial))
+        self._targetSerial = newSerial
+        return True
+
+    def _getAnswer(self, message):
+
+        response = dns.message.make_response(message)
+        records = []
+
+        if message.question[0].rdtype == dns.rdatatype.AXFR:
+            if self._currentSerial != 0:
+                print('Received an AXFR query but IXFR expected because the current serial is %d' % (self._currentSerial))
+                return (None, self._currentSerial)
+
+            newSerial = self._targetSerial
+            records = [
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial)
+                ]
+
+        elif message.question[0].rdtype == dns.rdatatype.IXFR:
+            oldSerial = message.authority[0][0].serial
+
+            newSerial = self._targetSerial
+            if newSerial == 2:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % oldSerial),
+                    # no deletion
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('b.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+            elif newSerial == 3:
+                records = [
+                    dns.rrset.from_text('zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.SOA, 'ns.zone.rpz. hostmaster.zone.rpz. %d 3600 3600 3600 1' % newSerial),
+                    dns.rrset.from_text('a.example.zone.rpz.', 60, dns.rdataclass.IN, dns.rdatatype.A, '192.0.2.1'),
+                    ]
+
+        response.answer = records
+        return (newSerial, response)
+
+    def _connectionHandler(self, conn):
+        data = None
+        while True:
+            data = conn.recv(2)
+            if not data:
+                break
+            (datalen,) = struct.unpack("!H", data)
+            data = conn.recv(datalen)
+            if not data:
+                break
+
+            message = dns.message.from_wire(data)
+            if len(message.question) != 1:
+                print('Invalid RPZ query, qdcount is %d' % (len(message.question)), file=sys.stderr)
+                break
+            if not message.question[0].rdtype in [dns.rdatatype.AXFR, dns.rdatatype.IXFR]:
+                print('Invalid RPZ query, qtype is %d' % (message.question.rdtype), file=sys.stderr)
+                break
+            (serial, answer) = self._getAnswer(message)
+            if not answer:
+                print('Unable to get a response for %s %d' % (message.question[0].name, message.question[0].rdtype), file=sys.stderr)
+                break
+
+            wire = answer.to_wire()
+            conn.send(struct.pack("!H", len(wire)))
+            conn.send(wire)
+            self._currentSerial = serial
+            break
+
+        conn.close()
+
+    def _listener(self):
+        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
+        try:
+            sock.bind(("127.0.0.1", self._serverPort))
+        except socket.error as e:
+            print("Error binding in the RPZ listener: %s" % str(e))
+            sys.exit(1)
+
+        sock.listen(100)
+        while True:
+            try:
+                (conn, _) = sock.accept()
+                thread = threading.Thread(name='RPZ Connection Handler',
+                                      target=self._connectionHandler,
+                                      args=[conn])
+                thread.setDaemon(True)
+                thread.start()
+
+            except socket.error as e:
+                print('Error in RPZ socket: %s' % str(e))
+                sock.close()
+
+class RPZIncompleteRecursorTest(RecursorTest):
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _confdir = 'RPZIncomplete'
+    _auth_zones = {
+        '8': {'threads': 1,
+              'zones': ['ROOT']},
+        '10': {'threads': 1,
+               'zones': ['example']},
+    }
+
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+log-rpz-changes=yes
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+    def checkRPZStats(self, serial, recordsCount, fullXFRCount, totalXFRCount, failedXFRCount):
+        headers = {'x-api-key': self._apiKey}
+        url = 'http://127.0.0.1:' + str(self._wsPort) + '/api/v1/servers/localhost/rpzstatistics'
+        r = requests.get(url, headers=headers, timeout=self._wsTimeout)
+        self.assertTrue(r)
+        self.assertEqual(r.status_code, 200)
+        self.assertTrue(r.json())
+        content = r.json()
+        self.assertIn('zone.rpz.', content)
+        zone = content['zone.rpz.']
+        for key in ['last_update', 'records', 'serial', 'transfers_failed', 'transfers_full', 'transfers_success']:
+            self.assertIn(key, zone)
+
+        self.assertEqual(zone['serial'], serial)
+        self.assertEqual(zone['records'], recordsCount)
+        self.assertEqual(zone['transfers_full'], fullXFRCount)
+        self.assertEqual(zone['transfers_success'], totalXFRCount)
+        self.assertEqual(zone['transfers_failed'], failedXFRCount)
+
+badrpzServerPort = 4251
+badrpzServer = BadRPZServer(badrpzServerPort)
+
+class RPZXFRIncompleteRecursorTest(RPZIncompleteRecursorTest):
+    """
+    This test makes sure that we correctly detect incomplete RPZ zones via AXFR then IXFR
+    """
+
+    global badrpzServerPort
+    _lua_config_file = """
+    -- The first server is a bogus one, to test that we correctly fail over to the second one
+    rpzMaster({'127.0.0.1:9999', '127.0.0.1:%d'}, 'zone.rpz.', { refresh=1 })
+    """ % (badrpzServerPort)
+    _confdir = 'RPZXFRIncomplete'
+    _wsPort = 8042
+    _wsTimeout = 2
+    _wsPassword = 'secretpassword'
+    _apiKey = 'secretapikey'
+    _config_template = """
+auth-zones=example=configs/%s/example.zone
+webserver=yes
+webserver-port=%d
+webserver-address=127.0.0.1
+webserver-password=%s
+api-key=%s
+""" % (_confdir, _wsPort, _wsPassword, _apiKey)
+
+    @classmethod
+    def generateRecursorConfig(cls, confdir):
+        authzonepath = os.path.join(confdir, 'example.zone')
+        with open(authzonepath, 'w') as authzone:
+            authzone.write("""$ORIGIN example.
+@ 3600 IN SOA {soa}
+a 3600 IN A 192.0.2.42
+b 3600 IN A 192.0.2.42
+c 3600 IN A 192.0.2.42
+d 3600 IN A 192.0.2.42
+e 3600 IN A 192.0.2.42
+""".format(soa=cls._SOA))
+        super(RPZIncompleteRecursorTest, cls).generateRecursorConfig(confdir)
+
+    def waitUntilCorrectSerialIsLoaded(self, serial, timeout=5):
+        global badrpzServer
+
+        badrpzServer.moveToSerial(serial)
+
+        attempts = 0
+        while attempts < timeout:
+            currentSerial = badrpzServer.getCurrentSerial()
+            if currentSerial > serial:
+                raise AssertionError("Expected serial %d, got %d" % (serial, currentSerial))
+            if currentSerial == serial:
+                return
+
+            attempts = attempts + 1
+            time.sleep(1)
+
+        raise AssertionError("Waited %d seconds for the serial to be updated to %d but the serial is still %d" % (timeout, serial, currentSerial))
+
+    def testRPZ(self):
+        self.waitForTCPSocket("127.0.0.1", self._wsPort)
+        # First zone
+        self.waitUntilCorrectSerialIsLoaded(1)
+        self.checkRPZStats(1, 1, 1, 1, 1) # failure count includes a port 9999 attempt
+
+        # second zone, should fail, incomplete IXFR
+        self.waitUntilCorrectSerialIsLoaded(2)
+        self.checkRPZStats(1, 1, 1, 1, 3)
+
+        # third zone, should fail, incomplete AXFR
+        self.waitUntilCorrectSerialIsLoaded(3)
+        self.checkRPZStats(1, 1, 1, 1, 5)
+