]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: add parser method to clear given record types in a packet
authorCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Fri, 3 Dec 2021 09:07:25 +0000 (10:07 +0100)
committerCharles-Henri Bruyand <charles-henri.bruyand@open-xchange.com>
Thu, 16 Dec 2021 13:26:04 +0000 (14:26 +0100)
pdns/dnsparser.cc
pdns/dnsparser.hh
pdns/test-dnsparser_cc.cc

index 57267218db870f7f59abcf33ccf747544c5dc6f7..008aca8af65b60e5df718cf63a1f4cecfc33e1b2 100644 (file)
@@ -26,7 +26,7 @@
 
 #include "namespaces.hh"
 
-UnknownRecordContent::UnknownRecordContent(const string& zone) 
+UnknownRecordContent::UnknownRecordContent(const string& zone)
 {
   // parse the input
   vector<string> parts;
@@ -217,7 +217,7 @@ void MOADNSParser::init(bool query, const pdns_string_view& packet)
 {
   if (packet.size() < sizeof(dnsheader))
     throw MOADNSException("Packet shorter than minimal header");
-  
+
   memcpy(&d_header, packet.data(), sizeof(dnsheader));
 
   if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
@@ -230,7 +230,7 @@ void MOADNSParser::init(bool query, const pdns_string_view& packet)
 
   if (query && (d_header.qdcount > 1))
     throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
-  
+
   unsigned int n=0;
 
   PacketReader pr(packet);
@@ -251,12 +251,12 @@ void MOADNSParser::init(bool query, const pdns_string_view& packet)
     d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
     for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
       DNSRecord dr;
-      
+
       if(n < d_header.ancount)
         dr.d_place=DNSResourceRecord::ANSWER;
       else if(n < d_header.ancount + d_header.nscount)
         dr.d_place=DNSResourceRecord::AUTHORITY;
-      else 
+      else
         dr.d_place=DNSResourceRecord::ADDITIONAL;
 
       unsigned int recordStartPos=pr.getPosition();
@@ -347,10 +347,10 @@ void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah)
 {
   unsigned int n;
   unsigned char *p=reinterpret_cast<unsigned char*>(&ah);
-  
-  for(n=0; n < sizeof(dnsrecordheader); ++n) 
+
+  for(n=0; n < sizeof(dnsrecordheader); ++n)
     p[n]=d_content.at(d_pos++);
-  
+
   ah.d_type=ntohs(ah.d_type);
   ah.d_class=ntohs(ah.d_class);
   ah.d_clen=ntohs(ah.d_clen);
@@ -416,7 +416,7 @@ uint32_t PacketReader::get32BitInt()
   ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
   ret+=static_cast<uint8_t>(d_content.at(d_pos++));
-  
+
   return ret;
 }
 
@@ -427,7 +427,7 @@ uint16_t PacketReader::get16BitInt()
   ret+=static_cast<uint8_t>(d_content.at(d_pos++));
   ret<<=8;
   ret+=static_cast<uint8_t>(d_content.at(d_pos++));
-  
+
   return ret;
 }
 
@@ -441,7 +441,7 @@ DNSName PacketReader::getName()
   unsigned int consumed;
   try {
     DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed, sizeof(dnsheader));
-    
+
     d_pos+=consumed;
     return dn;
   }
@@ -488,7 +488,7 @@ string PacketReader::getText(bool multi, bool lenField)
       labellen=static_cast<uint8_t>(d_content.at(d_pos++));
     else
       labellen=d_recordlen - (d_pos - d_startrecordpos);
-    
+
     ret.append(1,'"');
     if(labellen) { // no need to do anything for an empty string
       string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
@@ -576,7 +576,7 @@ void PacketReader::xfrSvcParamKeyVals(set<SvcParam> &kvs) {
     auto key = static_cast<SvcParam::SvcParamKey>(keyInt);
     uint16_t len;
     xfr16BitInt(len);
-    
+
     if (d_pos + len > (d_startrecordpos + d_recordlen)) {
       throw std::out_of_range("record is shorter than SVCB lengthfield implies");
     }
@@ -681,7 +681,7 @@ string simpleCompress(const string& elabel, const string& root)
   if(strchr(label.c_str(), '\\')) {
     boost::replace_all(label, "\\.", ".");
     boost::replace_all(label, "\\032", " ");
-    boost::replace_all(label, "\\\\", "\\");   
+    boost::replace_all(label, "\\\\", "\\");
   }
   typedef vector<pair<unsigned int, unsigned int> > parts_t;
   parts_t parts;
@@ -745,12 +745,71 @@ void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(
   }
 }
 
+void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes)
+{
+  size_t finalsize = packet.size();
+  clearDNSPacketRecordTypes(reinterpret_cast<char*>(packet.data()), finalsize, qtypes);
+  packet.resize(finalsize);
+}
+
+// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
+void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set<QType>& qtypes)
+{
+  if (length < sizeof(dnsheader)) {
+    return;
+  }
+  try
+  {
+    dnsheader* dh = reinterpret_cast<dnsheader*>(packet);
+    uint64_t ancount = ntohs(dh->ancount);
+    uint64_t nscount = ntohs(dh->nscount);
+    uint64_t arcount = ntohs(dh->arcount);
+    DNSPacketMangler dpm(packet, length);
+
+    for(uint64_t n = 0; n < ntohs(dh->qdcount) ; ++n) {
+      dpm.skipDomainName();
+      /* type and class */
+      dpm.skipBytes(4);
+    }
+    auto processSection = [&dpm, &qtypes, &length](uint64_t& count) {
+      for (uint64_t n=0; n < count; ++n) {
+        uint32_t start = dpm.getOffset();
+        dpm.skipDomainName();
+
+        const auto dnstype = QType{dpm.get16BitInt()};
+        /* class, ttl */
+        dpm.skipBytes(6);
+        dpm.skipRData();
+        if (qtypes.find(dnstype) != qtypes.end()) {
+          uint32_t rlen = dpm.getOffset() - start;
+          dpm.rewindBytes(rlen);
+          dpm.shrinkBytes(rlen);
+          // update count
+          count -= 1;
+          length -= rlen;
+          n -= 1;
+        }
+      }
+    };
+    processSection(ancount);
+    dh->ancount = htons(ancount);
+    processSection(nscount);
+    dh->nscount = htons(nscount);
+    processSection(arcount);
+    dh->arcount = htons(arcount);
+  }
+  catch(...)
+  {
+    return;
+  }
+}
+
 // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
 void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
 {
   if(length < sizeof(dnsheader))
     return;
-  try 
+  try
   {
     const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
     const uint64_t dqcount = ntohs(dh->qdcount);
@@ -766,14 +825,14 @@ void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
    // cerr<<"Skipped "<<n<<" questions, now parsing "<<numrecords<<" records"<<endl;
     for(n=0; n < numrecords; ++n) {
       dpm.skipDomainName();
-      
+
       uint16_t dnstype = dpm.get16BitInt();
       /* class */
       dpm.skipBytes(2);
-      
+
       if(dnstype == QType::OPT) // not aging that one with a stick
         break;
-      
+
       dpm.decreaseAndSkip32BitInt(seconds);
       dpm.skipRData();
     }
index 9cbe5a4e2bac8b80ed0eb318fd8becd380183e65..250e1223ba7e18c52b2f5938dfb489662d4c3f92 100644 (file)
@@ -434,6 +434,8 @@ string simpleCompress(const string& label, const string& root="");
 void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
 void ageDNSPacket(std::string& packet, uint32_t seconds);
 void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor);
+void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes);
+void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set<QType>& qtypes);
 uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr);
 uint32_t getDNSPacketLength(const char* packet, size_t length);
 uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
@@ -541,6 +543,14 @@ public:
     return d_offset;
   }
 
+  void shrinkBytes(uint16_t by)
+  {
+    if (d_notyouroffset + by > d_length) {
+      throw std::out_of_range("shrinking dns packet out of range: " + std::to_string(by) + " bytes at " + std::to_string(d_notyouroffset) + " for a total of " + std::to_string(d_length) );
+    }
+    memmove(d_packet + d_notyouroffset, d_packet + d_notyouroffset + by, d_length - (d_notyouroffset + by));
+    d_length -= by;
+  }
 private:
   void moveOffset(uint16_t by)
   {
index e71bf19fdb002e7fea2c5ab8a494985ca6986b2c..f24bb26b599d3778899ae3c8872311c59be61b09 100644 (file)
@@ -479,8 +479,82 @@ BOOST_AUTO_TEST_CASE(test_getRecordsOfTypeCount) {
      BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::SOA), 0);
 
      BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 4, QType::SOA), 0);
+  }
+
 }
 
+BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) {
+  {
+    auto generatePacket = []() {
+      const DNSName name("powerdns.com.");
+      const ComboAddress v4("1.2.3.4");
+      const ComboAddress v6("2001:db8::1");
+
+      vector<uint8_t> packet;
+      DNSPacketWriter pwR(packet, name, QType::A, QClass::IN, 0);
+      pwR.getHeader()->qr = 1;
+      pwR.commit();
+
+      pwR.startRecord(name, QType::A, 255, QClass::IN, DNSResourceRecord::ANSWER);
+      pwR.xfrIP(v4.sin4.sin_addr.s_addr);
+      pwR.commit();
+
+      /* different type */
+      pwR.startRecord(name, QType::AAAA, 42, QClass::IN, DNSResourceRecord::ANSWER);
+      pwR.xfrIP6(std::string(reinterpret_cast<const char*>(v6.sin6.sin6_addr.s6_addr), 16));
+      pwR.commit();
+
+      pwR.startRecord(name, QType::SOA, 257, QClass::IN, DNSResourceRecord::AUTHORITY);
+      pwR.commit();
+
+      pwR.startRecord(name, QType::A, 256, QClass::IN, DNSResourceRecord::ADDITIONAL);
+      pwR.xfrIP(v4.sin4.sin_addr.s_addr);
+      pwR.commit();
+
+      pwR.addOpt(4096, 0, 0);
+      pwR.commit();
+      return packet;
+    };
+
+    auto packet = generatePacket();
+
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
+
+    std::set<QType> toremove{QType::AAAA};
+    clearDNSPacketRecordTypes(packet, toremove);
+
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
+
+    toremove = {QType::A};
+    clearDNSPacketRecordTypes(packet, toremove);
+
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
+
+    packet = generatePacket();
+
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
+
+    toremove = {QType::A, QType::AAAA};
+    clearDNSPacketRecordTypes(packet, toremove);
+
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+    BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
+  }
+
 }
 
 BOOST_AUTO_TEST_SUITE_END()