From: Charles-Henri Bruyand Date: Fri, 3 Dec 2021 09:07:25 +0000 (+0100) Subject: dnsdist: add parser method to clear given record types in a packet X-Git-Tag: auth-4.7.0-alpha1~108^2~4 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=bd98d61bd20c114dbd4168460215bcfb8883a502;p=thirdparty%2Fpdns.git dnsdist: add parser method to clear given record types in a packet --- diff --git a/pdns/dnsparser.cc b/pdns/dnsparser.cc index 57267218db..008aca8af6 100644 --- a/pdns/dnsparser.cc +++ b/pdns/dnsparser.cc @@ -26,7 +26,7 @@ #include "namespaces.hh" -UnknownRecordContent::UnknownRecordContent(const string& zone) +UnknownRecordContent::UnknownRecordContent(const string& zone) { // parse the input vector parts; @@ -217,7 +217,7 @@ void MOADNSParser::init(bool query, const pdns_string_view& packet) { if (packet.size() < sizeof(dnsheader)) throw MOADNSException("Packet shorter than minimal header"); - + memcpy(&d_header, packet.data(), sizeof(dnsheader)); if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update) @@ -230,7 +230,7 @@ void MOADNSParser::init(bool query, const pdns_string_view& packet) if (query && (d_header.qdcount > 1)) throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")"); - + unsigned int n=0; PacketReader pr(packet); @@ -251,12 +251,12 @@ void MOADNSParser::init(bool query, const pdns_string_view& packet) d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount)); for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) { DNSRecord dr; - + if(n < d_header.ancount) dr.d_place=DNSResourceRecord::ANSWER; else if(n < d_header.ancount + d_header.nscount) dr.d_place=DNSResourceRecord::AUTHORITY; - else + else dr.d_place=DNSResourceRecord::ADDITIONAL; unsigned int recordStartPos=pr.getPosition(); @@ -347,10 +347,10 @@ void PacketReader::getDnsrecordheader(struct dnsrecordheader &ah) { unsigned int n; unsigned char *p=reinterpret_cast(&ah); - - for(n=0; n < sizeof(dnsrecordheader); ++n) + + for(n=0; n < sizeof(dnsrecordheader); ++n) p[n]=d_content.at(d_pos++); - + ah.d_type=ntohs(ah.d_type); ah.d_class=ntohs(ah.d_class); ah.d_clen=ntohs(ah.d_clen); @@ -416,7 +416,7 @@ uint32_t PacketReader::get32BitInt() ret+=static_cast(d_content.at(d_pos++)); ret<<=8; ret+=static_cast(d_content.at(d_pos++)); - + return ret; } @@ -427,7 +427,7 @@ uint16_t PacketReader::get16BitInt() ret+=static_cast(d_content.at(d_pos++)); ret<<=8; ret+=static_cast(d_content.at(d_pos++)); - + return ret; } @@ -441,7 +441,7 @@ DNSName PacketReader::getName() unsigned int consumed; try { DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed, sizeof(dnsheader)); - + d_pos+=consumed; return dn; } @@ -488,7 +488,7 @@ string PacketReader::getText(bool multi, bool lenField) labellen=static_cast(d_content.at(d_pos++)); else labellen=d_recordlen - (d_pos - d_startrecordpos); - + ret.append(1,'"'); if(labellen) { // no need to do anything for an empty string string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1); @@ -576,7 +576,7 @@ void PacketReader::xfrSvcParamKeyVals(set &kvs) { auto key = static_cast(keyInt); uint16_t len; xfr16BitInt(len); - + if (d_pos + len > (d_startrecordpos + d_recordlen)) { throw std::out_of_range("record is shorter than SVCB lengthfield implies"); } @@ -681,7 +681,7 @@ string simpleCompress(const string& elabel, const string& root) if(strchr(label.c_str(), '\\')) { boost::replace_all(label, "\\.", "."); boost::replace_all(label, "\\032", " "); - boost::replace_all(label, "\\\\", "\\"); + boost::replace_all(label, "\\\\", "\\"); } typedef vector > parts_t; parts_t parts; @@ -745,12 +745,71 @@ void editDNSPacketTTL(char* packet, size_t length, const std::function& packet, const std::set& qtypes) +{ + size_t finalsize = packet.size(); + clearDNSPacketRecordTypes(reinterpret_cast(packet.data()), finalsize, qtypes); + packet.resize(finalsize); +} + +// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it +void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set& qtypes) +{ + if (length < sizeof(dnsheader)) { + return; + } + try + { + dnsheader* dh = reinterpret_cast(packet); + uint64_t ancount = ntohs(dh->ancount); + uint64_t nscount = ntohs(dh->nscount); + uint64_t arcount = ntohs(dh->arcount); + DNSPacketMangler dpm(packet, length); + + for(uint64_t n = 0; n < ntohs(dh->qdcount) ; ++n) { + dpm.skipDomainName(); + /* type and class */ + dpm.skipBytes(4); + } + auto processSection = [&dpm, &qtypes, &length](uint64_t& count) { + for (uint64_t n=0; n < count; ++n) { + uint32_t start = dpm.getOffset(); + dpm.skipDomainName(); + + const auto dnstype = QType{dpm.get16BitInt()}; + /* class, ttl */ + dpm.skipBytes(6); + dpm.skipRData(); + if (qtypes.find(dnstype) != qtypes.end()) { + uint32_t rlen = dpm.getOffset() - start; + dpm.rewindBytes(rlen); + dpm.shrinkBytes(rlen); + // update count + count -= 1; + length -= rlen; + n -= 1; + } + } + }; + processSection(ancount); + dh->ancount = htons(ancount); + processSection(nscount); + dh->nscount = htons(nscount); + processSection(arcount); + dh->arcount = htons(arcount); + } + catch(...) + { + return; + } +} + // method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it void ageDNSPacket(char* packet, size_t length, uint32_t seconds) { if(length < sizeof(dnsheader)) return; - try + try { const dnsheader* dh = reinterpret_cast(packet); const uint64_t dqcount = ntohs(dh->qdcount); @@ -766,14 +825,14 @@ void ageDNSPacket(char* packet, size_t length, uint32_t seconds) // cerr<<"Skipped "<& visitor); +void clearDNSPacketRecordTypes(vector& packet, const std::set& qtypes); +void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set& qtypes); uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr); uint32_t getDNSPacketLength(const char* packet, size_t length); uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type); @@ -541,6 +543,14 @@ public: return d_offset; } + void shrinkBytes(uint16_t by) + { + if (d_notyouroffset + by > d_length) { + throw std::out_of_range("shrinking dns packet out of range: " + std::to_string(by) + " bytes at " + std::to_string(d_notyouroffset) + " for a total of " + std::to_string(d_length) ); + } + memmove(d_packet + d_notyouroffset, d_packet + d_notyouroffset + by, d_length - (d_notyouroffset + by)); + d_length -= by; + } private: void moveOffset(uint16_t by) { diff --git a/pdns/test-dnsparser_cc.cc b/pdns/test-dnsparser_cc.cc index e71bf19fdb..f24bb26b59 100644 --- a/pdns/test-dnsparser_cc.cc +++ b/pdns/test-dnsparser_cc.cc @@ -479,8 +479,82 @@ BOOST_AUTO_TEST_CASE(test_getRecordsOfTypeCount) { BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::SOA), 0); BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 4, QType::SOA), 0); + } + } +BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) { + { + auto generatePacket = []() { + const DNSName name("powerdns.com."); + const ComboAddress v4("1.2.3.4"); + const ComboAddress v6("2001:db8::1"); + + vector packet; + DNSPacketWriter pwR(packet, name, QType::A, QClass::IN, 0); + pwR.getHeader()->qr = 1; + pwR.commit(); + + pwR.startRecord(name, QType::A, 255, QClass::IN, DNSResourceRecord::ANSWER); + pwR.xfrIP(v4.sin4.sin_addr.s_addr); + pwR.commit(); + + /* different type */ + pwR.startRecord(name, QType::AAAA, 42, QClass::IN, DNSResourceRecord::ANSWER); + pwR.xfrIP6(std::string(reinterpret_cast(v6.sin6.sin6_addr.s6_addr), 16)); + pwR.commit(); + + pwR.startRecord(name, QType::SOA, 257, QClass::IN, DNSResourceRecord::AUTHORITY); + pwR.commit(); + + pwR.startRecord(name, QType::A, 256, QClass::IN, DNSResourceRecord::ADDITIONAL); + pwR.xfrIP(v4.sin4.sin_addr.s_addr); + pwR.commit(); + + pwR.addOpt(4096, 0, 0); + pwR.commit(); + return packet; + }; + + auto packet = generatePacket(); + + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 1); + + std::set toremove{QType::AAAA}; + clearDNSPacketRecordTypes(packet, toremove); + + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 0); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 1); + + toremove = {QType::A}; + clearDNSPacketRecordTypes(packet, toremove); + + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 0); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 0); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 0); + + packet = generatePacket(); + + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 1); + + toremove = {QType::A, QType::AAAA}; + clearDNSPacketRecordTypes(packet, toremove); + + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::A), 0); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 1, QType::AAAA), 0); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 2, QType::SOA), 1); + BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast(packet.data()), packet.size(), 3, QType::A), 0); + } + } BOOST_AUTO_TEST_SUITE_END()