From: Remi Gacogne Date: Thu, 1 Dec 2016 16:15:27 +0000 (+0100) Subject: dnsdist: Allow TTL alteration via Lua X-Git-Tag: rec-4.1.0-alpha1~316^2 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=153d506594fa942007914d06934da59aea10ccdf;p=thirdparty%2Fpdns.git dnsdist: Allow TTL alteration via Lua --- diff --git a/pdns/README-dnsdist.md b/pdns/README-dnsdist.md index 1f8455e9fb..c03d9bb308 100644 --- a/pdns/README-dnsdist.md +++ b/pdns/README-dnsdist.md @@ -400,6 +400,7 @@ Rules can be added via: Response rules can be added via: * addResponseAction(DNS rule, DNS Response Action) + * AddLuaResponseAction(DNS rule, Lua function) A DNS rule can be: @@ -572,6 +573,8 @@ Valid return values for `LuaAction` functions are: * DNSAction.Pool: use the specified pool to forward this query * DNSAction.Spoof: spoof the response using the supplied IPv4 (A), IPv6 (AAAA) or string (CNAME) value +The same feature exists to hand off some responses for Lua inspection, using `addLuaResponseAction(x, func)`. + DNSSEC ------ To provide DNSSEC service from a separate pool, try: @@ -1379,6 +1382,8 @@ instantiate a server with additional parameters * Lua Action related: * `addLuaAction(x, func)`: where 'x' is all the combinations from `addPoolRule`, and func is a function with the parameter `dq`, which returns an action to be taken on this packet. + * `addLuaResponseAction(x, func)`: where 'x' is all the combinations from `addPoolRule`, and func is a + function with the parameter `dr`, which returns an action to be taken on this response packet. Good for rare packets but where you want to do a lot of processing. * Server selection policy related: * `setServerPolicy(policy)`: set server selection policy to that policy @@ -1464,6 +1469,10 @@ instantiate a server with additional parameters * member `skipCache`: whether to skip cache lookup / storing the answer for this question (settable) * member `tcp`: whether this question was received over a TCP socket * member `useECS`: whether to send ECS to the backend (settable) + * DNSResponse gets the same member than DNSQuestion, plus some: + * member `editTTLs(func)`: the function `func` is invoked for every entries in the answer, authority + and additional section taking the section number (1 for answer, 2 for authority, 3 for additional), + the qclass and qtype values and the current TTL, and returning the new TTL or 0 to leave it unchanged * DNSHeader related * member `getRD()`: get recursion desired flag * member `setRD(bool)`: set recursion desired flag diff --git a/pdns/dnsdist-console.cc b/pdns/dnsdist-console.cc index 437c4b5525..61f3b7acba 100644 --- a/pdns/dnsdist-console.cc +++ b/pdns/dnsdist-console.cc @@ -257,6 +257,7 @@ const std::vector g_consoleKeywords{ { "addDynBlocks", true, "addresses, message[, seconds]", "block the set of addresses with message `msg`, for `seconds` seconds (10 by default)" }, { "addLocal", true, "netmask, [true], [false], [TCP Fast Open queue size]", "add to addresses we listen on. Second optional parameter sets TCP or not. Third optional parameter sets SO_REUSEPORT when available. Last parameter sets the TCP Fast Open queue size, enabling TCP Fast Open when available and the value is larger than 0" }, { "addLuaAction", true, "x, func", "where 'x' is all the combinations from `addPoolRule`, and func is a function with the parameter `dq`, which returns an action to be taken on this packet. Good for rare packets but where you want to do a lot of processing" }, + { "addLuaResponseAction", true, "x, func", "where 'x' is all the combinations from `addPoolRule`, and func is a function with the parameter `dr`, which returns an action to be taken on this response packet. Good for rare packets but where you want to do a lot of processing" }, { "addNoRecurseRule", true, "domain", "clear the RD flag for all queries matching the specified domain" }, { "addPoolRule", true, "domain, pool", "send queries to this domain to that pool" }, { "addQPSLimit", true, "domain, n", "limit queries within that domain to n per second" }, diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index ecfc53e800..250440eeca 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -65,6 +65,31 @@ private: func_t d_func; }; +class LuaResponseAction : public DNSResponseAction +{ +public: + typedef std::function(DNSResponse* dr)> func_t; + LuaResponseAction(LuaResponseAction::func_t func) : d_func(func) + {} + + Action operator()(DNSResponse* dr, string* ruleresult) const override + { + std::lock_guard lock(g_luamutex); + auto ret = d_func(dr); + if(ruleresult) + *ruleresult=std::get<1>(ret); + return (Action)std::get<0>(ret); + } + + string toString() const override + { + return "Lua response script"; + } + +private: + func_t d_func; +}; + typedef boost::variant>, std::shared_ptr > luadnsrule_t; std::shared_ptr makeRule(const luadnsrule_t& var) { @@ -676,6 +701,14 @@ vector> setupLua(bool client, const std::string& confi }); }); + g_lua.writeFunction("addLuaResponseAction", [](luadnsrule_t var, LuaResponseAction::func_t func) { + setLuaSideEffect(); + auto rule=makeRule(var); + g_resprulactions.modify([rule,func](decltype(g_resprulactions)::value_type& rulactions){ + rulactions.push_back({rule, + std::make_shared(func)}); + }); + }); g_lua.writeFunction("NoRecurseAction", []() { return std::shared_ptr(new NoRecurseAction); @@ -1498,6 +1531,9 @@ vector> setupLua(bool client, const std::string& confi g_lua.registerMember("size", [](const DNSResponse& dq) -> size_t { return dq.size; }, [](DNSResponse& dq, size_t newSize) { (void) newSize; }); g_lua.registerMember("tcp", [](const DNSResponse& dq) -> bool { return dq.tcp; }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; }); g_lua.registerMember("skipCache", [](const DNSResponse& dq) -> bool { return dq.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.skipCache = newSkipCache; }); + g_lua.registerFunction editFunc)>("editTTLs", [](const DNSResponse& dr, std::function editFunc) { + editDNSPacketTTL((char*) dr.dh, dr.len, editFunc); + }); g_lua.writeFunction("setMaxTCPClientThreads", [](uint64_t max) { if (!g_configurationDone) { diff --git a/pdns/dnsparser.cc b/pdns/dnsparser.cc index ebdcf0a067..6f2eab7959 100644 --- a/pdns/dnsparser.cc +++ b/pdns/dnsparser.cc @@ -573,7 +573,11 @@ public: } void skipBytes(uint16_t bytes) { - moveOffset(bytes); + moveOffset(bytes); + } + void rewindBytes(uint16_t by) + { + rewindOffset(by); } uint32_t get32BitInt() { @@ -604,11 +608,12 @@ public: int toskip = get16BitInt(); moveOffset(toskip); } + void decreaseAndSkip32BitInt(uint32_t decrease) { const char *p = d_packet + d_offset; moveOffset(4); - + uint32_t tmp; memcpy(&tmp, (void*) p, sizeof(tmp)); tmp = ntohl(tmp); @@ -616,6 +621,13 @@ public: tmp = htonl(tmp); memcpy(d_packet + d_offset-4, (const char*)&tmp, sizeof(tmp)); } + void setAndSkip32BitInt(uint32_t value) + { + moveOffset(4); + + value = htonl(value); + memcpy(d_packet + d_offset-4, (const char*)&value, sizeof(value)); + } uint32_t getOffset() const { return d_offset; @@ -628,6 +640,16 @@ private: throw std::out_of_range("dns packet out of range: "+std::to_string(d_notyouroffset) +" > " + std::to_string(d_length) ); } + void rewindOffset(uint16_t by) + { + if(d_notyouroffset < by) + throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < " + + std::to_string(by)); + d_notyouroffset -= by; + if(d_notyouroffset < 12) + throw std::out_of_range("Rewinding dns packet out of range: "+std::to_string(d_notyouroffset) +" < " + + std::to_string(12)); + } char* d_packet; size_t d_length; @@ -636,6 +658,50 @@ private: }; +// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it +void editDNSPacketTTL(char* packet, size_t length, std::function visitor) +{ + if(length < sizeof(dnsheader)) + return; + try + { + dnsheader dh; + memcpy((void*)&dh, (const dnsheader*)packet, sizeof(dh)); + uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount); + DNSPacketMangler dpm(packet, length); + + uint64_t n; + for(n=0; n < ntohs(dh.qdcount) ; ++n) { + dpm.skipLabel(); + /* type and class */ + dpm.skipBytes(4); + } + + for(n=0; n < numrecords; ++n) { + dpm.skipLabel(); + + uint8_t section = n < dh.ancount ? 1 : (n < (dh.ancount + dh.nscount) ? 2 : 3); + uint16_t dnstype = dpm.get16BitInt(); + uint16_t dnsclass = dpm.get16BitInt(); + + if(dnstype == QType::OPT) // not getting near that one with a stick + break; + + uint32_t dnsttl = dpm.get32BitInt(); + uint32_t newttl = visitor(section, dnsclass, dnstype, dnsttl); + if (newttl) { + dpm.rewindBytes(sizeof(newttl)); + dpm.setAndSkip32BitInt(newttl); + } + dpm.skipRData(); + } + } + catch(...) + { + return; + } +} + // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it void ageDNSPacket(char* packet, size_t length, uint32_t seconds) { diff --git a/pdns/dnsparser.hh b/pdns/dnsparser.hh index e64985f62b..7c89655d15 100644 --- a/pdns/dnsparser.hh +++ b/pdns/dnsparser.hh @@ -385,6 +385,7 @@ private: string simpleCompress(const string& label, const string& root=""); void ageDNSPacket(char* packet, size_t length, uint32_t seconds); void ageDNSPacket(std::string& packet, uint32_t seconds); +void editDNSPacketTTL(char* packet, size_t length, std::function visitor); uint32_t getDNSPacketMinTTL(const char* packet, size_t length); uint32_t getDNSPacketLength(const char* packet, size_t length); uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type); diff --git a/regression-tests.dnsdist/test_Responses.py b/regression-tests.dnsdist/test_Responses.py index 096b09d7ed..1982b60d5b 100644 --- a/regression-tests.dnsdist/test_Responses.py +++ b/regression-tests.dnsdist/test_Responses.py @@ -151,3 +151,50 @@ class TestResponseRuleQNameAllowed(DNSDistTest): receivedQuery.id = query.id self.assertEquals(query, receivedQuery) self.assertEquals(receivedResponse, None) + +class TestResponseRuleEditTTL(DNSDistTest): + + _ttl = 5 + _config_params = ['_testServerPort', '_ttl'] + _config_template = """ + newServer{address="127.0.0.1:%s"} + + function editTTLCallback(section, class, type, ttl) + return %d + end + + function editTTLFunc(dr) + dr:editTTLs(editTTLCallback) + return DNSAction.None, "" + end + + addLuaResponseAction(AllRule(), editTTLFunc) + """ + + def testTTLEdited(self): + """ + Responses: Alter the TTLs + """ + name = 'editttl.responses.tests.powerdns.com.' + query = dns.message.make_query(name, 'A', 'IN') + response = dns.message.make_response(query) + rrset = dns.rrset.from_text(name, + 3600, + dns.rdataclass.IN, + dns.rdatatype.A, + '192.0.2.1') + response.answer.append(rrset) + + (receivedQuery, receivedResponse) = self.sendUDPQuery(query, response) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl) + self.assertEquals(receivedResponse.answer[0].ttl, self._ttl) + + (receivedQuery, receivedResponse) = self.sendTCPQuery(query, response) + receivedQuery.id = query.id + self.assertEquals(query, receivedQuery) + self.assertEquals(response, receivedResponse) + self.assertNotEquals(response.answer[0].ttl, receivedResponse.answer[0].ttl) + self.assertEquals(receivedResponse.answer[0].ttl, self._ttl)