]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Allow TTL alteration via Lua 4787/head
authorRemi Gacogne <remi.gacogne@powerdns.com>
Thu, 1 Dec 2016 16:15:27 +0000 (17:15 +0100)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 6 Dec 2016 09:37:52 +0000 (10:37 +0100)
pdns/README-dnsdist.md
pdns/dnsdist-console.cc
pdns/dnsdist-lua.cc
pdns/dnsparser.cc
pdns/dnsparser.hh
regression-tests.dnsdist/test_Responses.py

index 1f8455e9fbf45acbf7aa379982e271e0a8adf130..c03d9bb30843f756c0eadb83913562d20dcd4aa6 100644 (file)
@@ -400,6 +400,7 @@ Rules can be added via:
 Response rules can be added via:
 
  * addResponseAction(DNS rule, DNS Response Action)
+ * AddLuaResponseAction(DNS rule, Lua function)
 
 A DNS rule can be:
 
@@ -572,6 +573,8 @@ Valid return values for `LuaAction` functions are:
  * DNSAction.Pool: use the specified pool to forward this query
  * DNSAction.Spoof: spoof the response using the supplied IPv4 (A), IPv6 (AAAA) or string (CNAME) value
 
+The same feature exists to hand off some responses for Lua inspection, using `addLuaResponseAction(x, func)`.
+
 DNSSEC
 ------
 To provide DNSSEC service from a separate pool, try:
@@ -1379,6 +1382,8 @@ instantiate a server with additional parameters
  * Lua Action related:
     * `addLuaAction(x, func)`: where 'x' is all the combinations from `addPoolRule`, and func is a 
       function with the parameter `dq`, which returns an action to be taken on this packet.
+    * `addLuaResponseAction(x, func)`: where 'x' is all the combinations from `addPoolRule`, and func is a
+      function with the parameter `dr`, which returns an action to be taken on this response packet.
       Good for rare packets but where you want to do a lot of processing.
  * Server selection policy related:
     * `setServerPolicy(policy)`: set server selection policy to that policy
@@ -1464,6 +1469,10 @@ instantiate a server with additional parameters
         * member `skipCache`: whether to skip cache lookup / storing the answer for this question (settable)
         * member `tcp`: whether this question was received over a TCP socket
         * member `useECS`: whether to send ECS to the backend (settable)
+    * DNSResponse gets the same member than DNSQuestion, plus some:
+        * member `editTTLs(func)`: the function `func` is invoked for every entries in the answer, authority
+        and additional section taking the section number (1 for answer, 2 for authority, 3 for additional),
+        the qclass and qtype values and the current TTL, and returning the new TTL or 0 to leave it unchanged
     * DNSHeader related
         * member `getRD()`: get recursion desired flag
         * member `setRD(bool)`: set recursion desired flag
index 437c4b552587ddf35c0c360866431854c4ef7e6f..61f3b7acba265de417223abacf5a61202d70c52d 100644 (file)
@@ -257,6 +257,7 @@ const std::vector<ConsoleKeyword> g_consoleKeywords{
   { "addDynBlocks", true, "addresses, message[, seconds]", "block the set of addresses with message `msg`, for `seconds` seconds (10 by default)" },
   { "addLocal", true, "netmask, [true], [false], [TCP Fast Open queue size]", "add to addresses we listen on. Second optional parameter sets TCP or not. Third optional parameter sets SO_REUSEPORT when available. Last parameter sets the TCP Fast Open queue size, enabling TCP Fast Open when available and the value is larger than 0" },
   { "addLuaAction", true, "x, func", "where 'x' is all the combinations from `addPoolRule`, and func is a function with the parameter `dq`, which returns an action to be taken on this packet. Good for rare packets but where you want to do a lot of processing" },
+  { "addLuaResponseAction", true, "x, func", "where 'x' is all the combinations from `addPoolRule`, and func is a function with the parameter `dr`, which returns an action to be taken on this response packet. Good for rare packets but where you want to do a lot of processing" },
   { "addNoRecurseRule", true, "domain", "clear the RD flag for all queries matching the specified domain" },
   { "addPoolRule", true, "domain, pool", "send queries to this domain to that pool" },
   { "addQPSLimit", true, "domain, n", "limit queries within that domain to n per second" },
index ecfc53e8003bfcd083e5d5c51b97c401f5f8f22a..250440eeca4d8b85b9404ed5cdcf7f7d669775c8 100644 (file)
@@ -65,6 +65,31 @@ private:
   func_t d_func;
 };
 
+class LuaResponseAction : public DNSResponseAction
+{
+public:
+  typedef std::function<std::tuple<int, string>(DNSResponse* dr)> func_t;
+  LuaResponseAction(LuaResponseAction::func_t func) : d_func(func)
+  {}
+
+  Action operator()(DNSResponse* dr, string* ruleresult) const override
+  {
+    std::lock_guard<std::mutex> lock(g_luamutex);
+    auto ret = d_func(dr);
+    if(ruleresult)
+      *ruleresult=std::get<1>(ret);
+    return (Action)std::get<0>(ret);
+  }
+
+  string toString() const override
+  {
+    return "Lua response script";
+  }
+
+private:
+  func_t d_func;
+};
+
 typedef boost::variant<string,vector<pair<int, string>>, std::shared_ptr<DNSRule> > luadnsrule_t;
 std::shared_ptr<DNSRule> makeRule(const luadnsrule_t& var)
 {
@@ -676,6 +701,14 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
                          });
                      });
 
+  g_lua.writeFunction("addLuaResponseAction", [](luadnsrule_t var, LuaResponseAction::func_t func) {
+      setLuaSideEffect();
+      auto rule=makeRule(var);
+      g_resprulactions.modify([rule,func](decltype(g_resprulactions)::value_type& rulactions){
+          rulactions.push_back({rule,
+                std::make_shared<LuaResponseAction>(func)});
+        });
+    });
 
   g_lua.writeFunction("NoRecurseAction", []() {
       return std::shared_ptr<DNSAction>(new NoRecurseAction);
@@ -1498,6 +1531,9 @@ vector<std::function<void(void)>> setupLua(bool client, const std::string& confi
   g_lua.registerMember<size_t (DNSResponse::*)>("size", [](const DNSResponse& dq) -> size_t { return dq.size; }, [](DNSResponse& dq, size_t newSize) { (void) newSize; });
   g_lua.registerMember<bool (DNSResponse::*)>("tcp", [](const DNSResponse& dq) -> bool { return dq.tcp; }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; });
   g_lua.registerMember<bool (DNSResponse::*)>("skipCache", [](const DNSResponse& dq) -> bool { return dq.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.skipCache = newSkipCache; });
+  g_lua.registerFunction<void(DNSResponse::*)(std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc)>("editTTLs", [](const DNSResponse& dr, std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc) {
+        editDNSPacketTTL((char*) dr.dh, dr.len, editFunc);
+      });
 
   g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) {
       if (!g_configurationDone) {
index ebdcf0a067821a7c5cbe24cd9672c38283ce8e61..6f2eab7959bf69218fc87790c15f390e3f9fac1a 100644 (file)
@@ -573,7 +573,11 @@ public:
   }
   void skipBytes(uint16_t bytes)
   {
-      moveOffset(bytes);
+    moveOffset(bytes);
+  }
+  void rewindBytes(uint16_t by)
+  {
+    rewindOffset(by);
   }
   uint32_t get32BitInt()
   {
@@ -604,11 +608,12 @@ public:
     int toskip = get16BitInt();
     moveOffset(toskip);
   }
+
   void decreaseAndSkip32BitInt(uint32_t decrease)
   {
     const char *p = d_packet + d_offset;
     moveOffset(4);
-    
+
     uint32_t tmp;
     memcpy(&tmp, (void*) p, sizeof(tmp));
     tmp = ntohl(tmp);
@@ -616,6 +621,13 @@ public:
     tmp = htonl(tmp);
     memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp));
   }
+  void setAndSkip32BitInt(uint32_t value)
+  {
+    moveOffset(4);
+
+    value = htonl(value);
+    memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value));
+  }
   uint32_t getOffset() const
   {
     return d_offset;
@@ -628,6 +640,16 @@ private:
       throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > " 
       + std::to_string(d_length) );
   }
+  void rewindOffset(uint16_t by)
+  {
+    if(d_notyouroffset < by)
+      throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
+                              + std::to_string(by));
+    d_notyouroffset -= by;
+    if(d_notyouroffset < 12)
+      throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < "
+                              + std::to_string(12));
+  }
   char* d_packet;
   size_t d_length;
   
@@ -636,6 +658,50 @@ private:
   
 };
 
+// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
+void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)> visitor)
+{
+  if(length < sizeof(dnsheader))
+    return;
+  try
+  {
+    dnsheader dh;
+    memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh));
+    uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
+    DNSPacketMangler dpm(packet, length);
+
+    uint64_t n;
+    for(n=0; n < ntohs(dh.qdcount) ; ++n) {
+      dpm.skipLabel();
+      /* type and class */
+      dpm.skipBytes(4);
+    }
+
+    for(n=0; n < numrecords; ++n) {
+      dpm.skipLabel();
+
+      uint8_t section = n < dh.ancount ? 1 : (n < (dh.ancount + dh.nscount) ? 2 : 3);
+      uint16_t dnstype = dpm.get16BitInt();
+      uint16_t dnsclass = dpm.get16BitInt();
+
+      if(dnstype == QType::OPT) // not getting near that one with a stick
+        break;
+
+      uint32_t dnsttl = dpm.get32BitInt();
+      uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl);
+      if (newttl) {
+        dpm.rewindBytes(sizeof(newttl));
+        dpm.setAndSkip32BitInt(newttl);
+      }
+      dpm.skipRData();
+    }
+  }
+  catch(...)
+  {
+    return;
+  }
+}
+
 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
 void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
 {
index e64985f62bafb12bf9eece8eb33b6507e9df9aba..7c89655d15f74851ea65177471678ea6f7aea533 100644 (file)
@@ -385,6 +385,7 @@ private:
 string simpleCompress(const string& label, const string& root="");
 void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
 void ageDNSPacket(std::string& packet, uint32_t seconds);
+void editDNSPacketTTL(char* packet, size_t length, std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)> visitor);
 uint32_t getDNSPacketMinTTL(const char* packet, size_t length);
 uint32_t getDNSPacketLength(const char* packet, size_t length);
 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
index 096b09d7ed28b8b7555c75b3dc196d04c4a61668..1982b60d5bcc390697ddcbd807e7ce47af20d05a 100644 (file)
@@ -151,3 +151,50 @@ class TestResponseRuleQNameAllowed(DNSDistTest):
         receivedQuery.id = query.id
         self.assertEquals(query, receivedQuery)
         self.assertEquals(receivedResponse, None)
+
+class TestResponseRuleEditTTL(DNSDistTest):
+
+    _ttl = 5
+    _config_params = ['_testServerPort', '_ttl']
+    _config_template = """
+    newServer{address="127.0.0.1:%s"}
+
+    function editTTLCallback(section, class, type, ttl)
+      return %d
+    end
+
+    function editTTLFunc(dr)
+      dr:editTTLs(editTTLCallback)
+      return DNSAction.None, ""
+    end
+
+    addLuaResponseAction(AllRule(), editTTLFunc)
+    """
+
+    def testTTLEdited(self):
+        """
+        Responses: Alter the TTLs
+        """
+        name = 'editttl.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+
+        (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)
+
+        (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response)
+        receivedQuery.id = query.id
+        self.assertEquals(query, receivedQuery)
+        self.assertEquals(response, receivedResponse)
+        self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl)
+        self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)