auto downstream = t_downstreamTCPConnectionsManager.getConnectionToDownstream(threadData->mplexer, downstreamServer, now, std::string());
prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize);
+ query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize;
downstream->queueQuery(tqs, std::move(query));
}
catch (...) {
if (query.d_proxyProtocolPayload.size() > 0 && !query.d_proxyProtocolPayloadAdded) {
query.d_buffer.insert(query.d_buffer.begin(), query.d_proxyProtocolPayload.begin(), query.d_proxyProtocolPayload.end());
query.d_proxyProtocolPayloadAdded = true;
+ query.d_proxyProtocolPayloadAddedSize = query.d_proxyProtocolPayload.size();
}
}
else if (connectionState == ConnectionState::proxySent) {
if (query.d_proxyProtocolPayloadAdded) {
- if (query.d_buffer.size() < query.d_proxyProtocolPayload.size()) {
+ if (query.d_buffer.size() < query.d_proxyProtocolPayloadAddedSize) {
throw std::runtime_error("Trying to remove a proxy protocol payload of size " + std::to_string(query.d_proxyProtocolPayload.size()) + " from a buffer of size " + std::to_string(query.d_buffer.size()));
}
- query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayload.size());
+ query.d_buffer.erase(query.d_buffer.begin(), query.d_buffer.begin() + query.d_proxyProtocolPayloadAddedSize);
query.d_proxyProtocolPayloadAdded = false;
+ query.d_proxyProtocolPayloadAddedSize = 0;
}
}
-
- editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayload.size() : 0, true);
+ editPayloadID(query.d_buffer, id, query.d_proxyProtocolPayloadAdded ? query.d_proxyProtocolPayloadAddedSize : 0, true);
}
IOState TCPConnectionToBackend::queueNextQuery(std::shared_ptr<TCPConnectionToBackend>& conn)
{
}
- InternalQuery(InternalQuery&& rhs) :
- d_idstate(std::move(rhs.d_idstate)), d_proxyProtocolPayload(std::move(rhs.d_proxyProtocolPayload)), d_buffer(std::move(rhs.d_buffer)), d_xfrMasterSerial(rhs.d_xfrMasterSerial), d_xfrSerialCount(rhs.d_xfrSerialCount), d_downstreamFailures(rhs.d_downstreamFailures), d_xfrMasterSerialCount(rhs.d_xfrMasterSerialCount), d_proxyProtocolPayloadAdded(rhs.d_proxyProtocolPayloadAdded)
- {
- }
- InternalQuery& operator=(InternalQuery&& rhs)
- {
- d_idstate = std::move(rhs.d_idstate);
- d_buffer = std::move(rhs.d_buffer);
- d_proxyProtocolPayload = std::move(rhs.d_proxyProtocolPayload);
- d_xfrMasterSerial = rhs.d_xfrMasterSerial;
- d_xfrSerialCount = rhs.d_xfrSerialCount;
- d_downstreamFailures = rhs.d_downstreamFailures;
- d_xfrMasterSerialCount = rhs.d_xfrMasterSerialCount;
- d_proxyProtocolPayloadAdded = rhs.d_proxyProtocolPayloadAdded;
- return *this;
- }
+ InternalQuery(InternalQuery&& rhs) = default;
+ InternalQuery& operator=(InternalQuery&& rhs) = default;
InternalQuery(const InternalQuery& rhs) = delete;
InternalQuery& operator=(const InternalQuery& rhs) = delete;
IDState d_idstate;
std::string d_proxyProtocolPayload;
PacketBuffer d_buffer;
+ uint32_t d_proxyProtocolPayloadAddedSize{0};
uint32_t d_xfrMasterSerial{0};
uint32_t d_xfrSerialCount{0};
uint32_t d_downstreamFailures{0};
if (du->downstream->d_config.useProxyProtocol) {
size_t payloadSize = 0;
- if (addProxyProtocol(dq)) {
+ if (addProxyProtocol(dq, &payloadSize)) {
du->proxyProtocolPayloadSize = payloadSize;
}
}
--- /dev/null
+#!/usr/bin/env python
+import base64
+import dns
+import os
+import unittest
+
+from dnsdisttests import DNSDistTest
+
+import pycurl
+from io import BytesIO
+
+@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
+class DNSDistDOHTest(DNSDistTest):
+
+ @classmethod
+ def getDOHGetURL(cls, baseurl, query, rawQuery=False):
+ if rawQuery:
+ wire = query
+ else:
+ wire = query.to_wire()
+ param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
+ return baseurl + "?dns=" + param
+
+ @classmethod
+ def openDOHConnection(cls, port, caFile, timeout=2.0):
+ conn = pycurl.Curl()
+ conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
+
+ conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
+ "Accept: application/dns-message"])
+ return conn
+
+ @classmethod
+ def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None):
+ url = cls.getDOHGetURL(baseurl, query, rawQuery)
+ conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+ response_headers = BytesIO()
+ #conn.setopt(pycurl.VERBOSE, True)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+ if useHTTPS:
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ if caFile:
+ conn.setopt(pycurl.CAINFO, caFile)
+
+ conn.setopt(pycurl.HTTPHEADER, customHeaders)
+ conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
+
+ if response:
+ if toQueue:
+ toQueue.put(response, True, timeout)
+ else:
+ cls._toResponderQueue.put(response, True, timeout)
+
+ receivedQuery = None
+ message = None
+ cls._response_headers = ''
+ data = conn.perform_rb()
+ cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ if cls._rcode == 200 and not rawResponse:
+ message = dns.message.from_wire(data)
+ elif rawResponse:
+ message = data
+
+ if useQueue:
+ if fromQueue:
+ if not fromQueue.empty():
+ receivedQuery = fromQueue.get(True, timeout)
+ else:
+ if not cls._fromResponderQueue.empty():
+ receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+ cls._response_headers = response_headers.getvalue()
+ return (receivedQuery, message)
+
+ @classmethod
+ def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
+ url = baseurl
+ conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
+ response_headers = BytesIO()
+ #conn.setopt(pycurl.VERBOSE, True)
+ conn.setopt(pycurl.URL, url)
+ conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
+ if useHTTPS:
+ conn.setopt(pycurl.SSL_VERIFYPEER, 1)
+ conn.setopt(pycurl.SSL_VERIFYHOST, 2)
+ if caFile:
+ conn.setopt(pycurl.CAINFO, caFile)
+
+ conn.setopt(pycurl.HTTPHEADER, customHeaders)
+ conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
+ conn.setopt(pycurl.POST, True)
+ data = query
+ if not rawQuery:
+ data = data.to_wire()
+
+ conn.setopt(pycurl.POSTFIELDS, data)
+
+ if response:
+ cls._toResponderQueue.put(response, True, timeout)
+
+ receivedQuery = None
+ message = None
+ cls._response_headers = ''
+ data = conn.perform_rb()
+ cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
+ if cls._rcode == 200 and not rawResponse:
+ message = dns.message.from_wire(data)
+ elif rawResponse:
+ message = data
+
+ if useQueue and not cls._fromResponderQueue.empty():
+ receivedQuery = cls._fromResponderQueue.get(True, timeout)
+
+ cls._response_headers = response_headers.getvalue()
+ return (receivedQuery, message)
+
+ def getHeaderValue(self, name):
+ for header in self._response_headers.decode().splitlines(False):
+ values = header.split(':')
+ key = values[0]
+ if key.lower() == name.lower():
+ return values[1].strip()
+ return None
+
+ def checkHasHeader(self, name, value):
+ got = self.getHeaderValue(name)
+ self.assertEqual(got, value)
+
+ def checkNoHeader(self, name):
+ self.checkHasHeader(name, None)
+
+ @classmethod
+ def setUpClass(cls):
+
+ # for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python
+ if 'SKIP_DOH_TESTS' in os.environ:
+ raise unittest.SkipTest('DNS over HTTPS tests are disabled')
+
+ cls.startResponders()
+ cls.startDNSDist()
+ cls.setUpSockets()
+
+ print("Launching tests..")
import base64
import dns
import os
-import re
import time
import unittest
import clientsubnetoption
-from dnsdisttests import DNSDistTest
+
+from dnsdistdohtests import DNSDistDOHTest
import pycurl
from io import BytesIO
-@unittest.skipIf('SKIP_DOH_TESTS' in os.environ, 'DNS over HTTPS tests are disabled')
-class DNSDistDOHTest(DNSDistTest):
-
- @classmethod
- def getDOHGetURL(cls, baseurl, query, rawQuery=False):
- if rawQuery:
- wire = query
- else:
- wire = query.to_wire()
- param = base64.urlsafe_b64encode(wire).decode('UTF8').rstrip('=')
- return baseurl + "?dns=" + param
-
- @classmethod
- def openDOHConnection(cls, port, caFile, timeout=2.0):
- conn = pycurl.Curl()
- conn.setopt(pycurl.HTTP_VERSION, pycurl.CURL_HTTP_VERSION_2)
-
- conn.setopt(pycurl.HTTPHEADER, ["Content-type: application/dns-message",
- "Accept: application/dns-message"])
- return conn
-
- @classmethod
- def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
- url = cls.getDOHGetURL(baseurl, query, rawQuery)
- conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
- response_headers = BytesIO()
- #conn.setopt(pycurl.VERBOSE, True)
- conn.setopt(pycurl.URL, url)
- conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
- if useHTTPS:
- conn.setopt(pycurl.SSL_VERIFYPEER, 1)
- conn.setopt(pycurl.SSL_VERIFYHOST, 2)
- if caFile:
- conn.setopt(pycurl.CAINFO, caFile)
-
- conn.setopt(pycurl.HTTPHEADER, customHeaders)
- conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
-
- if response:
- cls._toResponderQueue.put(response, True, timeout)
-
- receivedQuery = None
- message = None
- cls._response_headers = ''
- data = conn.perform_rb()
- cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
- if cls._rcode == 200 and not rawResponse:
- message = dns.message.from_wire(data)
- elif rawResponse:
- message = data
-
- if useQueue and not cls._fromResponderQueue.empty():
- receivedQuery = cls._fromResponderQueue.get(True, timeout)
-
- cls._response_headers = response_headers.getvalue()
- return (receivedQuery, message)
-
- @classmethod
- def sendDOHPostQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True):
- url = baseurl
- conn = cls.openDOHConnection(port, caFile=caFile, timeout=timeout)
- response_headers = BytesIO()
- #conn.setopt(pycurl.VERBOSE, True)
- conn.setopt(pycurl.URL, url)
- conn.setopt(pycurl.RESOLVE, ["%s:%d:127.0.0.1" % (servername, port)])
- if useHTTPS:
- conn.setopt(pycurl.SSL_VERIFYPEER, 1)
- conn.setopt(pycurl.SSL_VERIFYHOST, 2)
- if caFile:
- conn.setopt(pycurl.CAINFO, caFile)
-
- conn.setopt(pycurl.HTTPHEADER, customHeaders)
- conn.setopt(pycurl.HEADERFUNCTION, response_headers.write)
- conn.setopt(pycurl.POST, True)
- data = query
- if not rawQuery:
- data = data.to_wire()
-
- conn.setopt(pycurl.POSTFIELDS, data)
-
- if response:
- cls._toResponderQueue.put(response, True, timeout)
-
- receivedQuery = None
- message = None
- cls._response_headers = ''
- data = conn.perform_rb()
- cls._rcode = conn.getinfo(pycurl.RESPONSE_CODE)
- if cls._rcode == 200 and not rawResponse:
- message = dns.message.from_wire(data)
- elif rawResponse:
- message = data
-
- if useQueue and not cls._fromResponderQueue.empty():
- receivedQuery = cls._fromResponderQueue.get(True, timeout)
-
- cls._response_headers = response_headers.getvalue()
- return (receivedQuery, message)
-
- def getHeaderValue(self, name):
- for header in self._response_headers.decode().splitlines(False):
- values = header.split(':')
- key = values[0]
- if key.lower() == name.lower():
- return values[1].strip()
- return None
-
- def checkHasHeader(self, name, value):
- got = self.getHeaderValue(name)
- self.assertEqual(got, value)
-
- def checkNoHeader(self, name):
- self.checkHasHeader(name, None)
-
- @classmethod
- def setUpClass(cls):
-
- # for some reason, @unittest.skipIf() is not applied to derived classes with some versions of Python
- if 'SKIP_DOH_TESTS' in os.environ:
- raise unittest.SkipTest('DNS over HTTPS tests are disabled')
-
- cls.startResponders()
- cls.startDNSDist()
- cls.setUpSockets()
-
- print("Launching tests..")
-
class TestDOH(DNSDistDOHTest):
_serverKey = 'server.key'
from dnsdisttests import DNSDistTest
from proxyprotocol import ProxyProtocol
+from dnsdistdohtests import DNSDistDOHTest
# Python2/3 compatibility hacks
try:
except socket.timeout:
print('timeout')
self.assertEqual(receivedResponse, None)
+
+class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
+
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _dohServerPort = 8443
+ _dohBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohServerPort))
+ _proxyResponderPort = proxyResponderPort
+ _config_template = """
+ newServer{address="127.0.0.1:%s", useProxyProtocol=true}
+
+ addDOHLocal("127.0.0.1:%s", "%s", "%s")
+ """
+ _config_params = ['_proxyResponderPort', '_dohServerPort', '_serverCert', '_serverKey']
+
+ def testTruncation(self):
+ """
+ DOH: Truncation over UDP (with cache)
+ """
+ # the query is first forwarded over UDP, leading to a TC=1 answer from the
+ # backend, then over TCP
+ name = 'truncated-udp.doh-with-cache.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ query.id = 42
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+ expectedQuery.id = 42
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ # first response is a TC=1
+ tcResponse = dns.message.make_response(query)
+ tcResponse.flags |= dns.flags.TC
+ toProxyQueue.put(tcResponse, True, 2.0)
+
+ ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
+ # first query, received by the responder over UDP
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = expectedQuery.id
+ self.assertEqual(expectedQuery, receivedQuery)
+ self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
+
+ # check the response
+ self.assertTrue(receivedResponse)
+ self.assertEqual(response, receivedResponse)
+
+ # check the second query, received by the responder over TCP
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedDNSData)
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = expectedQuery.id
+ self.assertEqual(expectedQuery, receivedQuery)
+ self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
+
+ # make sure we consumed everything
+ self.assertTrue(toProxyQueue.empty())
+ self.assertTrue(fromProxyQueue.empty())