]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: add non ffi interface to clear given record types in a response
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Fri, 10 Dec 2021 15:59:27 +0000 (16:59 +0100)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Thu, 16 Dec 2021 13:27:06 +0000 (14:27 +0100)
pdns/dnsdist-lua-actions.cc
pdns/dnsdistdist/docs/rules-actions.rst
pdns/dnsparser.cc
pdns/test-dnsparser_cc.cc
regression-tests.dnsdist/test_Responses.py

index a75ba3930047108b4c509da749a139ba347bc1b6..0598d222a42f660a9ca6a8eae34405a1e40e8990 100644 (file)
@@ -1720,6 +1720,32 @@ private:
   std::string d_value;
 };
 
+class ClearRecordTypesResponseAction : public DNSResponseAction, public boost::noncopyable
+{
+public:
+  ClearRecordTypesResponseAction() {}
+
+  ClearRecordTypesResponseAction(std::set<QType> qtypes) : d_qtypes(qtypes)
+  {
+  }
+
+  DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override
+  {
+    if (d_qtypes.size() > 0) {
+      clearDNSPacketRecordTypes(dr->getMutableData(), d_qtypes);
+    }
+    return DNSResponseAction::Action::HeaderModify;
+  }
+
+  std::string toString() const override
+  {
+    return "clear record types";
+  }
+
+private:
+  std::set<QType> d_qtypes{};
+};
+
 class ContinueAction : public DNSAction
 {
 public:
@@ -2234,6 +2260,19 @@ void setupLuaActions(LuaContext& luaCtx)
       return std::shared_ptr<DNSResponseAction>(new LimitTTLResponseAction(0, max));
     });
 
+  luaCtx.writeFunction("ClearRecordTypesResponseAction", [](boost::variant<int,vector<pair<int, int>>> types) {
+      std::set<QType> qtypes{};
+      if(auto t = boost::get<int>(types)) {
+        qtypes.insert(t);
+      } else {
+        const auto& v = boost::get<vector<pair<int,int>>>(types);
+        for(const auto& tpair: v) {
+          qtypes.insert(tpair.second);
+        }
+      }
+      return std::shared_ptr<DNSResponseAction>(new ClearRecordTypesResponseAction(qtypes));
+    });
+
   luaCtx.writeFunction("RCodeAction", [](uint8_t rcode, boost::optional<responseParams_t> vars) {
       auto ret = std::shared_ptr<DNSAction>(new RCodeAction(rcode));
       auto rca = std::dynamic_pointer_cast<RCodeAction>(ret);
index 997d785a6b1f147df91a58a7304e4ccaa74cfc07..582a2505cad66bb2f7595fb2f3134746fe95fc00 100644 (file)
@@ -838,6 +838,14 @@ The following actions exist.
 
   Let these packets go through.
 
+.. function::ClearRecordTypesResponseAction(types)
+
+  .. versionadded:: 1.8.0
+
+  Removes given type(s) records from the response.
+
+  :param int types: a single type or a list of types to remove
+
 .. function:: ContinueAction(action)
 
   .. versionadded:: 1.4.0
index a2a6634dbfd313408dd7ada81bdfc760b5fc3942..b5922059960a8676a1f5db2857be30f4f6222527 100644 (file)
@@ -745,69 +745,132 @@ void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(
   }
 }
 
-void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes)
-{
-  size_t finalsize = packet.size();
-  clearDNSPacketRecordTypes(reinterpret_cast<char*>(packet.data()), finalsize, qtypes);
-  packet.resize(finalsize);
-}
-
-void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::set<QType>& qtypes)
+int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, PacketBuffer& newContent, const std::set<QType>& qtypes)
 {
-  size_t finalsize = packet.size();
-  clearDNSPacketRecordTypes(reinterpret_cast<char*>(packet.data()), finalsize, qtypes);
-  packet.resize(finalsize);
-}
+  static const std::set<QType>& safeTypes{QType::A, QType::AAAA, QType::DHCID, QType::TXT, QType::LUA, QType::ENT, QType::OPT, QType::HINFO, QType::DNSKEY, QType::CDNSKEY, QType::DS, QType::CDS, QType::DLV, QType::SSHFP, QType::KEY, QType::CERT, QType::TLSA, QType::SMIMEA, QType::OPENPGPKEY, QType::SVCB, QType::HTTPS, QType::NSEC3, QType::CSYNC, QType::NSEC3PARAM, QType::LOC, QType::NID, QType::L32, QType::L64, QType::EUI48, QType::EUI64, QType::URI, QType::CAA};
 
-// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
-void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set<QType>& qtypes)
-{
-  if (length < sizeof(dnsheader)) {
-    return;
+  if (initialPacket.size() < sizeof(dnsheader)) {
+    return EINVAL;
   }
-  try
-  {
-    dnsheader* dh = reinterpret_cast<dnsheader*>(packet);
-    uint64_t ancount = ntohs(dh->ancount);
-    uint64_t nscount = ntohs(dh->nscount);
-    uint64_t arcount = ntohs(dh->arcount);
-    DNSPacketMangler dpm(packet, length);
+  try {
+    const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+
+    if (ntohs(dh->qdcount) == 0)
+      return ENOENT;
+    auto packetView = pdns_string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size());
+
+    PacketReader pr(packetView);
+
+    size_t idx = 0;
+    DNSName rrname;
+    uint16_t qdcount = ntohs(dh->qdcount);
+    uint16_t ancount = ntohs(dh->ancount);
+    uint16_t nscount = ntohs(dh->nscount);
+    uint16_t arcount = ntohs(dh->arcount);
+    uint16_t rrtype;
+    uint16_t rrclass;
+    string blob;
+    struct dnsrecordheader ah;
 
-    for(uint64_t n = 0; n < ntohs(dh->qdcount) ; ++n) {
-      dpm.skipDomainName();
-      /* type and class */
-      dpm.skipBytes(4);
+    rrname = pr.getName();
+    rrtype = pr.get16BitInt();
+    rrclass = pr.get16BitInt();
+
+    GenericDNSPacketWriter<PacketBuffer> pw(newContent, rrname, rrtype, rrclass, dh->opcode);
+    pw.getHeader()->id=dh->id;
+    pw.getHeader()->qr=dh->qr;
+    pw.getHeader()->aa=dh->aa;
+    pw.getHeader()->tc=dh->tc;
+    pw.getHeader()->rd=dh->rd;
+    pw.getHeader()->ra=dh->ra;
+    pw.getHeader()->ad=dh->ad;
+    pw.getHeader()->cd=dh->cd;
+    pw.getHeader()->rcode=dh->rcode;
+
+    /* consume remaining qd if any */
+    if (qdcount > 1) {
+      for(idx = 1; idx < qdcount; idx++) {
+        rrname = pr.getName();
+        rrtype = pr.get16BitInt();
+        rrclass = pr.get16BitInt();
+        (void) rrtype;
+        (void) rrclass;
+      }
+    }
+
+    /* copy AN */
+    for (idx = 0; idx < ancount; idx++) {
+      rrname = pr.getName();
+      pr.getDnsrecordheader(ah);
+      pr.xfrBlob(blob);
+
+      if (qtypes.find(ah.d_type) == qtypes.end()) {
+        // if this is not a safe type
+        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
+          // "unsafe" types might countain compressed data, so cancel rewrite
+          newContent.clear();
+          return EIO;
+        }
+        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ANSWER, true);
+        pw.xfrBlob(blob);
+      }
+    }
+
+    /* copy NS */
+    for (idx = 0; idx < nscount; idx++) {
+      rrname = pr.getName();
+      pr.getDnsrecordheader(ah);
+      pr.xfrBlob(blob);
+
+      if (qtypes.find(ah.d_type) == qtypes.end()) {
+        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
+          // "unsafe" types might countain compressed data, so cancel rewrite
+          newContent.clear();
+          return EIO;
+        }
+        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::AUTHORITY, true);
+        pw.xfrBlob(blob);
+      }
     }
-    auto processSection = [&dpm, &qtypes, &length](uint64_t& count) {
-      for (uint64_t n=0; n < count; ++n) {
-        uint32_t start = dpm.getOffset();
-        dpm.skipDomainName();
-
-        const auto dnstype = QType{dpm.get16BitInt()};
-        /* class, ttl */
-        dpm.skipBytes(6);
-        dpm.skipRData();
-        if (qtypes.find(dnstype) != qtypes.end()) {
-          uint32_t rlen = dpm.getOffset() - start;
-          dpm.rewindBytes(rlen);
-          dpm.shrinkBytes(rlen);
-          // update count
-          count -= 1;
-          length -= rlen;
-          n -= 1;
+    /* copy AR */
+    for (idx = 0; idx < arcount; idx++) {
+      rrname = pr.getName();
+      pr.getDnsrecordheader(ah);
+      pr.xfrBlob(blob);
+
+      if (qtypes.find(ah.d_type) == qtypes.end()) {
+        if (safeTypes.find(ah.d_type) == safeTypes.end()) {
+          // "unsafe" types might countain compressed data, so cancel rewrite
+          newContent.clear();
+          return EIO;
         }
+        pw.startRecord(rrname, ah.d_type, ah.d_ttl, ah.d_class, DNSResourceRecord::ADDITIONAL, true);
+        pw.xfrBlob(blob);
       }
-    };
-    processSection(ancount);
-    dh->ancount = htons(ancount);
-    processSection(nscount);
-    dh->nscount = htons(nscount);
-    processSection(arcount);
-    dh->arcount = htons(arcount);
+    }
+    pw.commit();
+
   }
-  catch(...)
+  catch (...)
   {
-    return;
+    newContent.clear();
+    return EIO;
+  }
+  return 0;
+}
+
+void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes)
+{
+  return clearDNSPacketRecordTypes(reinterpret_cast<PacketBuffer&>(packet), qtypes);
+}
+
+void clearDNSPacketRecordTypes(PacketBuffer& packet, const std::set<QType>& qtypes)
+{
+  PacketBuffer newContent;
+
+  auto result = rewritePacketWithoutRecordTypes(packet, newContent, qtypes);
+  if (!result) {
+    packet = std::move(newContent);
   }
 }
 
index f24bb26b599d3778899ae3c8872311c59be61b09..2ec3802a5e7b8373ed6875e74090aac7d70721a4 100644 (file)
@@ -504,9 +504,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) {
       pwR.xfrIP6(std::string(reinterpret_cast<const char*>(v6.sin6.sin6_addr.s6_addr), 16));
       pwR.commit();
 
-      pwR.startRecord(name, QType::SOA, 257, QClass::IN, DNSResourceRecord::AUTHORITY);
-      pwR.commit();
-
       pwR.startRecord(name, QType::A, 256, QClass::IN, DNSResourceRecord::ADDITIONAL);
       pwR.xfrIP(v4.sin4.sin_addr.s_addr);
       pwR.commit();
@@ -520,7 +517,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) {
 
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
-    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
 
     std::set<QType> toremove{QType::AAAA};
@@ -528,7 +524,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) {
 
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
-    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
 
     toremove = {QType::A};
@@ -536,14 +531,12 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) {
 
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
-    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
 
     packet = generatePacket();
 
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
-    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
 
     toremove = {QType::A, QType::AAAA};
@@ -551,7 +544,6 @@ BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) {
 
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
-    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
     BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
   }
 
index cd76fbe815613ef749afffc7ac2c957466159db2..301048b164462ee92977c0c214053d2bbb4472b1 100644 (file)
@@ -414,7 +414,8 @@ class TestResponseClearRecordsType(DNSDistTest):
 
     newServer{address="127.0.0.1:%s"}
 
-    addResponseAction("ffi.clear-records-type.responses.tests.powerdns.com.", LuaResponseAction(luafct))
+    addResponseAction("ffi.clear-records-type.responses.tests.powerdns.com.", LuaFFIResponseAction(luafct))
+    addResponseAction("clear-records-type.responses.tests.powerdns.com.", ClearRecordTypesResponseAction(DNSQType.AAAA))
     """
 
     def testClearedFFI(self):
@@ -444,3 +445,31 @@ class TestResponseClearRecordsType(DNSDistTest):
             receivedQuery.id = query.id
             self.assertEqual(query, receivedQuery)
             self.assertEqual(expectedResponse, receivedResponse)
+
+    def testCleared(self):
+        """
+        Responses: Removes records of a given type
+        """
+        name = 'clear-records-type.responses.tests.powerdns.com.'
+        query = dns.message.make_query(name, 'A', 'IN')
+        response = dns.message.make_response(query)
+        expectedResponse = dns.message.make_response(query)
+        rrset = dns.rrset.from_text(name,
+                                    3600,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.A,
+                                    '192.0.2.1')
+        response.answer.append(rrset)
+        expectedResponse.answer.append(rrset)
+        rrset = dns.rrset.from_text(name,
+                                    3660,
+                                    dns.rdataclass.IN,
+                                    dns.rdatatype.AAAA,
+                                    '2001:DB8::1', '2001:DB8::2')
+        response.answer.append(rrset)
+        for method in ("sendUDPQuery", "sendTCPQuery"):
+            sender = getattr(self, method)
+            (receivedQuery, receivedResponse) = sender(query, response)
+            receivedQuery.id = query.id
+            self.assertEqual(query, receivedQuery)
+            self.assertEqual(expectedResponse, receivedResponse)