#include "namespaces.hh"
-UnknownRecordContent::UnknownRecordContent(const string& zone)
+UnknownRecordContent::UnknownRecordContent(const string& zone)
{
// parse the input
vector<string> parts;
{
if (packet.size() < sizeof(dnsheader))
throw MOADNSException("Packet shorter than minimal header");
-
+
memcpy(&d_header, packet.data(), sizeof(dnsheader));
if(d_header.opcode != Opcode::Query && d_header.opcode != Opcode::Notify && d_header.opcode != Opcode::Update)
if (query && (d_header.qdcount > 1))
throw MOADNSException("Query with QD > 1 ("+std::to_string(d_header.qdcount)+")");
-
+
unsigned int n=0;
PacketReader pr(packet);
d_answers.reserve((unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount));
for(n=0;n < (unsigned int)(d_header.ancount + d_header.nscount + d_header.arcount); ++n) {
DNSRecord dr;
-
+
if(n < d_header.ancount)
dr.d_place=DNSResourceRecord::ANSWER;
else if(n < d_header.ancount + d_header.nscount)
dr.d_place=DNSResourceRecord::AUTHORITY;
- else
+ else
dr.d_place=DNSResourceRecord::ADDITIONAL;
unsigned int recordStartPos=pr.getPosition();
{
unsigned int n;
unsigned char *p=reinterpret_cast<unsigned char*>(&ah);
-
- for(n=0; n < sizeof(dnsrecordheader); ++n)
+
+ for(n=0; n < sizeof(dnsrecordheader); ++n)
p[n]=d_content.at(d_pos++);
-
+
ah.d_type=ntohs(ah.d_type);
ah.d_class=ntohs(ah.d_class);
ah.d_clen=ntohs(ah.d_clen);
ret+=static_cast<uint8_t>(d_content.at(d_pos++));
ret<<=8;
ret+=static_cast<uint8_t>(d_content.at(d_pos++));
-
+
return ret;
}
ret+=static_cast<uint8_t>(d_content.at(d_pos++));
ret<<=8;
ret+=static_cast<uint8_t>(d_content.at(d_pos++));
-
+
return ret;
}
unsigned int consumed;
try {
DNSName dn((const char*) d_content.data(), d_content.size(), d_pos, true /* uncompress */, nullptr /* qtype */, nullptr /* qclass */, &consumed, sizeof(dnsheader));
-
+
d_pos+=consumed;
return dn;
}
labellen=static_cast<uint8_t>(d_content.at(d_pos++));
else
labellen=d_recordlen - (d_pos - d_startrecordpos);
-
+
ret.append(1,'"');
if(labellen) { // no need to do anything for an empty string
string val(&d_content.at(d_pos), &d_content.at(d_pos+labellen-1)+1);
auto key = static_cast<SvcParam::SvcParamKey>(keyInt);
uint16_t len;
xfr16BitInt(len);
-
+
if (d_pos + len > (d_startrecordpos + d_recordlen)) {
throw std::out_of_range("record is shorter than SVCB lengthfield implies");
}
if(strchr(label.c_str(), '\\')) {
boost::replace_all(label, "\\.", ".");
boost::replace_all(label, "\\032", " ");
- boost::replace_all(label, "\\\\", "\\");
+ boost::replace_all(label, "\\\\", "\\");
}
typedef vector<pair<unsigned int, unsigned int> > parts_t;
parts_t parts;
}
}
+void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes)
+{
+ size_t finalsize = packet.size();
+ clearDNSPacketRecordTypes(reinterpret_cast<char*>(packet.data()), finalsize, qtypes);
+ packet.resize(finalsize);
+}
+
+// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
+void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set<QType>& qtypes)
+{
+ if (length < sizeof(dnsheader)) {
+ return;
+ }
+ try
+ {
+ dnsheader* dh = reinterpret_cast<dnsheader*>(packet);
+ uint64_t ancount = ntohs(dh->ancount);
+ uint64_t nscount = ntohs(dh->nscount);
+ uint64_t arcount = ntohs(dh->arcount);
+ DNSPacketMangler dpm(packet, length);
+
+ for(uint64_t n = 0; n < ntohs(dh->qdcount) ; ++n) {
+ dpm.skipDomainName();
+ /* type and class */
+ dpm.skipBytes(4);
+ }
+ auto processSection = [&dpm, &qtypes, &length](uint64_t& count) {
+ for (uint64_t n=0; n < count; ++n) {
+ uint32_t start = dpm.getOffset();
+ dpm.skipDomainName();
+
+ const auto dnstype = QType{dpm.get16BitInt()};
+ /* class, ttl */
+ dpm.skipBytes(6);
+ dpm.skipRData();
+ if (qtypes.find(dnstype) != qtypes.end()) {
+ uint32_t rlen = dpm.getOffset() - start;
+ dpm.rewindBytes(rlen);
+ dpm.shrinkBytes(rlen);
+ // update count
+ count -= 1;
+ length -= rlen;
+ n -= 1;
+ }
+ }
+ };
+ processSection(ancount);
+ dh->ancount = htons(ancount);
+ processSection(nscount);
+ dh->nscount = htons(nscount);
+ processSection(arcount);
+ dh->arcount = htons(arcount);
+ }
+ catch(...)
+ {
+ return;
+ }
+}
+
// method of operation: silently fail if it doesn't work - we're only trying to be nice, don't fall over on it
void ageDNSPacket(char* packet, size_t length, uint32_t seconds)
{
if(length < sizeof(dnsheader))
return;
- try
+ try
{
const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
const uint64_t dqcount = ntohs(dh->qdcount);
// cerr<<"Skipped "<<n<<" questions, now parsing "<<numrecords<<" records"<<endl;
for(n=0; n < numrecords; ++n) {
dpm.skipDomainName();
-
+
uint16_t dnstype = dpm.get16BitInt();
/* class */
dpm.skipBytes(2);
-
+
if(dnstype == QType::OPT) // not aging that one with a stick
break;
-
+
dpm.decreaseAndSkip32BitInt(seconds);
dpm.skipRData();
}
void ageDNSPacket(char* packet, size_t length, uint32_t seconds);
void ageDNSPacket(std::string& packet, uint32_t seconds);
void editDNSPacketTTL(char* packet, size_t length, const std::function<uint32_t(uint8_t, uint16_t, uint16_t, uint32_t)>& visitor);
+void clearDNSPacketRecordTypes(vector<uint8_t>& packet, const std::set<QType>& qtypes);
+void clearDNSPacketRecordTypes(char* packet, size_t& length, const std::set<QType>& qtypes);
uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA=nullptr);
uint32_t getDNSPacketLength(const char* packet, size_t length);
uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t section, uint16_t type);
return d_offset;
}
+ void shrinkBytes(uint16_t by)
+ {
+ if (d_notyouroffset + by > d_length) {
+ throw std::out_of_range("shrinking dns packet out of range: " + std::to_string(by) + " bytes at " + std::to_string(d_notyouroffset) + " for a total of " + std::to_string(d_length) );
+ }
+ memmove(d_packet + d_notyouroffset, d_packet + d_notyouroffset + by, d_length - (d_notyouroffset + by));
+ d_length -= by;
+ }
private:
void moveOffset(uint16_t by)
{
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::SOA), 0);
BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 4, QType::SOA), 0);
+ }
+
}
+BOOST_AUTO_TEST_CASE(test_clearDNSPacketRecordTypes) {
+ {
+ auto generatePacket = []() {
+ const DNSName name("powerdns.com.");
+ const ComboAddress v4("1.2.3.4");
+ const ComboAddress v6("2001:db8::1");
+
+ vector<uint8_t> packet;
+ DNSPacketWriter pwR(packet, name, QType::A, QClass::IN, 0);
+ pwR.getHeader()->qr = 1;
+ pwR.commit();
+
+ pwR.startRecord(name, QType::A, 255, QClass::IN, DNSResourceRecord::ANSWER);
+ pwR.xfrIP(v4.sin4.sin_addr.s_addr);
+ pwR.commit();
+
+ /* different type */
+ pwR.startRecord(name, QType::AAAA, 42, QClass::IN, DNSResourceRecord::ANSWER);
+ pwR.xfrIP6(std::string(reinterpret_cast<const char*>(v6.sin6.sin6_addr.s6_addr), 16));
+ pwR.commit();
+
+ pwR.startRecord(name, QType::SOA, 257, QClass::IN, DNSResourceRecord::AUTHORITY);
+ pwR.commit();
+
+ pwR.startRecord(name, QType::A, 256, QClass::IN, DNSResourceRecord::ADDITIONAL);
+ pwR.xfrIP(v4.sin4.sin_addr.s_addr);
+ pwR.commit();
+
+ pwR.addOpt(4096, 0, 0);
+ pwR.commit();
+ return packet;
+ };
+
+ auto packet = generatePacket();
+
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
+
+ std::set<QType> toremove{QType::AAAA};
+ clearDNSPacketRecordTypes(packet, toremove);
+
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
+
+ toremove = {QType::A};
+ clearDNSPacketRecordTypes(packet, toremove);
+
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
+
+ packet = generatePacket();
+
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 1);
+
+ toremove = {QType::A, QType::AAAA};
+ clearDNSPacketRecordTypes(packet, toremove);
+
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::A), 0);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 1, QType::AAAA), 0);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 2, QType::SOA), 1);
+ BOOST_CHECK_EQUAL(getRecordsOfTypeCount(reinterpret_cast<char*>(packet.data()), packet.size(), 3, QType::A), 0);
+ }
+
}
BOOST_AUTO_TEST_SUITE_END()