#!/usr/bin/env python
-import copy
import dns
+import selectors
import socket
+import ssl
import struct
import sys
import threading
+import time
-from dnsdisttests import DNSDistTest
+from dnsdisttests import DNSDistTest, pickAvailablePort
from proxyprotocol import ProxyProtocol
+from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder
from dnsdistdohtests import DNSDistDOHTest
# Python2/3 compatibility hacks
except ImportError:
from Queue import Queue
-def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
- sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
- sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
- try:
- sock.bind(("127.0.0.1", port))
- except socket.error as e:
- print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
- sys.exit(1)
+toProxyQueue = Queue()
+fromProxyQueue = Queue()
+proxyResponderPort = pickAvailablePort()
- while True:
- data, addr = sock.recvfrom(4096)
-
- proxy = ProxyProtocol()
- if len(data) < proxy.HEADER_SIZE:
- continue
-
- if not proxy.parseHeader(data):
- continue
-
- if proxy.local:
- # likely a healthcheck
- data = data[proxy.HEADER_SIZE:]
- request = dns.message.from_wire(data)
- response = dns.message.make_response(request)
- wire = response.to_wire()
- sock.settimeout(2.0)
- sock.sendto(wire, addr)
- sock.settimeout(None)
-
- continue
-
- payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
- dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
- toQueue.put([payload, dnsData], True, 2.0)
- # computing the correct ID for the response
- request = dns.message.from_wire(dnsData)
- response = fromQueue.get(True, 2.0)
- response.id = request.id
-
- sock.settimeout(2.0)
- sock.sendto(response.to_wire(), addr)
- sock.settimeout(None)
+udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+udpResponder.daemon = True
+udpResponder.start()
+tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+tcpResponder.daemon = True
+tcpResponder.start()
- sock.close()
+backgroundThreads = {}
+
+def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort, serverCtx=None, ca=None, sni=None):
+ # this responder accepts TCP connections on the listening port,
+ # and relay the raw content to a second TCP connection to the
+ # forwarding port, after adding a Proxy Protocol v2 payload
+ # containing the initial source IP and port, destination IP
+ # and port.
+ backgroundThreads[threading.get_native_id()] = True
-def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
- # be aware that this responder will not accept a new connection
- # until the last one has been closed. This is done on purpose to
- # to check for connection reuse, making sure that a lot of connections
- # are not opened in parallel.
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+ if serverCtx is not None:
+ sock = serverCtx.wrap_socket(sock, server_side=True)
+
try:
- sock.bind(("127.0.0.1", port))
+ sock.bind(("127.0.0.1", listeningPort))
except socket.error as e:
- print("Error binding in the TCP responder: %s" % str(e))
+ print("Error binding in the Mock TCP reverse proxy: %s" % str(e))
sys.exit(1)
-
+ sock.settimeout(0.5)
sock.listen(100)
+
while True:
- (conn, _) = sock.accept()
- conn.settimeout(5.0)
- # try to read the entire Proxy Protocol header
- proxy = ProxyProtocol()
- header = conn.recv(proxy.HEADER_SIZE)
- if not header:
- conn.close()
- continue
-
- if not proxy.parseHeader(header):
- conn.close()
- continue
-
- proxyContent = conn.recv(proxy.contentLen)
- if not proxyContent:
- conn.close()
- continue
-
- payload = header + proxyContent
- while True:
+ try:
+ (incoming, _) = sock.accept()
+ except socket.timeout:
+ if backgroundThreads.get(threading.get_native_id(), False) == False:
+ del backgroundThreads[threading.get_native_id()]
+ break
+ else:
+ continue
+
+ incoming.settimeout(5.0)
+ payload = ProxyProtocol.getPayload(False, True, False, '127.0.0.1', '127.0.0.1', incoming.getpeername()[1], listeningPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+
+ outgoing = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+ outgoing.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+ outgoing.settimeout(2.0)
+ if sni:
+ if hasattr(ssl, 'create_default_context'):
+ sslctx = ssl.create_default_context(cafile=ca)
+ if hasattr(sslctx, 'set_alpn_protocols'):
+ sslctx.set_alpn_protocols(['h2'])
+ outgoing = sslctx.wrap_socket(outgoing, server_hostname=sni)
+ else:
+ outgoing = ssl.wrap_socket(outgoing, ca_certs=ca, cert_reqs=ssl.CERT_REQUIRED)
+
+ outgoing.connect(('127.0.0.1', forwardingPort))
+
+ outgoing.send(payload)
+
+ sel = selectors.DefaultSelector()
+ def readFromClient(conn):
+ data = conn.recv(512)
+ if not data or len(data) == 0:
+ return False
+ outgoing.send(data)
+ return True
+
+ def readFromBackend(conn):
+ data = conn.recv(512)
+ if not data or len(data) == 0:
+ return False
+ incoming.send(data)
+ return True
+
+ sel.register(incoming, selectors.EVENT_READ, readFromClient)
+ sel.register(outgoing, selectors.EVENT_READ, readFromBackend)
+ done = False
+ while not done:
try:
- data = conn.recv(2)
+ events = sel.select()
+ for key, mask in events:
+ if not (key.data)(key.fileobj):
+ done = True
+ break
except socket.timeout:
- data = None
-
- if not data:
- conn.close()
break
-
- (datalen,) = struct.unpack("!H", data)
- data = conn.recv(datalen)
-
- toQueue.put([payload, data], True, 2.0)
-
- response = copy.deepcopy(fromQueue.get(True, 2.0))
- if not response:
- conn.close()
+ except:
break
- # computing the correct ID for the response
- request = dns.message.from_wire(data)
- response.id = request.id
-
- wire = response.to_wire()
- conn.send(struct.pack("!H", len(wire)))
- conn.send(wire)
-
- conn.close()
+ incoming.close()
+ outgoing.close()
sock.close()
-toProxyQueue = Queue()
-fromProxyQueue = Queue()
-proxyResponderPort = 5470
-
-udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
-udpResponder.setDaemon(True)
-udpResponder.start()
-tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
-tcpResponder.setDaemon(True)
-tcpResponder.start()
-
class ProxyProtocolTest(DNSDistTest):
_proxyResponderPort = proxyResponderPort
_config_params = ['_proxyResponderPort']
"""
_config_template = """
+ addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
+ addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
setProxyProtocolACL( { "127.0.0.1/32" } )
- newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+ newServer{address="127.0.0.1:%d", useProxyProtocol=true, proxyProtocolAdvertiseTLS=true}
function addValues(dq)
dq:addProxyProtocolValue(0, 'foo')
-- override all existing values
addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"}))
"""
- _config_params = ['_proxyResponderPort']
- _verboseMode = True
+ _serverKey = 'server.key'
+ _serverCert = 'server.chain'
+ _serverName = 'tls.tests.dnsdist.org'
+ _caCert = 'ca.pem'
+ _dohServerPPOutsidePort = pickAvailablePort()
+ _dohServerPPInsidePort = pickAvailablePort()
+ _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort']
def testNoHeader(self):
"""
name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN')
- for method in ("sendUDPQuery", "sendTCPQuery"):
+ for method in ("sendUDPQuery", "sendTCPQuery", "sendDOHQueryWrapper"):
sender = getattr(self, method)
- (_, receivedResponse) = sender(query, response=None)
+ try:
+ (_, receivedResponse) = sender(query, response=None)
+ except Exception:
+ receivedResponse = None
self.assertEqual(receivedResponse, None)
def testIncomingProxyDest(self):
self.assertEqual(receivedResponse, response)
self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
+ def testProxyDoHSeveralQueriesOverConnectionPPOutside(self):
+ """
+ Incoming Proxy Protocol: Several queries over the same connection (DoH, PP outside TLS)
+ """
+ name = 'several-queries.doh-outside.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ toProxyQueue.put(response, True, 2.0)
+
+ wire = query.to_wire()
+
+ reverseProxyPort = pickAvailablePort()
+ reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPOutsidePort])
+ reverseProxy.start()
+ time.sleep(1)
+
+ receivedResponse = None
+ conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
+
+ reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+ for idx in range(5):
+ receivedResponse = None
+ toProxyQueue.put(response, True, 2.0)
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+ def testProxyDoHSeveralQueriesOverConnectionPPInside(self):
+ """
+ Incoming Proxy Protocol: Several queries over the same connection (DoH, PP inside TLS)
+ """
+ name = 'several-queries.doh-inside.proxy-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+
+ toProxyQueue.put(response, True, 2.0)
+
+ wire = query.to_wire()
+
+ reverseProxyPort = pickAvailablePort()
+ tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+ tlsContext.load_cert_chain(self._serverCert, self._serverKey)
+ tlsContext.set_alpn_protocols(['h2'])
+ reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPInsidePort, tlsContext, self._caCert, self._serverName])
+ reverseProxy.start()
+
+ receivedResponse = None
+ time.sleep(1)
+ conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
+
+ reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+ for idx in range(5):
+ receivedResponse = None
+ toProxyQueue.put(response, True, 2.0)
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ self.assertTrue(receivedResponse)
+
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ receivedQuery.id = query.id
+ receivedResponse.id = response.id
+ self.assertEqual(receivedQuery, query)
+ self.assertEqual(receivedResponse, response)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+ @classmethod
+ def tearDownClass(cls):
+ cls._sock.close()
+ for backgroundThread in cls._backgroundThreads:
+ cls._backgroundThreads[backgroundThread] = False
+ for backgroundThread in backgroundThreads:
+ backgroundThreads[backgroundThread] = False
+ cls.killProcess(cls._dnsdist)
+
class TestProxyProtocolNotExpected(DNSDistTest):
"""
dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1
print('timeout')
self.assertEqual(receivedResponse, None)
+class TestProxyProtocolNotAllowedOnBind(DNSDistTest):
+ """
+ dnsdist is configured to expect a Proxy Protocol header on incoming queries but not on the 127.0.0.1 bind
+ """
+ _skipListeningOnCL = True
+ _config_template = """
+ -- proxy protocol payloads are not allowed on this bind address!
+ addLocal('127.0.0.1:%d', {enableProxyProtocol=false})
+ setProxyProtocolACL( { "127.0.0.1/8" } )
+ newServer{address="127.0.0.1:%d"}
+ """
+ # NORMAL responder, does not expect a proxy protocol payload!
+ _config_params = ['_dnsDistPort', '_testServerPort']
+
+ def testNoHeader(self):
+ """
+ Unexpected Proxy Protocol: no header
+ """
+ # no proxy protocol header and none is expected from this source, should be passed on
+ name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 60,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+
+ response.answer.append(rrset)
+
+ for method in ("sendUDPQuery", "sendTCPQuery"):
+ sender = getattr(self, method)
+ (receivedQuery, receivedResponse) = sender(query, response)
+ receivedQuery.id = query.id
+ self.assertEqual(query, receivedQuery)
+ self.assertEqual(response, receivedResponse)
+
+ def testIncomingProxyDest(self):
+ """
+ Unexpected Proxy Protocol: should be dropped
+ """
+ name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN')
+
+ # Make sure that the proxy payload does NOT turn into a legal qname
+ destAddr = "ff:db8::ffff"
+ destPort = 65535
+ srcAddr = "ff:db8::ffff"
+ srcPort = 65535
+
+ udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+ self.assertEqual(receivedResponse, None)
+
+ tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+ wire = query.to_wire()
+
+ receivedResponse = None
+ try:
+ conn = self.openTCPConnection(2.0)
+ conn.send(tcpPayload)
+ conn.send(struct.pack("!H", len(wire)))
+ conn.send(wire)
+ receivedResponse = self.recvTCPResponseOverConnection(conn)
+ except socket.timeout:
+ print('timeout')
+ self.assertEqual(receivedResponse, None)
+
class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
_serverKey = 'server.key'
_serverCert = 'server.chain'
_serverName = 'tls.tests.dnsdist.org'
_caCert = 'ca.pem'
- _dohServerPort = 8443
- _dohBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohServerPort))
+ _dohWithNGHTTP2ServerPort = pickAvailablePort()
+ _dohWithNGHTTP2BaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohWithNGHTTP2ServerPort))
+ _dohWithH2OServerPort = pickAvailablePort()
+ _dohWithH2OBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohWithH2OServerPort))
_proxyResponderPort = proxyResponderPort
_config_template = """
newServer{address="127.0.0.1:%s", useProxyProtocol=true}
-
- addDOHLocal("127.0.0.1:%s", "%s", "%s")
+ addDOHLocal("127.0.0.1:%d", "%s", "%s", { '/dns-query' }, { trustForwardedForHeader=true, library='nghttp2' })
+ addDOHLocal("127.0.0.1:%d", "%s", "%s", { '/dns-query' }, { trustForwardedForHeader=true, library='h2o' })
+ setACL( { "::1/128", "127.0.0.0/8" } )
"""
- _config_params = ['_proxyResponderPort', '_dohServerPort', '_serverCert', '_serverKey']
+ _config_params = ['_proxyResponderPort', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey']
+ _verboseMode = True
def testTruncation(self):
"""
- DOH: Truncation over UDP (with cache)
+ DOH: Truncation over UDP
"""
# the query is first forwarded over UDP, leading to a TC=1 answer from the
# backend, then over TCP
- name = 'truncated-udp.doh-with-cache.tests.powerdns.com.'
+ name = 'truncated-udp.doh.proxy-protocol.tests.powerdns.com.'
query = dns.message.make_query(name, 'A', 'IN')
query.id = 42
expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
'127.0.0.1')
response.answer.append(rrset)
- # first response is a TC=1
- tcResponse = dns.message.make_response(query)
- tcResponse.flags |= dns.flags.TC
- toProxyQueue.put(tcResponse, True, 2.0)
+ for (port,url) in [(self._dohWithNGHTTP2ServerPort, self._dohWithNGHTTP2BaseURL), (self._dohWithH2OServerPort, self._dohWithH2OBaseURL)]:
+ # first response is a TC=1
+ tcResponse = dns.message.make_response(query)
+ tcResponse.flags |= dns.flags.TC
+ toProxyQueue.put(tcResponse, True, 2.0)
- ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
- # first query, received by the responder over UDP
- self.assertTrue(receivedProxyPayload)
- self.assertTrue(receivedDNSData)
- receivedQuery = dns.message.from_wire(receivedDNSData)
- self.assertTrue(receivedQuery)
- receivedQuery.id = expectedQuery.id
- self.assertEqual(expectedQuery, receivedQuery)
- self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
- self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
+ ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
+ # first query, received by the responder over UDP
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = expectedQuery.id
+ self.assertEqual(expectedQuery, receivedQuery)
+ self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=port)
- # check the response
- self.assertTrue(receivedResponse)
- self.assertEqual(response, receivedResponse)
+ # check the response
+ self.assertTrue(receivedResponse)
+ self.assertEqual(response, receivedResponse)
- # check the second query, received by the responder over TCP
- (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
- self.assertTrue(receivedDNSData)
- receivedQuery = dns.message.from_wire(receivedDNSData)
- self.assertTrue(receivedQuery)
- receivedQuery.id = expectedQuery.id
- self.assertEqual(expectedQuery, receivedQuery)
- self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
- self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
-
- # make sure we consumed everything
- self.assertTrue(toProxyQueue.empty())
- self.assertTrue(fromProxyQueue.empty())
+ # check the second query, received by the responder over TCP
+ (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+ self.assertTrue(receivedDNSData)
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = expectedQuery.id
+ self.assertEqual(expectedQuery, receivedQuery)
+ self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=port)
+
+ # make sure we consumed everything
+ self.assertTrue(toProxyQueue.empty())
+ self.assertTrue(fromProxyQueue.empty())
+
+ def testAddressFamilyMismatch(self):
+ """
+ DOH with IPv6 X-Forwarded-For to an IPv4 endpoint
+ """
+ name = 'x-forwarded-for-af-mismatch.doh.outgoing-proxy-protocol.tests.powerdns.com.'
+ query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+ query.id = 0
+ expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+ expectedQuery.id = 0
+ response = dns.message.make_response(query)
+ rrset = dns.rrset.from_text(name,
+ 3600,
+ dns.rdataclass.IN,
+ dns.rdatatype.A,
+ '127.0.0.1')
+ response.answer.append(rrset)
+
+ for (port,url) in [(self._dohWithNGHTTP2ServerPort, self._dohWithNGHTTP2BaseURL), (self._dohWithH2OServerPort, self._dohWithH2OBaseURL)]:
+ # the query should be dropped
+ (receivedQuery, receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, customHeaders=['x-forwarded-for: [::1]:8080'], useQueue=False)
+ self.assertFalse(receivedQuery)
+ self.assertFalse(receivedResponse)
+
+ # make sure the timeout is detected, if any
+ time.sleep(4)
+
+ # this one should not
+ ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, customHeaders=['x-forwarded-for: 127.0.0.42:8080'], response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
+ self.assertTrue(receivedProxyPayload)
+ self.assertTrue(receivedDNSData)
+ receivedQuery = dns.message.from_wire(receivedDNSData)
+ self.assertTrue(receivedQuery)
+ receivedQuery.id = expectedQuery.id
+ self.assertEqual(expectedQuery, receivedQuery)
+ self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+ self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.42', '127.0.0.1', True, destinationPort=port)
+ # check the response
+ self.assertTrue(receivedResponse)
+ receivedResponse.id = response.id
+ self.assertEqual(response, receivedResponse)