From: Remi Gacogne Date: Thu, 13 Jul 2023 12:45:38 +0000 (+0200) Subject: dnsdist: Add a test for DoH incoming proxy protocol outside of TLS X-Git-Tag: rec-5.0.0-alpha1~19^2~17 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=41f3676508f033cc0b0a5fb456cce0f917c7149a;p=thirdparty%2Fpdns.git dnsdist: Add a test for DoH incoming proxy protocol outside of TLS --- diff --git a/regression-tests.dnsdist/dnsdisttests.py b/regression-tests.dnsdist/dnsdisttests.py index 75068bea02..e2787d577f 100644 --- a/regression-tests.dnsdist/dnsdisttests.py +++ b/regression-tests.dnsdist/dnsdisttests.py @@ -987,7 +987,7 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): return conn @classmethod - def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, useProxyProtocol=False, conn=None): + def sendDOHQuery(cls, port, servername, baseurl, query, response=None, timeout=2.0, caFile=None, useQueue=True, rawQuery=False, rawResponse=False, customHeaders=[], useHTTPS=True, fromQueue=None, toQueue=None, conn=None): url = cls.getDOHGetURL(baseurl, query, rawQuery) if not conn: @@ -1003,11 +1003,6 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): if caFile: conn.setopt(pycurl.CAINFO, caFile) - if useProxyProtocol: - print('enabling PP') - # 274 is CURLOPT_HAPROXYPROTOCOL - conn.setopt(274, 1) - response_headers = BytesIO() #conn.setopt(pycurl.VERBOSE, True) conn.setopt(pycurl.URL, url) @@ -1088,8 +1083,8 @@ class DNSDistTest(AssertEqualDNSMessageMixin, unittest.TestCase): cls._response_headers = response_headers.getvalue() return (receivedQuery, message) - def sendDOHQueryWrapper(self, query, response, useQueue=True, useProxyProtocol=False): - return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue, useProxyProtocol=useProxyProtocol) + def sendDOHQueryWrapper(self, query, response, useQueue=True): + return self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) def sendDOHWithNGHTTP2QueryWrapper(self, query, response, useQueue=True): return self.sendDOHQuery(self._dohWithNGHTTP2ServerPort, self._serverName, self._dohWithNGHTTP2BaseURL, query, response=response, caFile=self._caCert, useQueue=useQueue) diff --git a/regression-tests.dnsdist/test_ProxyProtocol.py b/regression-tests.dnsdist/test_ProxyProtocol.py index 2ff4a2d741..744d6a7c79 100644 --- a/regression-tests.dnsdist/test_ProxyProtocol.py +++ b/regression-tests.dnsdist/test_ProxyProtocol.py @@ -2,6 +2,7 @@ import copy import dns +import selectors import socket import struct import sys @@ -141,6 +142,80 @@ tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=Prox tcpResponder.daemon = True tcpResponder.start() +backgroundThreads = {} + +def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort): + # this responder accepts TCP connections on the listening port, + # and relay the raw content to a second TCP connection to the + # forwarding port, after adding a Proxy Protocol v2 payload + # containing the initial source IP and port, destination IP + # and port. + backgroundThreads[threading.get_native_id()] = True + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + sock.bind(("127.0.0.1", listeningPort)) + except socket.error as e: + print("Error binding in the Mock TCP reverse proxy: %s" % str(e)) + sys.exit(1) + sock.settimeout(0.5) + sock.listen(100) + while True: + try: + (incoming, _) = sock.accept() + except socket.timeout: + if backgroundThreads.get(threading.get_native_id(), False) == False: + del backgroundThreads[threading.get_native_id()] + break + else: + continue + + incoming.settimeout(5.0) + payload = ProxyProtocol.getPayload(False, True, False, '127.0.0.1', '127.0.0.1', incoming.getpeername()[1], listeningPort, [ [ 2, b'foo'], [ 3, b'proxy'] ]) + + outgoing = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + outgoing.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + outgoing.settimeout(2.0) + outgoing.connect(('127.0.0.1', forwardingPort)) + + outgoing.send(payload) + + sel = selectors.DefaultSelector() + def readFromClient(conn): + data = conn.recv(512) + if not data or len(data) == 0: + return False + outgoing.send(data) + return True + + def readFromBackend(conn): + data = conn.recv(512) + if not data or len(data) == 0: + return False + incoming.send(data) + return True + + sel.register(incoming, selectors.EVENT_READ, readFromClient) + sel.register(outgoing, selectors.EVENT_READ, readFromBackend) + done = False + while not done: + try: + events = sel.select() + for key, mask in events: + if not (key.data)(key.fileobj): + done = True + break + except socket.timeout: + break + except: + break + + incoming.close() + outgoing.close() + + sock.close() + class ProxyProtocolTest(DNSDistTest): _proxyResponderPort = proxyResponderPort _config_params = ['_proxyResponderPort'] @@ -398,6 +473,7 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): """ _config_template = """ + addDOHLocal("127.0.0.1:%s", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true}) setProxyProtocolACL( { "127.0.0.1/32" } ) newServer{address="127.0.0.1:%d", useProxyProtocol=true} @@ -434,8 +510,13 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): -- override all existing values addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"})) """ - _config_params = ['_proxyResponderPort'] - _verboseMode = True + _serverKey = 'server.key' + _serverCert = 'server.chain' + _serverName = 'tls.tests.dnsdist.org' + _caCert = 'ca.pem' + _dohServerPort = 8443 + _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort)) + _config_params = ['_dohServerPort', '_serverCert', '_serverKey', '_proxyResponderPort'] def testNoHeader(self): """ @@ -445,9 +526,12 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.' query = dns.message.make_query(name, 'A', 'IN') - for method in ("sendUDPQuery", "sendTCPQuery"): + for method in ("sendUDPQuery", "sendTCPQuery", "sendDOHQueryWrapper"): sender = getattr(self, method) - (_, receivedResponse) = sender(query, response=None) + try: + (_, receivedResponse) = sender(query, response=None) + except Exception: + receivedResponse = None self.assertEqual(receivedResponse, None) def testIncomingProxyDest(self): @@ -656,6 +740,66 @@ class TestProxyProtocolIncoming(ProxyProtocolTest): self.assertEqual(receivedResponse, response) self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort) + def testProxyDoHSeveralQueriesOverConnection(self): + """ + Incoming Proxy Protocol: Several queries over the same connection (DoH) + """ + name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + + toProxyQueue.put(response, True, 2.0) + + wire = query.to_wire() + + reverseProxyPort = 13053 + reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPort]) + reverseProxy.start() + + receivedResponse = None + conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0) + + reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort)) + (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn) + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEqual(receivedQuery, query) + self.assertEqual(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort) + + for idx in range(5): + receivedResponse = None + toProxyQueue.put(response, True, 2.0) + (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn) + (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0) + self.assertTrue(receivedProxyPayload) + self.assertTrue(receivedDNSData) + self.assertTrue(receivedResponse) + + receivedQuery = dns.message.from_wire(receivedDNSData) + receivedQuery.id = query.id + receivedResponse.id = response.id + self.assertEqual(receivedQuery, query) + print(receivedResponse) + print(response) + self.assertEqual(receivedResponse, response) + self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort) + + @classmethod + def tearDownClass(cls): + cls._sock.close() + for backgroundThread in cls._backgroundThreads: + cls._backgroundThreads[backgroundThread] = False + for backgroundThread in backgroundThreads: + backgroundThreads[backgroundThread] = False + cls.killProcess(cls._dnsdist) + class TestProxyProtocolNotExpected(DNSDistTest): """ dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1