]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_ProxyProtocol.py
Merge pull request #13702 from omoerbeek/rec-log-ref-wrapper
[thirdparty/pdns.git] / regression-tests.dnsdist / test_ProxyProtocol.py
index 7a507819493aca4216951548fac9c8abac316801..2ed60e08bc9f4351b7a59da384030f4b3f7691a6 100644 (file)
@@ -1,14 +1,17 @@
 #!/usr/bin/env python
 
-import copy
 import dns
+import selectors
 import socket
+import ssl
 import struct
 import sys
 import threading
+import time
 
-from dnsdisttests import DNSDistTest
+from dnsdisttests import DNSDistTest, pickAvailablePort
 from proxyprotocol import ProxyProtocol
+from proxyprotocolutils import ProxyProtocolUDPResponder, ProxyProtocolTCPResponder
 from dnsdistdohtests import DNSDistDOHTest
 
 # Python2/3 compatibility hacks
@@ -17,129 +20,106 @@ try:
 except ImportError:
   from Queue import Queue
 
-def ProxyProtocolUDPResponder(port, fromQueue, toQueue):
-    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
-    sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
-    try:
-        sock.bind(("127.0.0.1", port))
-    except socket.error as e:
-        print("Error binding in the Proxy Protocol UDP responder: %s" % str(e))
-        sys.exit(1)
+toProxyQueue = Queue()
+fromProxyQueue = Queue()
+proxyResponderPort = pickAvailablePort()
 
-    while True:
-        data, addr = sock.recvfrom(4096)
-
-        proxy = ProxyProtocol()
-        if len(data) < proxy.HEADER_SIZE:
-            continue
-
-        if not proxy.parseHeader(data):
-            continue
-
-        if proxy.local:
-            # likely a healthcheck
-            data = data[proxy.HEADER_SIZE:]
-            request = dns.message.from_wire(data)
-            response = dns.message.make_response(request)
-            wire = response.to_wire()
-            sock.settimeout(2.0)
-            sock.sendto(wire, addr)
-            sock.settimeout(None)
-
-            continue
-
-        payload = data[:(proxy.HEADER_SIZE + proxy.contentLen)]
-        dnsData = data[(proxy.HEADER_SIZE + proxy.contentLen):]
-        toQueue.put([payload, dnsData], True, 2.0)
-        # computing the correct ID for the response
-        request = dns.message.from_wire(dnsData)
-        response = fromQueue.get(True, 2.0)
-        response.id = request.id
-
-        sock.settimeout(2.0)
-        sock.sendto(response.to_wire(), addr)
-        sock.settimeout(None)
+udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+udpResponder.daemon = True
+udpResponder.start()
+tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
+tcpResponder.daemon = True
+tcpResponder.start()
 
-    sock.close()
+backgroundThreads = {}
+
+def MockTCPReverseProxyAddingProxyProtocol(listeningPort, forwardingPort, serverCtx=None, ca=None, sni=None):
+    # this responder accepts TCP connections on the listening port,
+    # and relay the raw content to a second TCP connection to the
+    # forwarding port, after adding a Proxy Protocol v2 payload
+    # containing the initial source IP and port, destination IP
+    # and port.
+    backgroundThreads[threading.get_native_id()] = True
 
-def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
-    # be aware that this responder will not accept a new connection
-    # until the last one has been closed. This is done on purpose to
-    # to check for connection reuse, making sure that a lot of connections
-    # are not opened in parallel.
     sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
     sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1)
     sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+
+    if serverCtx is not None:
+        sock = serverCtx.wrap_socket(sock, server_side=True)
+
     try:
-        sock.bind(("127.0.0.1", port))
+        sock.bind(("127.0.0.1", listeningPort))
     except socket.error as e:
-        print("Error binding in the TCP responder: %s" % str(e))
+        print("Error binding in the Mock TCP reverse proxy: %s" % str(e))
         sys.exit(1)
-
+    sock.settimeout(0.5)
     sock.listen(100)
+
     while True:
-        (conn, _) = sock.accept()
-        conn.settimeout(5.0)
-        # try to read the entire Proxy Protocol header
-        proxy = ProxyProtocol()
-        header = conn.recv(proxy.HEADER_SIZE)
-        if not header:
-            conn.close()
-            continue
-
-        if not proxy.parseHeader(header):
-            conn.close()
-            continue
-
-        proxyContent = conn.recv(proxy.contentLen)
-        if not proxyContent:
-            conn.close()
-            continue
-
-        payload = header + proxyContent
-        while True:
+        try:
+            (incoming, _) = sock.accept()
+        except socket.timeout:
+            if backgroundThreads.get(threading.get_native_id(), False) == False:
+                del backgroundThreads[threading.get_native_id()]
+                break
+            else:
+              continue
+
+        incoming.settimeout(5.0)
+        payload = ProxyProtocol.getPayload(False, True, False, '127.0.0.1', '127.0.0.1', incoming.getpeername()[1], listeningPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+
+        outgoing = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        outgoing.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
+        outgoing.settimeout(2.0)
+        if sni:
+            if hasattr(ssl, 'create_default_context'):
+                sslctx = ssl.create_default_context(cafile=ca)
+                if hasattr(sslctx, 'set_alpn_protocols'):
+                    sslctx.set_alpn_protocols(['h2'])
+                outgoing = sslctx.wrap_socket(outgoing, server_hostname=sni)
+            else:
+                outgoing = ssl.wrap_socket(outgoing, ca_certs=ca, cert_reqs=ssl.CERT_REQUIRED)
+
+        outgoing.connect(('127.0.0.1', forwardingPort))
+
+        outgoing.send(payload)
+
+        sel = selectors.DefaultSelector()
+        def readFromClient(conn):
+            data = conn.recv(512)
+            if not data or len(data) == 0:
+              return False
+            outgoing.send(data)
+            return True
+
+        def readFromBackend(conn):
+            data = conn.recv(512)
+            if not data or len(data) == 0:
+              return False
+            incoming.send(data)
+            return True
+
+        sel.register(incoming, selectors.EVENT_READ, readFromClient)
+        sel.register(outgoing, selectors.EVENT_READ, readFromBackend)
+        done = False
+        while not done:
           try:
-            data = conn.recv(2)
+            events = sel.select()
+            for key, mask in events:
+              if not (key.data)(key.fileobj):
+                done = True
+                break
           except socket.timeout:
-            data = None
-
-          if not data:
-            conn.close()
             break
-
-          (datalen,) = struct.unpack("!H", data)
-          data = conn.recv(datalen)
-
-          toQueue.put([payload, data], True, 2.0)
-
-          response = copy.deepcopy(fromQueue.get(True, 2.0))
-          if not response:
-            conn.close()
+          except:
             break
 
-          # computing the correct ID for the response
-          request = dns.message.from_wire(data)
-          response.id = request.id
-
-          wire = response.to_wire()
-          conn.send(struct.pack("!H", len(wire)))
-          conn.send(wire)
-
-        conn.close()
+        incoming.close()
+        outgoing.close()
 
     sock.close()
 
-toProxyQueue = Queue()
-fromProxyQueue = Queue()
-proxyResponderPort = 5470
-
-udpResponder = threading.Thread(name='UDP Proxy Protocol Responder', target=ProxyProtocolUDPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
-udpResponder.setDaemon(True)
-udpResponder.start()
-tcpResponder = threading.Thread(name='TCP Proxy Protocol Responder', target=ProxyProtocolTCPResponder, args=[proxyResponderPort, toProxyQueue, fromProxyQueue])
-tcpResponder.setDaemon(True)
-tcpResponder.start()
-
 class ProxyProtocolTest(DNSDistTest):
     _proxyResponderPort = proxyResponderPort
     _config_params = ['_proxyResponderPort']
@@ -397,8 +377,10 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     """
 
     _config_template = """
+    addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=true})
+    addDOHLocal("127.0.0.1:%d", "%s", "%s", {"/"}, {library='nghttp2', proxyProtocolOutsideTLS=false})
     setProxyProtocolACL( { "127.0.0.1/32" } )
-    newServer{address="127.0.0.1:%d", useProxyProtocol=true}
+    newServer{address="127.0.0.1:%d", useProxyProtocol=true, proxyProtocolAdvertiseTLS=true}
 
     function addValues(dq)
       dq:addProxyProtocolValue(0, 'foo')
@@ -433,8 +415,13 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
     -- override all existing values
     addAction("override.proxy-protocol-incoming.tests.powerdns.com.", SetProxyProtocolValuesAction({["50"]="overridden"}))
     """
-    _config_params = ['_proxyResponderPort']
-    _verboseMode = True
+    _serverKey = 'server.key'
+    _serverCert = 'server.chain'
+    _serverName = 'tls.tests.dnsdist.org'
+    _caCert = 'ca.pem'
+    _dohServerPPOutsidePort = pickAvailablePort()
+    _dohServerPPInsidePort = pickAvailablePort()
+    _config_params = ['_dohServerPPOutsidePort', '_serverCert', '_serverKey', '_dohServerPPInsidePort', '_serverCert', '_serverKey', '_proxyResponderPort']
 
     def testNoHeader(self):
         """
@@ -444,9 +431,12 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
         name = 'no-header.incoming-proxy-protocol.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
 
-        for method in ("sendUDPQuery", "sendTCPQuery"):
+        for method in ("sendUDPQuery", "sendTCPQuery", "sendDOHQueryWrapper"):
           sender = getattr(self, method)
-          (_, receivedResponse) = sender(query, response=None)
+          try:
+            (_, receivedResponse) = sender(query, response=None)
+          except Exception:
+            receivedResponse = None
           self.assertEqual(receivedResponse, None)
 
     def testIncomingProxyDest(self):
@@ -655,6 +645,118 @@ class TestProxyProtocolIncoming(ProxyProtocolTest):
           self.assertEqual(receivedResponse, response)
           self.checkMessageProxyProtocol(receivedProxyPayload, srcAddr, destAddr, True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [ 42, b'bar'], [255, b'proxy-protocol'] ], True, srcPort, destPort)
 
+    def testProxyDoHSeveralQueriesOverConnectionPPOutside(self):
+        """
+        Incoming Proxy Protocol: Several queries over the same connection (DoH, PP outside TLS)
+        """
+        name = 'several-queries.doh-outside.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        reverseProxyPort = pickAvailablePort()
+        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPOutsidePort])
+        reverseProxy.start()
+        time.sleep(1)
+
+        receivedResponse = None
+        conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
+
+        reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+        for idx in range(5):
+          receivedResponse = None
+          toProxyQueue.put(response, True, 2.0)
+          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          self.assertTrue(receivedResponse)
+
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          receivedQuery.id = query.id
+          receivedResponse.id = response.id
+          self.assertEqual(receivedQuery, query)
+          self.assertEqual(receivedResponse, response)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+    def testProxyDoHSeveralQueriesOverConnectionPPInside(self):
+        """
+        Incoming Proxy Protocol: Several queries over the same connection (DoH, PP inside TLS)
+        """
+        name = 'several-queries.doh-inside.proxy-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+
+        toProxyQueue.put(response, True, 2.0)
+
+        wire = query.to_wire()
+
+        reverseProxyPort = pickAvailablePort()
+        tlsContext = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+        tlsContext.load_cert_chain(self._serverCert, self._serverKey)
+        tlsContext.set_alpn_protocols(['h2'])
+        reverseProxy = threading.Thread(name='Mock Proxy Protocol Reverse Proxy', target=MockTCPReverseProxyAddingProxyProtocol, args=[reverseProxyPort, self._dohServerPPInsidePort, tlsContext, self._caCert, self._serverName])
+        reverseProxy.start()
+
+        receivedResponse = None
+        time.sleep(1)
+        conn = self.openDOHConnection(reverseProxyPort, self._caCert, timeout=2.0)
+
+        reverseProxyBaseURL = ("https://%s:%d/" % (self._serverName, reverseProxyPort))
+        (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+        for idx in range(5):
+          receivedResponse = None
+          toProxyQueue.put(response, True, 2.0)
+          (receivedQuery, receivedResponse) = self.sendDOHQuery(reverseProxyPort, self._serverName, reverseProxyBaseURL, query, response=response, caFile=self._caCert, useQueue=True, conn=conn)
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          self.assertTrue(receivedResponse)
+
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          receivedQuery.id = query.id
+          receivedResponse.id = response.id
+          self.assertEqual(receivedQuery, query)
+          self.assertEqual(receivedResponse, response)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [0, b'foo'], [1, b'dnsdist'], [ 2, b'foo'], [3, b'proxy'], [32, ''], [ 42, b'bar'], [255, b'proxy-protocol'] ], v6=False, sourcePort=None, destinationPort=reverseProxyPort)
+
+    @classmethod
+    def tearDownClass(cls):
+        cls._sock.close()
+        for backgroundThread in cls._backgroundThreads:
+            cls._backgroundThreads[backgroundThread] = False
+        for backgroundThread in backgroundThreads:
+            backgroundThreads[backgroundThread] = False
+        cls.killProcess(cls._dnsdist)
+
 class TestProxyProtocolNotExpected(DNSDistTest):
     """
     dnsdist is configured to expect a Proxy Protocol header on incoming queries but not from 127.0.0.1
@@ -722,29 +824,101 @@ class TestProxyProtocolNotExpected(DNSDistTest):
           print('timeout')
         self.assertEqual(receivedResponse, None)
 
+class TestProxyProtocolNotAllowedOnBind(DNSDistTest):
+    """
+    dnsdist is configured to expect a Proxy Protocol header on incoming queries but not on the 127.0.0.1 bind
+    """
+    _skipListeningOnCL = True
+    _config_template = """
+    -- proxy protocol payloads are not allowed on this bind address!
+    addLocal('127.0.0.1:%d', {enableProxyProtocol=false})
+    setProxyProtocolACL( { "127.0.0.1/8" } )
+    newServer{address="127.0.0.1:%d"}
+    """
+    # NORMAL responder, does not expect a proxy protocol payload!
+    _config_params = ['_dnsDistPort', '_testServerPort']
+
+    def testNoHeader(self):
+        """
+        Unexpected Proxy Protocol: no header
+        """
+        # no proxy protocol header and none is expected from this source, should be passed on
+        name = 'no-header.unexpected-proxy-protocol.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+
+        response.answer.append(rrset)
+
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+          sender = getattr(self, method)
+          (receivedQuery, receivedResponse) = sender(query, response)
+          receivedQuery.id = query.id
+          self.assertEqual(query, receivedQuery)
+          self.assertEqual(response, receivedResponse)
+
+    def testIncomingProxyDest(self):
+        """
+        Unexpected Proxy Protocol: should be dropped
+        """
+        name = 'with-proxy-payload.unexpected-protocol-incoming.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+
+        # Make sure that the proxy payload does NOT turn into a legal qname
+        destAddr = "ff:db8::ffff"
+        destPort = 65535
+        srcAddr = "ff:db8::ffff"
+        srcPort = 65535
+
+        udpPayload = ProxyProtocol.getPayload(False, False, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        (_, receivedResponse) = self.sendUDPQuery(udpPayload + query.to_wire(), response=None, useQueue=False, rawQuery=True)
+        self.assertEqual(receivedResponse, None)
+
+        tcpPayload = ProxyProtocol.getPayload(False, True, True, srcAddr, destAddr, srcPort, destPort, [ [ 2, b'foo'], [ 3, b'proxy'] ])
+        wire = query.to_wire()
+
+        receivedResponse = None
+        try:
+          conn = self.openTCPConnection(2.0)
+          conn.send(tcpPayload)
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+        self.assertEqual(receivedResponse, None)
+
 class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
 
     _serverKey = 'server.key'
     _serverCert = 'server.chain'
     _serverName = 'tls.tests.dnsdist.org'
     _caCert = 'ca.pem'
-    _dohServerPort = 8443
-    _dohBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohServerPort))
+    _dohWithNGHTTP2ServerPort = pickAvailablePort()
+    _dohWithNGHTTP2BaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohWithNGHTTP2ServerPort))
+    _dohWithH2OServerPort = pickAvailablePort()
+    _dohWithH2OBaseURL = ("https://%s:%d/dns-query" % (_serverName, _dohWithH2OServerPort))
     _proxyResponderPort = proxyResponderPort
     _config_template = """
     newServer{address="127.0.0.1:%s", useProxyProtocol=true}
-
-    addDOHLocal("127.0.0.1:%s", "%s", "%s")
+    addDOHLocal("127.0.0.1:%d", "%s", "%s", { '/dns-query' }, { trustForwardedForHeader=true, library='nghttp2' })
+    addDOHLocal("127.0.0.1:%d", "%s", "%s", { '/dns-query' }, { trustForwardedForHeader=true, library='h2o' })
+    setACL( { "::1/128", "127.0.0.0/8" } )
     """
-    _config_params = ['_proxyResponderPort', '_dohServerPort', '_serverCert', '_serverKey']
+    _config_params = ['_proxyResponderPort', '_dohWithNGHTTP2ServerPort', '_serverCert', '_serverKey', '_dohWithH2OServerPort', '_serverCert', '_serverKey']
+    _verboseMode = True
 
     def testTruncation(self):
         """
-        DOH: Truncation over UDP (with cache)
+        DOH: Truncation over UDP
         """
         # the query is first forwarded over UDP, leading to a TC=1 answer from the
         # backend, then over TCP
-        name = 'truncated-udp.doh-with-cache.tests.powerdns.com.'
+        name = 'truncated-udp.doh.proxy-protocol.tests.powerdns.com.'
         query = dns.message.make_query(name, 'A', 'IN')
         query.id = 42
         expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
@@ -757,36 +931,78 @@ class TestDOHWithOutgoingProxyProtocol(DNSDistDOHTest):
                                     '127.0.0.1')
         response.answer.append(rrset)
 
-        # first response is a TC=1
-        tcResponse = dns.message.make_response(query)
-        tcResponse.flags |= dns.flags.TC
-        toProxyQueue.put(tcResponse, True, 2.0)
+        for (port,url) in [(self._dohWithNGHTTP2ServerPort, self._dohWithNGHTTP2BaseURL), (self._dohWithH2OServerPort, self._dohWithH2OBaseURL)]:
+          # first response is a TC=1
+          tcResponse = dns.message.make_response(query)
+          tcResponse.flags |= dns.flags.TC
+          toProxyQueue.put(tcResponse, True, 2.0)
 
-        ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(self._dohServerPort, self._serverName, self._dohBaseURL, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
-        # first query, received by the responder over UDP
-        self.assertTrue(receivedProxyPayload)
-        self.assertTrue(receivedDNSData)
-        receivedQuery = dns.message.from_wire(receivedDNSData)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = expectedQuery.id
-        self.assertEqual(expectedQuery, receivedQuery)
-        self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
-        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
+          ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
+          # first query, received by the responder over UDP
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          self.assertTrue(receivedQuery)
+          receivedQuery.id = expectedQuery.id
+          self.assertEqual(expectedQuery, receivedQuery)
+          self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=port)
 
-        # check the response
-        self.assertTrue(receivedResponse)
-        self.assertEqual(response, receivedResponse)
+          # check the response
+          self.assertTrue(receivedResponse)
+          self.assertEqual(response, receivedResponse)
 
-        # check the second query, received by the responder over TCP
-        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
-        self.assertTrue(receivedDNSData)
-        receivedQuery = dns.message.from_wire(receivedDNSData)
-        self.assertTrue(receivedQuery)
-        receivedQuery.id = expectedQuery.id
-        self.assertEqual(expectedQuery, receivedQuery)
-        self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
-        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=self._dohServerPort)
-
-        # make sure we consumed everything
-        self.assertTrue(toProxyQueue.empty())
-        self.assertTrue(fromProxyQueue.empty())
+          # check the second query, received by the responder over TCP
+          (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+          self.assertTrue(receivedDNSData)
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          self.assertTrue(receivedQuery)
+          receivedQuery.id = expectedQuery.id
+          self.assertEqual(expectedQuery, receivedQuery)
+          self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, destinationPort=port)
+
+          # make sure we consumed everything
+          self.assertTrue(toProxyQueue.empty())
+          self.assertTrue(fromProxyQueue.empty())
+
+    def testAddressFamilyMismatch(self):
+        """
+        DOH with IPv6 X-Forwarded-For to an IPv4 endpoint
+        """
+        name = 'x-forwarded-for-af-mismatch.doh.outgoing-proxy-protocol.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN', use_edns=False)
+        query.id = 0
+        expectedQuery = dns.message.make_query(name, 'A', 'IN', use_edns=True, payload=4096)
+        expectedQuery.id = 0
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '127.0.0.1')
+        response.answer.append(rrset)
+
+        for (port,url) in [(self._dohWithNGHTTP2ServerPort, self._dohWithNGHTTP2BaseURL), (self._dohWithH2OServerPort, self._dohWithH2OBaseURL)]:
+          # the query should be dropped
+          (receivedQuery, receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, customHeaders=['x-forwarded-for: [::1]:8080'], useQueue=False)
+          self.assertFalse(receivedQuery)
+          self.assertFalse(receivedResponse)
+
+          # make sure the timeout is detected, if any
+          time.sleep(4)
+
+          # this one should not
+          ((receivedProxyPayload, receivedDNSData), receivedResponse) = self.sendDOHQuery(port, self._serverName, url, query, caFile=self._caCert, customHeaders=['x-forwarded-for: 127.0.0.42:8080'], response=response, fromQueue=fromProxyQueue, toQueue=toProxyQueue)
+          self.assertTrue(receivedProxyPayload)
+          self.assertTrue(receivedDNSData)
+          receivedQuery = dns.message.from_wire(receivedDNSData)
+          self.assertTrue(receivedQuery)
+          receivedQuery.id = expectedQuery.id
+          self.assertEqual(expectedQuery, receivedQuery)
+          self.checkQueryEDNSWithoutECS(expectedQuery, receivedQuery)
+          self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.42', '127.0.0.1', True, destinationPort=port)
+          # check the response
+          self.assertTrue(receivedResponse)
+          receivedResponse.id = response.id
+          self.assertEqual(response, receivedResponse)