]> git.ipfire.org Git - thirdparty/pdns.git/blobdiff - regression-tests.dnsdist/test_ProxyProtocol.py
Merge pull request #8945 from rgacogne/ddist-x-forwarded-for
[thirdparty/pdns.git] / regression-tests.dnsdist / test_ProxyProtocol.py
index 2d4e0069ec58e35e01e99ae27774334489de0c81..ab6c1a208c8b9632430ba527d0e1ace9c38d3445 100644 (file)
@@ -91,26 +91,34 @@ def ProxyProtocolTCPResponder(port, fromQueue, toQueue):
             continue
 
         payload = header + proxyContent
+        while True:
+          try:
+            data = conn.recv(2)
+          except socket.timeout:
+            data = None
 
-        data = conn.recv(2)
-        (datalen,) = struct.unpack("!H", data)
+          if not data:
+            conn.close()
+            break
 
-        data = conn.recv(datalen)
+          (datalen,) = struct.unpack("!H", data)
+          data = conn.recv(datalen)
 
-        toQueue.put([payload, data], True, 2.0)
+          toQueue.put([payload, data], True, 2.0)
 
-        response = fromQueue.get(True, 2.0)
-        if not response:
+          response = fromQueue.get(True, 2.0)
+          if not response:
             conn.close()
-            continue
+            break
 
-        # computing the correct ID for the response
-        request = dns.message.from_wire(data)
-        response.id = request.id
+          # computing the correct ID for the response
+          request = dns.message.from_wire(data)
+          response.id = request.id
+
+          wire = response.to_wire()
+          conn.send(struct.pack("!H", len(wire)))
+          conn.send(wire)
 
-        wire = response.to_wire()
-        conn.send(struct.pack("!H", len(wire)))
-        conn.send(wire)
         conn.close()
 
     sock.close()
@@ -162,7 +170,7 @@ class TestProxyProtocol(ProxyProtocolTest):
     newServer{address="127.0.0.1:%d", useProxyProtocol=true}
 
     function addValues(dq)
-      local values = { ["0"]="foo", ["42"]="bar" }
+      local values = { [0]="foo", [42]="bar" }
       dq:setProxyProtocolValues(values)
       return DNSAction.None
     end
@@ -366,3 +374,35 @@ class TestProxyProtocol(ProxyProtocolTest):
       self.assertEquals(receivedQuery, query)
       self.assertEquals(receivedResponse, response)
       self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [ [1, b'dnsdist'] , [ 255, b'proxy-protocol'] ])
+
+    def testProxyTCPSeveralQueriesOnSameConnection(self):
+      """
+        Proxy Protocol: Several queries on the same TCP connection
+      """
+      name = 'several-queries-same-conn.proxy.tests.powerdns.com.'
+      query = dns.message.make_query(name, 'A', 'IN')
+      response = dns.message.make_response(query)
+
+      conn = self.openTCPConnection(2.0)
+      data = query.to_wire()
+
+      for idx in range(10):
+        toProxyQueue.put(response, True, 2.0)
+        self.sendTCPQueryOverConnection(conn, data, rawQuery=True)
+        receivedResponse = None
+        try:
+          receivedResponse = self.recvTCPResponseOverConnection(conn)
+        except socket.timeout:
+          print('timeout')
+
+        (receivedProxyPayload, receivedDNSData) = fromProxyQueue.get(True, 2.0)
+        self.assertTrue(receivedProxyPayload)
+        self.assertTrue(receivedDNSData)
+        self.assertTrue(receivedResponse)
+
+        receivedQuery = dns.message.from_wire(receivedDNSData)
+        receivedQuery.id = query.id
+        receivedResponse.id = response.id
+        self.assertEquals(receivedQuery, query)
+        self.assertEquals(receivedResponse, response)
+        self.checkMessageProxyProtocol(receivedProxyPayload, '127.0.0.1', '127.0.0.1', True, [])