]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix invalid proxy protocol payload on a DoH TC to TCP retry 11604/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 4 May 2022 16:38:22 +0000 (18:38 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 4 May 2022 16:38:22 +0000 (18:38 +0200)
dnsdist forwards incoming DoH queries to its backend over UDP, and
retry over TCP if the response is truncated (TC=1).
When the proxy protocol is used between dnsdist and its backend, the
second query, over TCP, needs to take into account that the proxy
protocol payload has already been handled. This was not properly done
in that exact case because the proxy protocol payload length was not
propagated to the code handling the TCP communication, leading to
the query ID being edited at the wrong offset in the packet and thus
to an invalid proxy protocol payload.

pdns/dnsdist-tcp.cc
pdns/dnsdistdist/dnsdist-tcp-downstream.cc
pdns/dnsdistdist/dnsdist-tcp.hh
pdns/dnsdistdist/doh.cc
regression-tests.dnsdist/dnsdistdohtests.py [new file with mode: 0644]
regression-tests.dnsdist/test_DOH.py
regression-tests.dnsdist/test_ProxyProtocol.py

index e0de9f35e6acd1c826518b0c04fbdbc34f2632cc..a5b02e19a654a2456ac7b01a67d370995f99f3bd 100644 (file)
@@ -1204,6 +1204,7 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par
       auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
 
       prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
+      query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize;
       downstream->queueQuery(tqs, std::move(query));
     }
     catch (...) {
index 40c3905a12eb42c34ed8f3369ff5247ebc326e8c..1dbf7cc96c9450316c1164f363c88cd322f2cd59 100644 (file)
@@ -169,19 +169,20 @@ static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState quer
     if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) {
       query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end());
       query.d_proxyProtocolPayloadAdded = true;
+      query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size();
     }
   }
   else if (connectionState == ConnectionState::proxySent) {
     if (query.d_proxyProtocolPayloadAdded) {
-      if (query.d_buffer.size() < query.d_proxyProtocolPayload.size()) {
+      if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) {
         throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size()));
       }
-      query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayload.size());
+      query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize);
       query.d_proxyProtocolPayloadAdded = false;
+      query.d_proxyProtocolPayloadAddedSize = 0;
     }
   }
-
-  editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true);
+  editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true);
 }
 
 IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
index 9154f2f650f2d237a9e6a4a8f3a2b8149b191242..1e896b473e073e8c002b24201b4d15e693fb9f14 100644 (file)
@@ -83,22 +83,8 @@ struct InternalQuery
   {
   }
 
-  InternalQuery(InternalQuery&& rhs) :
-    d_idstate(std::move(rhs.d_idstate)), d_proxyProtocolPayload(std::move(rhs.d_proxyProtocolPayload)), d_buffer(std::move(rhs.d_buffer)), d_xfrMasterSerial(rhs.d_xfrMasterSerial), d_xfrSerialCount(rhs.d_xfrSerialCount), d_downstreamFailures(rhs.d_downstreamFailures), d_xfrMasterSerialCount(rhs.d_xfrMasterSerialCount), d_proxyProtocolPayloadAdded(rhs.d_proxyProtocolPayloadAdded)
-  {
-  }
-  InternalQuery& operator=(InternalQuery&& rhs)
-  {
-    d_idstate = std::move(rhs.d_idstate);
-    d_buffer = std::move(rhs.d_buffer);
-    d_proxyProtocolPayload = std::move(rhs.d_proxyProtocolPayload);
-    d_xfrMasterSerial = rhs.d_xfrMasterSerial;
-    d_xfrSerialCount = rhs.d_xfrSerialCount;
-    d_downstreamFailures = rhs.d_downstreamFailures;
-    d_xfrMasterSerialCount = rhs.d_xfrMasterSerialCount;
-    d_proxyProtocolPayloadAdded = rhs.d_proxyProtocolPayloadAdded;
-    return *this;
-  }
+  InternalQuery(InternalQuery&& rhs) = default;
+  InternalQuery& operator=(InternalQuery&& rhs) = default;
 
   InternalQuery(const InternalQuery& rhs) = delete;
   InternalQuery& operator=(const InternalQuery& rhs) = delete;
@@ -111,6 +97,7 @@ struct InternalQuery
   IDState d_idstate;
   std::string d_proxyProtocolPayload;
   PacketBuffer d_buffer;
+  uint32_t d_proxyProtocolPayloadAddedSize{0};
   uint32_t d_xfrMasterSerial{0};
   uint32_t d_xfrSerialCount{0};
   uint32_t d_downstreamFailures{0};
index da1e8ef838cd99ccb1d129cb99e443b7690cb59f..d74d52665f1406a7098b0fdec031086e543616bd 100644 (file)
@@ -699,7 +699,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& du)
 
     if (du->downstream->d_config.useProxyProtocol) {
       size_t payloadSize = 0;
-      if (addProxyProtocol(dq)) {
+      if (addProxyProtocol(dq, &payloadSize)) {
         du->proxyProtocolPayloadSize = payloadSize;
       }
     }
diff --git a/regression-tests.dnsdist/dnsdistdohtests.py b/regression-tests.dnsdist/dnsdistdohtests.py
new file mode 100644 (file)
index 0000000..cd19d09
--- /dev/null
@@ -0,0 +1,145 @@
+#!/usr/bin/env python
+import base64
+import dns
+import os
+import unittest
+
+from dnsdisttests import DNSDistTest
+
+import pycurl
+from io import BytesIO
+
+@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
+class DNSDistDOHTest(DNSDistTest):
+
+    @classmethod
+    def getDOHGetURL(cls, baseurl, query, rawQuery=False):
+        if rawQuery:
+            wire = query
+        else:
+            wire = query.to_wire()
+        param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+        return baseurl + "?dns=" + param
+
+    @classmethod
+    def openDOHConnection(cls, port, caFile, timeout=2.0):
+        conn = pycurl.Curl()
+        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
+
+        conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
+                                         "Accept: application/dns-message"])
+        return conn
+
+    @classmethod
+    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
+        url = cls.getDOHGetURL(baseurl, query, rawQuery)
+        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+        response_headers = BytesIO()
+        #conn.setopt(pycurl.VERBOSE, True)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        if useHTTPS:
+            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+            if caFile:
+                conn.setopt(pycurl.CAINFO, caFile)
+
+        conn.setopt(pycurl.HTTPHEADER, customHeaders)
+        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
+
+        if response:
+            if toQueue:
+                toQueue.put(response, True, timeout)
+            else:
+                cls._toResponderQueue.put(response, True, timeout)
+
+        receivedQuery = None
+        message = None
+        cls._response_headers = ''
+        data = conn.perform_rb()
+        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        if cls._rcode == 200 and not rawResponse:
+            message = dns.message.from_wire(data)
+        elif rawResponse:
+            message = data
+
+        if useQueue:
+            if fromQueue:
+                if not fromQueue.empty():
+                    receivedQuery = fromQueue.get(True, timeout)
+            else:
+                if not cls._fromResponderQueue.empty():
+                    receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        cls._response_headers = response_headers.getvalue()
+        return (receivedQuery, message)
+
+    @classmethod
+    def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
+        url = baseurl
+        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+        response_headers = BytesIO()
+        #conn.setopt(pycurl.VERBOSE, True)
+        conn.setopt(pycurl.URL, url)
+        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+        if useHTTPS:
+            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+            if caFile:
+                conn.setopt(pycurl.CAINFO, caFile)
+
+        conn.setopt(pycurl.HTTPHEADER, customHeaders)
+        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
+        conn.setopt(pycurl.POST, True)
+        data = query
+        if not rawQuery:
+            data = data.to_wire()
+
+        conn.setopt(pycurl.POSTFIELDS, data)
+
+        if response:
+            cls._toResponderQueue.put(response, True, timeout)
+
+        receivedQuery = None
+        message = None
+        cls._response_headers = ''
+        data = conn.perform_rb()
+        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+        if cls._rcode == 200 and not rawResponse:
+            message = dns.message.from_wire(data)
+        elif rawResponse:
+            message = data
+
+        if useQueue and not cls._fromResponderQueue.empty():
+            receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+        cls._response_headers = response_headers.getvalue()
+        return (receivedQuery, message)
+
+    def getHeaderValue(self, name):
+        for header in self._response_headers.decode().splitlines(False):
+            values = header.split(':')
+            key = values[0]
+            if key.lower() == name.lower():
+                return values[1].strip()
+        return None
+
+    def checkHasHeader(self, name, value):
+        got = self.getHeaderValue(name)
+        self.assertEqual(got, value)
+
+    def checkNoHeader(self, name):
+        self.checkHasHeader(name, None)
+
+    @classmethod
+    def setUpClass(cls):
+
+        # for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python
+        if 'SKIP_DOH_TESTS' in os.environ:
+            raise unittest.SkipTest('DNS over HTTPS tests are disabled')
+
+        cls.startResponders()
+        cls.startDNSDist()
+        cls.setUpSockets()
+
+        print("Launching tests..")
index ee99e21d5672f4302ed3aa396383a342882669ec..4f7d1ec60095f50c4760e4652f454ab1c308dd64 100644 (file)
 import base64
 import dns
 import os
-import re
 import time
 import unittest
 import clientsubnetoption
-from dnsdisttests import DNSDistTest
+
+from dnsdistdohtests import DNSDistDOHTest
 
 import pycurl
 from io import BytesIO
 
-@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
-class DNSDistDOHTest(DNSDistTest):
-
-    @classmethod
-    def getDOHGetURL(cls, baseurl, query, rawQuery=False):
-        if rawQuery:
-            wire = query
-        else:
-            wire = query.to_wire()
-        param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
-        return baseurl + "?dns=" + param
-
-    @classmethod
-    def openDOHConnection(cls, port, caFile, timeout=2.0):
-        conn = pycurl.Curl()
-        conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
-
-        conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
-                                         "Accept: application/dns-message"])
-        return conn
-
-    @classmethod
-    def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
-        url = cls.getDOHGetURL(baseurl, query, rawQuery)
-        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
-        response_headers = BytesIO()
-        #conn.setopt(pycurl.VERBOSE, True)
-        conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
-        if useHTTPS:
-            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
-            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
-            if caFile:
-                conn.setopt(pycurl.CAINFO, caFile)
-
-        conn.setopt(pycurl.HTTPHEADER, customHeaders)
-        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
-
-        if response:
-            cls._toResponderQueue.put(response, True, timeout)
-
-        receivedQuery = None
-        message = None
-        cls._response_headers = ''
-        data = conn.perform_rb()
-        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
-        if cls._rcode == 200 and not rawResponse:
-            message = dns.message.from_wire(data)
-        elif rawResponse:
-            message = data
-
-        if useQueue and not cls._fromResponderQueue.empty():
-            receivedQuery = cls._fromResponderQueue.get(True, timeout)
-
-        cls._response_headers = response_headers.getvalue()
-        return (receivedQuery, message)
-
-    @classmethod
-    def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
-        url = baseurl
-        conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
-        response_headers = BytesIO()
-        #conn.setopt(pycurl.VERBOSE, True)
-        conn.setopt(pycurl.URL, url)
-        conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
-        if useHTTPS:
-            conn.setopt(pycurl.SSL_VERIFYPEER, 1)
-            conn.setopt(pycurl.SSL_VERIFYHOST, 2)
-            if caFile:
-                conn.setopt(pycurl.CAINFO, caFile)
-
-        conn.setopt(pycurl.HTTPHEADER, customHeaders)
-        conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
-        conn.setopt(pycurl.POST, True)
-        data = query
-        if not rawQuery:
-            data = data.to_wire()
-
-        conn.setopt(pycurl.POSTFIELDS, data)
-
-        if response:
-            cls._toResponderQueue.put(response, True, timeout)
-
-        receivedQuery = None
-        message = None
-        cls._response_headers = ''
-        data = conn.perform_rb()
-        cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
-        if cls._rcode == 200 and not rawResponse:
-            message = dns.message.from_wire(data)
-        elif rawResponse:
-            message = data
-
-        if useQueue and not cls._fromResponderQueue.empty():
-            receivedQuery = cls._fromResponderQueue.get(True, timeout)
-
-        cls._response_headers = response_headers.getvalue()
-        return (receivedQuery, message)
-
-    def getHeaderValue(self, name):
-        for header in self._response_headers.decode().splitlines(False):
-            values = header.split(':')
-            key = values[0]
-            if key.lower() == name.lower():
-                return values[1].strip()
-        return None
-
-    def checkHasHeader(self, name, value):
-        got = self.getHeaderValue(name)
-        self.assertEqual(got, value)
-
-    def checkNoHeader(self, name):
-        self.checkHasHeader(name, None)
-
-    @classmethod
-    def setUpClass(cls):
-
-        # for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python
-        if 'SKIP_DOH_TESTS' in os.environ:
-            raise unittest.SkipTest('DNS over HTTPS tests are disabled')
-
-        cls.startResponders()
-        cls.startDNSDist()
-        cls.setUpSockets()
-
-        print("Launching tests..")
-
 class TestDOH(DNSDistDOHTest):
 
     _serverKey = 'server.key'
index bf073f74ed639de2465a6c678705baaa251e8148..7a507819493aca4216951548fac9c8abac316801 100644 (file)
@@ -9,6 +9,7 @@ import threading
 
 from dnsdisttests import DNSDistTest
 from proxyprotocol import ProxyProtocol
+from dnsdistdohtests import DNSDistDOHTest
 
 # Python2/3 compatibility hacks
 try:
@@ -720,3 +721,72 @@ class TestProxyProtocolNotExpected(DNSDistTest):
         except socket.timeout:
           print('timeout')
         self.assertEqual(receivedResponse, None)
+
+class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
+
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _dohServerPort = 8443
+    _dohBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohServerPort))
+    _proxyResponderPort = proxyResponderPort
+    _config_template = """
+    newServer{address="127.0.0.1:%s", useProxyProtocol=true}
+
+    addDOHLocal("127.0.0.1:%s", "%s", "%s")
+    """
+    _config_params = ['_proxyResponderPort', '_dohServerPort', '_serverCert', '_serverKey']
+
+    def testTruncation(self):
+        """
+        DOH: Truncation over UDP (with cache)
+        """
+        # the query is first forwarded over UDP, leading to a TC=1 answer from the
+        # backend, then over TCP
+        name = 'truncated-udp.doh-with-cache.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        query.id = 42
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery.id = 42
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        # first response is a TC=1
+        tcResponse = dns.message.make_response(query)
+        tcResponse.flags |= dns.flags.TC
+        toProxyQueue.put(tcResponse, True, 2.0)
+
+        ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
+        # first query, received by the responder over UDP
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        self.assertTrue(receivedQuery)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
+
+        # check the response
+        self.assertTrue(receivedResponse)
+        self.assertEqual(response, receivedResponse)
+
+        # check the second query, received by the responder over TCP
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedDNSData)
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        self.assertTrue(receivedQuery)
+        receivedQuery.id = expectedQuery.id
+        self.assertEqual(expectedQuery, receivedQuery)
+        self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
+
+        # make sure we consumed everything
+        self.assertTrue(toProxyQueue.empty())
+        self.assertTrue(fromProxyQueue.empty())