import dns
import selectors
import socket
+import ssl
import struct
import sys
import threading
backgroundThreads = {}
-def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort):
+def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort, serverCtx=None, ca=None, sni=None):
# this responder accepts TCP connections on the listening port,
# and relay the raw content to a second TCP connection to the
# forwarding port, after adding a Proxy Protocol v2 payload
# containing the initial source IP and port, destination IP
# and port.
backgroundThreads[threading.get_native_id()] = True
+
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+ if serverCtx is not None:
+ sock = serverCtx.wrap_socket(sock, server_side=True)
+
try:
sock.bind(("127.0.0.1", listeningPort))
except socket.error as e:
sys.exit(1)
sock.settimeout(0.5)
sock.listen(100)
+
while True:
try:
(incoming, _) = sock.accept()
outgoing = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
outgoing.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
outgoing.settimeout(2.0)
+ if sni:
+ if hasattr(ssl, 'create_default_context'):
+ sslctx = ssl.create_default_context(cafile=ca)
+ if hasattr(sslctx, 'set_alpn_protocols'):
+ sslctx.set_alpn_protocols(['h2'])
+ outgoing = sslctx.wrap_socket(outgoing, server_hostname=sni)
+ else:
+ outgoing = ssl.wrap_socket(outgoing, ca_certs=ca, cert_reqs=ssl.CERT_REQUIRED)
+
outgoing.connect(('127.0.0.1', forwardingPort))
outgoing.send(payload)
"""
_config_template = """
- addDOHLocal("127.0.0.1:%s", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
+ addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
+ addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
setProxyProtocolACL( { "127.0.0.1/32" } )
newServer{address="127.0.0.1:%d", useProxyProtocol=true}
_serverCert = 'server.chain'
_serverName = 'tls.tests.dnsdist.org'
_caCert = 'ca.pem'
- _dohServerPort = 8443
- _dohBaseURL = ("https://%s:%d/" % (_serverName, _dohServerPort))
- _config_params = ['_dohServerPort', '_serverCert', '_serverKey', '_proxyResponderPort']
+ _dohServerPPOutsidePort = 8443
+ _dohServerPPInsidePort = 9443
+ _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort']
def testNoHeader(self):
"""
self.assertEqual(receivedResponse, response)
self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
- def testProxyDoHSeveralQueriesOverConnection(self):
+ def testProxyDoHSeveralQueriesOverConnectionPPOutside(self):
"""
- Incoming Proxy Protocol: Several queries over the same connection (DoH)
+ Incoming Proxy Protocol: Several queries over the same connection (DoH, PP outside TLS)
"""
- name = 'several-queries.proxy-protocol-incoming.tests.powerdns.com.'
+ name = 'several-queries.doh-outside.proxy-protocol-incoming.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN')
response = dns.message.make_response(query)
wire = query.to_wire()
reverseProxyPort = 13053
- reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPort])
+ reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPOutsidePort])
+ reverseProxy.start()
+
+ receivedResponse = None
+ conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
+
+ reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+ for idx in range(5):
+ receivedResponse = None
+ toProxyQueue.put(response, True, 2.0)
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+ def testProxyDoHSeveralQueriesOverConnectionPPInside(self):
+ """
+ Incoming Proxy Protocol: Several queries over the same connection (DoH, PP inside TLS)
+ """
+ name = 'several-queries.doh-inside.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ toProxyQueue.put(response, True, 2.0)
+
+ wire = query.to_wire()
+
+ reverseProxyPort = 14053
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.load_cert_chain(self._serverCert, self._serverKey)
+ tlsContext.set_alpn_protocols(['h2'])
+ reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPInsidePort, tlsContext, self._caCert, self._serverName])
reverseProxy.start()
receivedResponse = None
+ time.sleep(1)
conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
receivedQuery.id = query.id
receivedResponse.id = response.id
self.assertEqual(receivedQuery, query)
- print(receivedResponse)
- print(response)
self.assertEqual(receivedResponse, response)
self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)