From: Remi Gacogne Date: Wed, 4 May 2022 16:38:22 +0000 (+0200) Subject: dnsdist: Fix invalid proxy protocol payload on a DoH TC to TCP retry X-Git-Tag: auth-4.8.0-alpha0~103^2 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=1c9c001cbe327023e5d490e5bc044d67ecae9cf2;p=thirdparty%2Fpdns.git dnsdist: Fix invalid proxy protocol payload on a DoH TC to TCP retry dnsdist forwards incoming DoH queries to its backend over UDP, and retry over TCP if the response is truncated (TC=1). When the proxy protocol is used between dnsdist and its backend, the second query, over TCP, needs to take into account that the proxy protocol payload has already been handled. This was not properly done in that exact case because the proxy protocol payload length was not propagated to the code handling the TCP communication, leading to the query ID being edited at the wrong offset in the packet and thus to an invalid proxy protocol payload. --- diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index e0de9f35e6..a5b02e19a6 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -1204,6 +1204,7 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string()); prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); + query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize; downstream->queueQuery(tqs, std::move(query)); } catch (...) { diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index 40c3905a12..1dbf7cc96c 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -169,19 +169,20 @@ static void prepareQueryForSending(TCPQuery& query, uint16_t id, QueryState quer if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) { query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end()); query.d_proxyProtocolPayloadAdded = true; + query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size(); } } else if (connectionState == ConnectionState::proxySent) { if (query.d_proxyProtocolPayloadAdded) { - if (query.d_buffer.size() < query.d_proxyProtocolPayload.size()) { + if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) { throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size())); } - query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayload.size()); + query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize); query.d_proxyProtocolPayloadAdded = false; + query.d_proxyProtocolPayloadAddedSize = 0; } } - - editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true); + editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true); } IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr& conn) diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index 9154f2f650..1e896b473e 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -83,22 +83,8 @@ struct InternalQuery { } - InternalQuery(InternalQuery&& rhs) : - d_idstate(std::move(rhs.d_idstate)), d_proxyProtocolPayload(std::move(rhs.d_proxyProtocolPayload)), d_buffer(std::move(rhs.d_buffer)), d_xfrMasterSerial(rhs.d_xfrMasterSerial), d_xfrSerialCount(rhs.d_xfrSerialCount), d_downstreamFailures(rhs.d_downstreamFailures), d_xfrMasterSerialCount(rhs.d_xfrMasterSerialCount), d_proxyProtocolPayloadAdded(rhs.d_proxyProtocolPayloadAdded) - { - } - InternalQuery& operator=(InternalQuery&& rhs) - { - d_idstate = std::move(rhs.d_idstate); - d_buffer = std::move(rhs.d_buffer); - d_proxyProtocolPayload = std::move(rhs.d_proxyProtocolPayload); - d_xfrMasterSerial = rhs.d_xfrMasterSerial; - d_xfrSerialCount = rhs.d_xfrSerialCount; - d_downstreamFailures = rhs.d_downstreamFailures; - d_xfrMasterSerialCount = rhs.d_xfrMasterSerialCount; - d_proxyProtocolPayloadAdded = rhs.d_proxyProtocolPayloadAdded; - return *this; - } + InternalQuery(InternalQuery&& rhs) = default; + InternalQuery& operator=(InternalQuery&& rhs) = default; InternalQuery(const InternalQuery& rhs) = delete; InternalQuery& operator=(const InternalQuery& rhs) = delete; @@ -111,6 +97,7 @@ struct InternalQuery IDState d_idstate; std::string d_proxyProtocolPayload; PacketBuffer d_buffer; + uint32_t d_proxyProtocolPayloadAddedSize{0}; uint32_t d_xfrMasterSerial{0}; uint32_t d_xfrSerialCount{0}; uint32_t d_downstreamFailures{0}; diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index da1e8ef838..d74d52665f 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -699,7 +699,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& du) if (du->downstream->d_config.useProxyProtocol) { size_t payloadSize = 0; - if (addProxyProtocol(dq)) { + if (addProxyProtocol(dq, &payloadSize)) { du->proxyProtocolPayloadSize = payloadSize; } } diff --git a/regression-tests.dnsdist/dnsdistdohtests.py b/regression-tests.dnsdist/dnsdistdohtests.py new file mode 100644 index 0000000000..cd19d095a2 --- /dev/null +++ b/regression-tests.dnsdist/dnsdistdohtests.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +import base64 +import dns +import os +import unittest + +from dnsdisttests import DNSDistTest + +import pycurl +from io import BytesIO + +@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled') +class DNSDistDOHTest(DNSDistTest): + + @classmethod + def getDOHGetURL(cls, baseurl, query, rawQuery=False): + if rawQuery: + wire = query + else: + wire = query.to_wire() + param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') + return baseurl + "?dns=" + param + + @classmethod + def openDOHConnection(cls, port, caFile, timeout=2.0): + conn = pycurl.Curl() + conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2) + + conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message", + "Accept: application/dns-message"]) + return conn + + @classmethod + def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None): + url = cls.getDOHGetURL(baseurl, query, rawQuery) + conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) + response_headers = BytesIO() + #conn.setopt(pycurl.VERBOSE, True) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) + if useHTTPS: + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + if caFile: + conn.setopt(pycurl.CAINFO, caFile) + + conn.setopt(pycurl.HTTPHEADER, customHeaders) + conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) + + if response: + if toQueue: + toQueue.put(response, True, timeout) + else: + cls._toResponderQueue.put(response, True, timeout) + + receivedQuery = None + message = None + cls._response_headers = '' + data = conn.perform_rb() + cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) + if cls._rcode == 200 and not rawResponse: + message = dns.message.from_wire(data) + elif rawResponse: + message = data + + if useQueue: + if fromQueue: + if not fromQueue.empty(): + receivedQuery = fromQueue.get(True, timeout) + else: + if not cls._fromResponderQueue.empty(): + receivedQuery = cls._fromResponderQueue.get(True, timeout) + + cls._response_headers = response_headers.getvalue() + return (receivedQuery, message) + + @classmethod + def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True): + url = baseurl + conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) + response_headers = BytesIO() + #conn.setopt(pycurl.VERBOSE, True) + conn.setopt(pycurl.URL, url) + conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) + if useHTTPS: + conn.setopt(pycurl.SSL_VERIFYPEER, 1) + conn.setopt(pycurl.SSL_VERIFYHOST, 2) + if caFile: + conn.setopt(pycurl.CAINFO, caFile) + + conn.setopt(pycurl.HTTPHEADER, customHeaders) + conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) + conn.setopt(pycurl.POST, True) + data = query + if not rawQuery: + data = data.to_wire() + + conn.setopt(pycurl.POSTFIELDS, data) + + if response: + cls._toResponderQueue.put(response, True, timeout) + + receivedQuery = None + message = None + cls._response_headers = '' + data = conn.perform_rb() + cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) + if cls._rcode == 200 and not rawResponse: + message = dns.message.from_wire(data) + elif rawResponse: + message = data + + if useQueue and not cls._fromResponderQueue.empty(): + receivedQuery = cls._fromResponderQueue.get(True, timeout) + + cls._response_headers = response_headers.getvalue() + return (receivedQuery, message) + + def getHeaderValue(self, name): + for header in self._response_headers.decode().splitlines(False): + values = header.split(':') + key = values[0] + if key.lower() == name.lower(): + return values[1].strip() + return None + + def checkHasHeader(self, name, value): + got = self.getHeaderValue(name) + self.assertEqual(got, value) + + def checkNoHeader(self, name): + self.checkHasHeader(name, None) + + @classmethod + def setUpClass(cls): + + # for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python + if 'SKIP_DOH_TESTS' in os.environ: + raise unittest.SkipTest('DNS over HTTPS tests are disabled') + + cls.startResponders() + cls.startDNSDist() + cls.setUpSockets() + + print("Launching tests..") diff --git a/regression-tests.dnsdist/test_DOH.py b/regression-tests.dnsdist/test_DOH.py index ee99e21d56..4f7d1ec600 100644 --- a/regression-tests.dnsdist/test_DOH.py +++ b/regression-tests.dnsdist/test_DOH.py @@ -2,142 +2,15 @@ import base64 import dns import os -import re import time import unittest import clientsubnetoption -from dnsdisttests import DNSDistTest + +from dnsdistdohtests import DNSDistDOHTest import pycurl from io import BytesIO -@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled') -class DNSDistDOHTest(DNSDistTest): - - @classmethod - def getDOHGetURL(cls, baseurl, query, rawQuery=False): - if rawQuery: - wire = query - else: - wire = query.to_wire() - param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=') - return baseurl + "?dns=" + param - - @classmethod - def openDOHConnection(cls, port, caFile, timeout=2.0): - conn = pycurl.Curl() - conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2) - - conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message", - "Accept: application/dns-message"]) - return conn - - @classmethod - def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True): - url = cls.getDOHGetURL(baseurl, query, rawQuery) - conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) - response_headers = BytesIO() - #conn.setopt(pycurl.VERBOSE, True) - conn.setopt(pycurl.URL, url) - conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) - if useHTTPS: - conn.setopt(pycurl.SSL_VERIFYPEER, 1) - conn.setopt(pycurl.SSL_VERIFYHOST, 2) - if caFile: - conn.setopt(pycurl.CAINFO, caFile) - - conn.setopt(pycurl.HTTPHEADER, customHeaders) - conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) - - if response: - cls._toResponderQueue.put(response, True, timeout) - - receivedQuery = None - message = None - cls._response_headers = '' - data = conn.perform_rb() - cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) - if cls._rcode == 200 and not rawResponse: - message = dns.message.from_wire(data) - elif rawResponse: - message = data - - if useQueue and not cls._fromResponderQueue.empty(): - receivedQuery = cls._fromResponderQueue.get(True, timeout) - - cls._response_headers = response_headers.getvalue() - return (receivedQuery, message) - - @classmethod - def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True): - url = baseurl - conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout) - response_headers = BytesIO() - #conn.setopt(pycurl.VERBOSE, True) - conn.setopt(pycurl.URL, url) - conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)]) - if useHTTPS: - conn.setopt(pycurl.SSL_VERIFYPEER, 1) - conn.setopt(pycurl.SSL_VERIFYHOST, 2) - if caFile: - conn.setopt(pycurl.CAINFO, caFile) - - conn.setopt(pycurl.HTTPHEADER, customHeaders) - conn.setopt(pycurl.HEADERFUNCTION, response_headers.write) - conn.setopt(pycurl.POST, True) - data = query - if not rawQuery: - data = data.to_wire() - - conn.setopt(pycurl.POSTFIELDS, data) - - if response: - cls._toResponderQueue.put(response, True, timeout) - - receivedQuery = None - message = None - cls._response_headers = '' - data = conn.perform_rb() - cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE) - if cls._rcode == 200 and not rawResponse: - message = dns.message.from_wire(data) - elif rawResponse: - message = data - - if useQueue and not cls._fromResponderQueue.empty(): - receivedQuery = cls._fromResponderQueue.get(True, timeout) - - cls._response_headers = response_headers.getvalue() - return (receivedQuery, message) - - def getHeaderValue(self, name): - for header in self._response_headers.decode().splitlines(False): - values = header.split(':') - key = values[0] - if key.lower() == name.lower(): - return values[1].strip() - return None - - def checkHasHeader(self, name, value): - got = self.getHeaderValue(name) - self.assertEqual(got, value) - - def checkNoHeader(self, name): - self.checkHasHeader(name, None) - - @classmethod - def setUpClass(cls): - - # for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python - if 'SKIP_DOH_TESTS' in os.environ: - raise unittest.SkipTest('DNS over HTTPS tests are disabled') - - cls.startResponders() - cls.startDNSDist() - cls.setUpSockets() - - print("Launching tests..") - class TestDOH(DNSDistDOHTest): _serverKey = 'server.key' diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index bf073f74ed..7a50781949 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -9,6 +9,7 @@ import threading from dnsdisttests import DNSDistTest from proxyprotocol import ProxyProtocol +from dnsdistdohtests import DNSDistDOHTest # Python2/3 compatibility hacks try: @@ -720,3 +721,72 @@ class TestProxyProtocolNotExpected(DNSDistTest): except socket.timeout: print('timeout') self.assertEqual(receivedResponse, None) + +class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest): + + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _dohServerPort = 8443 + _dohBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohServerPort)) + _proxyResponderPort = proxyResponderPort + _config_template = """ + newServer{address="127.0.0.1:%s", useProxyProtocol=true} + + addDOHLocal("127.0.0.1:%s", "%s", "%s") + """ + _config_params = ['_proxyResponderPort', '_dohServerPort', '_serverCert', '_serverKey'] + + def testTruncation(self): + """ + DOH: Truncation over UDP (with cache) + """ + # the query is first forwarded over UDP, leading to a TC=1 answer from the + # backend, then over TCP + name = 'truncated-udp.doh-with-cache.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + query.id = 42 + expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096) + expectedQuery.id = 42 + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '127.0.0.1') + response.answer.append(rrset) + + # first response is a TC=1 + tcResponse = dns.message.make_response(query) + tcResponse.flags |= dns.flags.TC + toProxyQueue.put(tcResponse, True, 2.0) + + ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue) + # first query, received by the responder over UDP + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + receivedQuery = dns.message.from_wire(receivedDNSData) + self.assertTrue(receivedQuery) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort) + + # check the response + self.assertTrue(receivedResponse) + self.assertEqual(response, receivedResponse) + + # check the second query, received by the responder over TCP + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedDNSData) + receivedQuery = dns.message.from_wire(receivedDNSData) + self.assertTrue(receivedQuery) + receivedQuery.id = expectedQuery.id + self.assertEqual(expectedQuery, receivedQuery) + self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort) + + # make sure we consumed everything + self.assertTrue(toProxyQueue.empty()) + self.assertTrue(fromProxyQueue.empty())