]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Fix the harvesting of destination addresses
authorRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 24 Feb 2023 10:30:44 +0000 (11:30 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Fri, 24 Feb 2023 10:30:44 +0000 (11:30 +0100)
The original destination was not properly updated: 'hopLocal' contains
the destination address of the packet we received, and matches 'origDest'
unless the proxy protocol is used, in which case 'origDest' will be
updated by the 'real' destination address as seen by the client and
the first hop.
Reported by phonedph1 (many thanks!).

pdns/dnsdist.cc
regression-tests.dnsdist/test_Advanced.py

index f2e5cd8c2ae691aa39d1f9ee8d51fd8eb9313624..1e8ff8b6779ad0db37842470bb6c01b381ba9be2 100644 (file)
@@ -1165,6 +1165,10 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s
     dest.sin4.sin_family = 0;
   }
 
+  if (dest.sin4.sin_family == 0) {
+    dest = cs.local;
+  }
+
   ++cs.queries;
   ++g_stats.queries;
 
@@ -1591,8 +1595,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
   ids.cs = &cs;
   ids.origRemote = remote;
   ids.hopRemote = remote;
-  ids.origDest = dest;
-  ids.hopLocal = dest;
   ids.protocol = dnsdist::Protocol::DoUDP;
 
   try {
@@ -1602,6 +1604,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
     /* dest might have been updated, if we managed to harvest the destination address */
     ids.origDest = dest;
+    ids.hopLocal = dest;
 
     std::vector<ProxyProtocolValue> proxyProtocolValues;
     if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, ids.origRemote, ids.origDest, proxyProtocolValues)) {
@@ -1634,9 +1637,6 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct
     }
 
     ids.qname = DNSName(reinterpret_cast<const char*>(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass);
-    if (ids.origDest.sin4.sin_family == 0) {
-      ids.origDest = cs.local;
-    }
     if (ids.dnsCryptQuery) {
       ids.protocol = dnsdist::Protocol::DNSCryptUDP;
     }
index 0c3947ac890801429bd87ac9615f668f6e29419a..fc1f6eb7d9a82319bee5aad6f84e48b1bde1b61c 100644 (file)
@@ -1,6 +1,7 @@
 #!/usr/bin/env python
 import base64
 import os
+import socket
 import time
 import unittest
 import dns
@@ -266,9 +267,9 @@ class TestAdvancedGetLocalPortOnAnyBind(DNSDistTest):
       return DNSAction.Spoof, "port-was-"..port..".local-port-any.advanced.tests.powerdns.com."
     end
     addAction("local-port-any.advanced.tests.powerdns.com.", LuaAction(answerBasedOnLocalPort))
-    newServer{address="127.0.0.1:%s"}
+    newServer{address="127.0.0.1:%d"}
     """
-    _dnsDistListeningAddr = "0.0.0.0"
+    _dnsDistListeningAddr = '0.0.0.0'
 
     def testAdvancedGetLocalPortOnAnyBind(self):
         """
@@ -304,8 +305,12 @@ class TestAdvancedGetLocalAddressOnAnyBind(DNSDistTest):
     end
     addAction("local-address-any.advanced.tests.powerdns.com.", LuaAction(answerBasedOnLocalAddress))
     newServer{address="127.0.0.1:%s"}
+    addLocal('0.0.0.0:%d')
+    addLocal('[::]:%d')
     """
-    _dnsDistListeningAddr = "0.0.0.0"
+    _config_params = ['_testServerPort', '_dnsDistPort', '_dnsDistPort']
+    _acl = ['127.0.0.1/32', '::1/128']
+    _skipListeningOnCL = True
 
     def testAdvancedGetLocalAddressOnAnyBind(self):
         """
@@ -329,6 +334,86 @@ class TestAdvancedGetLocalAddressOnAnyBind(DNSDistTest):
             (_, receivedResponse) = sender(query, response=None, useQueue=False)
             self.assertEqual(receivedResponse, response)
 
+        # now a bit more tricky, UDP-only IPv4
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.CNAME,
+                                    'address-was-127-0-0-2.local-address-any.advanced.tests.powerdns.com.')
+        response.answer.append(rrset)
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.settimeout(1.0)
+        sock.connect(('127.0.0.2', self._dnsDistPort))
+        try:
+            query = query.to_wire()
+            sock.send(query)
+            (data, remote) = sock.recvfrom(4096)
+            self.assertEquals(remote[0], '127.0.0.2')
+        except socket.timeout:
+            data = None
+
+        self.assertTrue(data)
+        receivedResponse = dns.message.from_wire(data)
+        self.assertEqual(receivedResponse, response)
+
+    def testAdvancedCheckSourceAddrOnAnyBind(self):
+        """
+        Advanced: Check the source address on responses for an ANY bind
+        """
+        name = 'source-addr-any.advanced.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        # dnsdist set RA = RD for spoofed responses
+        query.flags &= ~dns.flags.RD
+
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    60,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.42')
+        response.answer.append(rrset)
+
+        # a bit more tricky, UDP-only IPv4
+        sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+        sock.settimeout(1.0)
+        sock.connect(('127.0.0.2', self._dnsDistPort))
+        self._toResponderQueue.put(response, True, 1.0)
+        try:
+            data = query.to_wire()
+            sock.send(data)
+            (data, remote) = sock.recvfrom(4096)
+            self.assertEquals(remote[0], '127.0.0.2')
+        except socket.timeout:
+            data = None
+
+        self.assertTrue(data)
+        receivedResponse = dns.message.from_wire(data)
+        receivedQuery = self._fromResponderQueue.get(True, 1.0)
+        receivedQuery.id = query.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+
+        # a bit more tricky, UDP-only IPv6
+        sock = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
+        sock.settimeout(1.0)
+        sock.connect(('::1', self._dnsDistPort))
+        self._toResponderQueue.put(response, True, 1.0)
+        try:
+            data = query.to_wire()
+            sock.send(data)
+            (data, remote) = sock.recvfrom(4096)
+            self.assertEquals(remote[0], '::1')
+        except socket.timeout:
+            data = None
+
+        self.assertTrue(data)
+        receivedResponse = dns.message.from_wire(data)
+        receivedQuery = self._fromResponderQueue.get(True, 1.0)
+        receivedQuery.id = query.id
+        self.assertEqual(receivedQuery, query)
+        self.assertEqual(receivedResponse, response)
+
 class TestAdvancedAllowHeaderOnly(DNSDistTest):
 
     _config_template = """