From: Remi Gacogne Date: Thu, 5 Oct 2023 14:38:56 +0000 (+0200) Subject: dnsdist: Prevent dnsheader alignment issues X-Git-Tag: dnsdist-1.9.0-alpha2^2~3 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=90686725bb6723ce5fc2a003573983aad59a739d;p=thirdparty%2Fpdns.git dnsdist: Prevent dnsheader alignment issues --- diff --git a/pdns/dnscrypt.cc b/pdns/dnscrypt.cc index 6db8613a3d..192dcfad18 100644 --- a/pdns/dnscrypt.cc +++ b/pdns/dnscrypt.cc @@ -399,9 +399,10 @@ bool DNSCryptQuery::parsePlaintextQuery(const PacketBuffer& packet) return false; } - const struct dnsheader * dh = reinterpret_cast(packet.data()); - if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query) + const dnsheader_aligned dh(packet.data()); + if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query) { return false; + } unsigned int qnameWireLength; uint16_t qtype, qclass; diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index 7ca9be2f6e..67de6226a0 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -252,7 +252,7 @@ bool DNSDistPacketCache::get(DNSQuestion& dq, uint16_t queryId, uint32_t* keyOut } /* check for collision */ - if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) { + if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader().get())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) { ++d_lookupCollisions; return false; } diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index 9e9d9c329e..2cad1945bc 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -21,6 +21,7 @@ */ #include "dolog.hh" #include "dnsdist.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-ecs.hh" #include "dnsparser.hh" #include "dnswriter.hh" @@ -44,13 +45,15 @@ bool g_addEDNSToSelfGeneratedResponses{true}; int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent) { assert(initialPacket.size() >= sizeof(dnsheader)); - const struct dnsheader* dh = reinterpret_cast(initialPacket.data()); + const dnsheader_aligned dh(initialPacket.data()); - if (ntohs(dh->arcount) == 0) + if (ntohs(dh->arcount) == 0) { return ENOENT; + } - if (ntohs(dh->qdcount) == 0) + if (ntohs(dh->qdcount) == 0) { return ENOENT; + } PacketReader pr(std::string_view(reinterpret_cast(initialPacket.data()), initialPacket.size())); @@ -152,7 +155,7 @@ static bool addOrReplaceEDNSOption(std::vector> bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent) { assert(initialPacket.size() >= sizeof(dnsheader)); - const struct dnsheader* dh = reinterpret_cast(initialPacket.data()); + const dnsheader_aligned dh(initialPacket.data()); if (ntohs(dh->qdcount) == 0) { return false; @@ -269,7 +272,7 @@ static bool slowParseEDNSOptions(const PacketBuffer& packet, EDNSOptionViewMap& return false; } - const struct dnsheader* dh = reinterpret_cast(packet.data()); + const dnsheader_aligned dh(packet.data()); if (ntohs(dh->qdcount) == 0) { return false; @@ -324,10 +327,11 @@ int locateEDNSOptRR(const PacketBuffer& packet, uint16_t * optStart, size_t * op assert(optStart != NULL); assert(optLen != NULL); assert(last != NULL); - const struct dnsheader* dh = reinterpret_cast(packet.data()); + const dnsheader_aligned dh(packet.data()); - if (ntohs(dh->arcount) == 0) + if (ntohs(dh->arcount) == 0) { return ENOENT; + } PacketReader pr(std::string_view(reinterpret_cast(packet.data()), packet.size())); @@ -390,14 +394,15 @@ int getEDNSOptionsStart(const PacketBuffer& packet, const size_t offset, uint16_ { assert(optRDPosition != nullptr); assert(remaining != nullptr); - const struct dnsheader* dh = reinterpret_cast(packet.data()); + const dnsheader_aligned dh(packet.data()); if (offset >= packet.size()) { return ENOENT; } - if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0) + if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0) { return ENOENT; + } size_t pos = sizeof(dnsheader) + offset; pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE; @@ -571,10 +576,12 @@ static bool addEDNSWithECS(PacketBuffer& packet, size_t maximumSize, const strin return false; } - struct dnsheader* dh = reinterpret_cast(packet.data()); - uint16_t arcount = ntohs(dh->arcount); - arcount++; - dh->arcount = htons(arcount); + dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) { + uint16_t arcount = ntohs(header.arcount); + arcount++; + header.arcount = htons(arcount); + return true; + }); ednsAdded = true; ecsAdded = true; @@ -585,7 +592,7 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, cons { assert(qnameWireLength <= packet.size()); - const struct dnsheader* dh = reinterpret_cast(packet.data()); + const dnsheader_aligned dh(packet.data()); if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) { PacketBuffer newContent; @@ -752,7 +759,7 @@ bool isEDNSOptionInOpt(const PacketBuffer& packet, const size_t optStart, const int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent) { assert(initialPacket.size() >= sizeof(dnsheader)); - const struct dnsheader* dh = reinterpret_cast(initialPacket.data()); + const dnsheader_aligned dh(initialPacket.data()); if (ntohs(dh->arcount) == 0) return ENOENT; @@ -852,8 +859,10 @@ bool addEDNS(PacketBuffer& packet, size_t maximumSize, bool dnssecOK, uint16_t p return false; } - auto dh = reinterpret_cast(packet.data()); - dh->arcount = htons(ntohs(dh->arcount) + 1); + dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) { + header.arcount = htons(ntohs(header.arcount) + 1); + return true; + }); return true; } @@ -894,17 +903,19 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, /* chop off everything after the question */ packet.resize(queryPartSize); - dh = dq.getHeader(); - if (nxd) { - dh->rcode = RCode::NXDomain; - } - else { - dh->rcode = RCode::NoError; - } - dh->qr = true; - dh->ancount = 0; - dh->nscount = 0; - dh->arcount = 0; + dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) { + if (nxd) { + header.rcode = RCode::NXDomain; + } + else { + header.rcode = RCode::NoError; + } + header.qr = true; + header.ancount = 0; + header.nscount = 0; + header.arcount = 0; + return true; + }); rdLength = htons(rdLength); ttl = htonl(ttl); @@ -934,16 +945,18 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, } packet.insert(packet.end(), soa.begin(), soa.end()); - dh = dq.getHeader(); /* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here and have EDNS added to AR afterwards */ - if (soaInAuthoritySection) { - dh->nscount = htons(1); - } else { - dh->arcount = htons(1); - } + dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) { + if (soaInAuthoritySection) { + header.nscount = htons(1); + } else { + header.arcount = htons(1); + } + return true; + }); if (hadEDNS) { /* now we need to add a new OPT record */ @@ -982,7 +995,10 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq) /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */ packet.resize(packet.size() - existingOptLen); - dq.getHeader()->arcount = 0; + dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) { + header.arcount = 0; + return true; + }); if (g_addEDNSToSelfGeneratedResponses) { /* now we need to add a new OPT record */ @@ -1107,7 +1123,10 @@ bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsDa auto& data = dq.getMutableData(); if (generateOptRR(optRData, data, dq.getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) { - dq.getHeader()->arcount = htons(1); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.arcount = htons(1); + return true; + }); // make sure that any EDNS sent by the backend is removed before forwarding the response to the client dq.ids.ednsAdded = true; } @@ -1129,17 +1148,22 @@ bool setInternalQueryRCode(InternalQueryState& state, PacketBuffer& buffer, uin hadEDNS = getEDNS0Record(buffer, edns0); } - auto dh = reinterpret_cast(buffer.data()); - dh->rcode = rcode; - dh->ad = false; - dh->aa = false; - dh->ra = dh->rd; - dh->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode,clearAnswers](dnsheader& header) { + header.rcode = rcode; + header.ad = false; + header.aa = false; + header.ra = header.rd; + header.qr = true; + + if (clearAnswers) { + header.ancount = 0; + header.nscount = 0; + header.arcount = 0; + } + return true; + }); if (clearAnswers) { - dh->ancount = 0; - dh->nscount = 0; - dh->arcount = 0; buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t)); if (hadEDNS) { DNSQuestion dq(state, buffer); diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index b81600fe8c..b9d1b5b3af 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -23,6 +23,7 @@ #include "threadname.hh" #include "dnsdist.hh" #include "dnsdist-async.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-ecs.hh" #include "dnsdist-edns.hh" #include "dnsdist-lua.hh" @@ -242,37 +243,48 @@ std::map TeeAction::getStats() const void TeeAction::worker() { setThreadName("dnsdist/TeeWork"); - char packet[1500]; - int res=0; - struct dnsheader* dh=(struct dnsheader*)packet; - for(;;) { - res=waitForData(d_fd, 0, 250000); - if(d_pleaseQuit) + std::array packet; + ssize_t res = 0; + const dnsheader_aligned dh(packet.data()); + for (;;) { + res = waitForData(d_fd, 0, 250000); + if (d_pleaseQuit) { break; - if(res < 0) { + } + + if (res < 0) { usleep(250000); continue; } - if(res==0) + if (res == 0) { continue; - res=recv(d_fd, packet, sizeof(packet), 0); - if(res <= (int)sizeof(struct dnsheader)) + } + res = recv(d_fd, packet.data(), packet.size(), 0); + if (static_cast(res) <= sizeof(struct dnsheader)) { d_recverrors++; - else + } + else { d_responses++; + } - if(dh->rcode == RCode::NoError) + if (dh->rcode == RCode::NoError) { d_noerrors++; - else if(dh->rcode == RCode::ServFail) + } + else if (dh->rcode == RCode::ServFail) { d_servfails++; - else if(dh->rcode == RCode::NXDomain) + } + else if (dh->rcode == RCode::NXDomain) { d_nxdomains++; - else if(dh->rcode == RCode::Refused) + } + else if (dh->rcode == RCode::Refused) { d_refuseds++; - else if(dh->rcode == RCode::FormErr) + } + else if (dh->rcode == RCode::FormErr) { d_formerrs++; - else if(dh->rcode == RCode::NotImp) + } + else if (dh->rcode == RCode::NotImp) { d_notimps++; + } } } @@ -343,9 +355,12 @@ public: RCodeAction(uint8_t rcode) : d_rcode(rcode) {} DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->getHeader()->rcode = d_rcode; - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) { + header.rcode = d_rcode; + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); return Action::HeaderModify; } std::string toString() const override @@ -364,10 +379,13 @@ public: ERCodeAction(uint8_t rcode) : d_rcode(rcode) {} DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->getHeader()->rcode = (d_rcode & 0xF); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) { + header.rcode = (d_rcode & 0xF); + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); dq->ednsRCode = ((d_rcode & 0xFFF0) >> 4); - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); return Action::HeaderModify; } std::string toString() const override @@ -819,7 +837,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu if (d_raw.size() >= sizeof(dnsheader)) { auto id = dq->getHeader()->id; dq->getMutableData() = d_raw; - dq->getHeader()->id = id; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [id](dnsheader& header) { + header.id = id; + return true; + }); return Action::HeaderModify; } vector addrs; @@ -875,10 +896,13 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen); // there goes your EDNS uint8_t* dest = &(data.at(sizeof(dnsheader) + qnameWireLength + 4)); - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); - dq->getHeader()->ancount = 0; - dq->getHeader()->arcount = 0; // for now, forget about your EDNS, we're marching over it + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) { + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + header.ancount = 0; + header.arcount = 0; // for now, forget about your EDNS, we're marching over it + return true; + }); uint32_t ttl = htonl(d_responseConfig.ttl); uint16_t qclass = htons(dq->ids.qclass); @@ -902,7 +926,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu memcpy(dest, recordstart, sizeof(recordstart)); dest += sizeof(recordstart); memcpy(dest, wireData.c_str(), wireData.length()); - dq->getHeader()->ancount++; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) { + header.ancount++; + return true; + }); } else if (!rawResponses.empty()) { qtype = htons(qtype); @@ -917,7 +944,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu memcpy(dest, rawResponse.c_str(), rawResponse.size()); dest += rawResponse.size(); - dq->getHeader()->ancount++; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) { + header.ancount++; + return true; + }); } raw = true; } @@ -935,11 +965,18 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu addr.sin4.sin_family == AF_INET ? reinterpret_cast(&addr.sin4.sin_addr.s_addr) : reinterpret_cast(&addr.sin6.sin6_addr.s6_addr), addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)); dest += (addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)); - dq->getHeader()->ancount++; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) { + header.ancount++; + return true; + }); } } - dq->getHeader()->ancount = htons(dq->getHeader()->ancount); + auto finalANCount = dq->getHeader()->ancount; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [finalANCount](dnsheader& header) { + header.ancount = htons(finalANCount); + return true; + }); if (hadEDNS && raw == false) { addEDNS(dq->getMutableData(), dq->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0); @@ -991,7 +1028,10 @@ public: auto& data = dq->getMutableData(); if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) { - dq->getHeader()->arcount = htons(1); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) { + header.arcount = htons(1); + return true; + }); // make sure that any EDNS sent by the backend is removed before forwarding the response to the client dq->ids.ednsAdded = true; } @@ -1036,7 +1076,10 @@ public: // this action does not stop the processing DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return Action::None; } std::string toString() const override @@ -1252,7 +1295,10 @@ public: // this action does not stop the processing DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->getHeader()->cd = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) { + header.cd = true; + return true; + }); return Action::None; } std::string toString() const override @@ -1922,8 +1968,11 @@ public: } dq->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); - dq->getHeader()->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) { + header.qr = true; // for good measure + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); return Action::HeaderModify; } @@ -2067,7 +2116,10 @@ public: return Action::None; } - setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) { + setResponseHeadersFromConfig(header, d_responseConfig); + return true; + }); return Action::Allow; } diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index 057f71c907..bf456075c7 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -36,10 +36,20 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerMember("qname", [](const DNSQuestion& dq) -> const DNSName { return dq.ids.qname; }, [](DNSQuestion& dq, const DNSName& newName) { (void) newName; }); luaCtx.registerMember("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; }); luaCtx.registerMember("qclass", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSQuestion& dq, uint16_t newClass) { (void) newClass; }); - luaCtx.registerMember("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.getHeader()->rcode = newRCode; }); + luaCtx.registerMember("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [newRCode](dnsheader& header) { + header.rcode = newRCode; + return true; + }); + }); luaCtx.registerMember("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; }); /* DNSDist DNSQuestion */ - luaCtx.registerMember("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast(dq).getHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { *(dq.getHeader()) = *dh; }); + luaCtx.registerMember("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast(dq).getMutableHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&dh](dnsheader& header) { + header = *dh; + return true; + }); + }); luaCtx.registerMember("len", [](const DNSQuestion& dq) -> uint16_t { return dq.getData().size(); }, [](DNSQuestion& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); }); luaCtx.registerMember("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; }); luaCtx.registerMember("tcp", [](const DNSQuestion& dq) -> bool { return dq.overTCP(); }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; }); @@ -100,7 +110,12 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) auto& buffer = dq.getMutableData(); buffer.clear(); buffer.insert(buffer.begin(), raw.begin(), raw.end()); - reinterpret_cast(buffer.data())->id = oldID; + + reinterpret_cast(buffer.data())->id = oldID; + dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [oldID](dnsheader& header) { + header.id = oldID; + return true; + }); }); luaCtx.registerFunction(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) { if (dq.ednsOptions == nullptr) { @@ -333,9 +348,19 @@ private: luaCtx.registerMember("qname", [](const DNSResponse& dq) -> const DNSName { return dq.ids.qname; }, [](DNSResponse& dq, const DNSName& newName) { (void) newName; }); luaCtx.registerMember("qtype", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSResponse& dq, uint16_t newType) { (void) newType; }); luaCtx.registerMember("qclass", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSResponse& dq, uint16_t newClass) { (void) newClass; }); - luaCtx.registerMember("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) { dq.getHeader()->rcode = newRCode; }); + luaCtx.registerMember("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [newRCode](dnsheader& header) { + header.rcode = newRCode; + return true; + }); + }); luaCtx.registerMember("remoteaddr", [](const DNSResponse& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSResponse& dq, const ComboAddress newRemote) { (void) newRemote; }); - luaCtx.registerMember("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast(dr).getHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { *(dr.getHeader()) = *dh; }); + luaCtx.registerMember("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast(dr).getMutableHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [&dh](dnsheader& header) { + header = *dh; + return true; + }); + }); luaCtx.registerMember("len", [](const DNSResponse& dq) -> uint16_t { return dq.getData().size(); }, [](DNSResponse& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); }); luaCtx.registerMember("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; }); luaCtx.registerMember("tcp", [](const DNSResponse& dq) -> bool { return dq.overTCP(); }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; }); @@ -355,7 +380,10 @@ private: auto& buffer = dr.getMutableData(); buffer.clear(); buffer.insert(buffer.begin(), raw.begin(), raw.end()); - reinterpret_cast(buffer.data())->id = oldID; + dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [oldID](dnsheader& header) { + header.id = oldID; + return true; + }); }); luaCtx.registerFunction(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) { diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index 3d118508cd..7b9ebc7e5f 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -26,6 +26,7 @@ #include "dnsdist.hh" #include "dnsdist-concurrent-connections.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-ecs.hh" #include "dnsdist-nghttp2-in.hh" #include "dnsdist-proxy-protocol.hh" @@ -511,7 +512,7 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe DNSResponse dr(ids, response.d_buffer, ds); dr.d_incomingTCPState = state; - memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); + memcpy(&response.d_cleartextDH, dr.getHeader().get(), sizeof(response.d_cleartextDH)); if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) { state->terminateClientConnection(); @@ -668,16 +669,19 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha { /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ - auto* dh = reinterpret_cast(query.data()); - if (!checkQueryHeaders(dh, *d_ci.cs)) { + const dnsheader_aligned dh(query.data()); + if (!checkQueryHeaders(dh.get(), *d_ci.cs)) { return QueryProcessingResult::InvalidHeaders; } if (dh->qdcount == 0) { TCPResponse response; - dh->rcode = RCode::NotImp; - dh->qr = true; auto queryID = dh->id; + dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); response.d_idstate = std::move(ids); response.d_idstate.origID = queryID; response.d_idstate.selfGenerated = true; @@ -696,8 +700,11 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha } DNSQuestion dq(ids, query); - const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); - ids.origFlags = *flags; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&ids](dnsheader& header) { + const uint16_t* flags = getFlagsFromDNSHeader(&header); + ids.origFlags = *flags; + return true; + }); dq.d_incomingTCPState = state; dq.sni = d_handler.getServerNameIndication(); @@ -714,7 +721,7 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha if (forwardViaUDPFirst()) { // if there was no EDNS, we add it with a large buffer size // so we can use UDP to talk to the backend. - auto dh = const_cast(reinterpret_cast(query.data())); + const dnsheader_aligned dh(query.data()); if (!dh->arcount) { if (addEDNS(query, 4096, false, 4096, 0)) { dq.ids.ednsAdded = true; @@ -747,15 +754,15 @@ IncomingTCPConnectionState::QueryProcessingResult IncomingTCPConnectionState::ha // the buffer might have been invalidated by now uint16_t queryID; { - const dnsheader* dh = dq.getHeader(); + const auto dh = dq.getHeader(); queryID = dh->id; } if (result == ProcessQueryResult::SendAnswer) { TCPResponse response; { - const dnsheader* dh = dq.getHeader(); - memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); + const auto dh = dq.getHeader(); + memcpy(&response.d_cleartextDH, dh.get(), sizeof(response.d_cleartextDH)); } response.d_idstate = std::move(ids); response.d_idstate.origID = queryID; diff --git a/pdns/dnsdist-xpf.cc b/pdns/dnsdist-xpf.cc index 6f4cba5315..eb2ba57855 100644 --- a/pdns/dnsdist-xpf.cc +++ b/pdns/dnsdist-xpf.cc @@ -22,6 +22,7 @@ #include "dnsdist-xpf.hh" +#include "dnsdist-dnsparser.hh" #include "dnsparser.hh" #include "xpf.hh" @@ -54,7 +55,9 @@ bool addXPF(DNSQuestion& dq, uint16_t optionCode) pos += payload.size(); (void) pos; - dq.getHeader()->arcount = htons(ntohs(dq.getHeader()->arcount) + 1); - + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.arcount = htons(ntohs(header.arcount) + 1); + return true; + }); return true; } diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index ec4ca88ac0..d0c0dc6292 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -53,6 +53,7 @@ #include "dnsdist-carbon.hh" #include "dnsdist-console.hh" #include "dnsdist-discovery.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-dynblocks.hh" #include "dnsdist-ecs.hh" #include "dnsdist-healthchecks.hh" @@ -197,8 +198,12 @@ static void truncateTC(PacketBuffer& packet, size_t maximumSize, unsigned int qn } packet.resize(static_cast(sizeof(dnsheader)+qnameWireLength+DNS_TYPE_SIZE+DNS_CLASS_SIZE)); - struct dnsheader* dh = reinterpret_cast(packet.data()); - dh->ancount = dh->arcount = dh->nscount = 0; + dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) { + header.ancount = 0; + header.arcount = 0; + header.nscount = 0; + return true; + }); if (hadEDNS) { addEDNS(packet, maximumSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0); @@ -232,8 +237,8 @@ static std::unique_ptr> g_delay{nullptr}; std::string DNSQuestion::getTrailingData() const { - const char* message = reinterpret_cast(this->getHeader()); - const uint16_t messageLen = getDNSPacketLength(message, this->data.size()); + const char* message = reinterpret_cast(this->getData().data()); + const uint16_t messageLen = getDNSPacketLength(message, this->getData().size()); return std::string(message + messageLen, this->getData().size() - messageLen); } @@ -251,6 +256,14 @@ bool DNSQuestion::setTrailingData(const std::string& tail) return true; } +bool DNSQuestion::editHeader(std::function editFunction) +{ + if (data.size() < sizeof(dnsheader)) { + throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer"); + } + return dnsdist::PacketMangling::editDNSHeaderFromPacket(data, editFunction); +} + static void doLatencyStats(dnsdist::Protocol protocol, double udiff) { constexpr auto doAvg = [](double& var, double n, double weight) { @@ -311,7 +324,7 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, return false; } - const struct dnsheader* dh = reinterpret_cast(response.data()); + const dnsheader_aligned dh(response.data()); if (dh->qr == 0) { ++dnsdist::metrics::g_stats.nonCompliantResponses; if (remote) { @@ -370,11 +383,14 @@ static void restoreFlags(struct dnsheader* dh, uint16_t origFlags) *flags |= origFlags; } -static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags) +static bool fixUpQueryTurnedResponse(DNSQuestion& dnsQuestion, const uint16_t origFlags) { - restoreFlags(dq.getHeader(), origFlags); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [origFlags](dnsheader& header) { + restoreFlags(&header, origFlags); + return true; + }); - return addEDNSToQueryTurnedResponse(dq); + return addEDNSToQueryTurnedResponse(dnsQuestion); } static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, bool* zeroScope) @@ -383,8 +399,10 @@ static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t return false; } - struct dnsheader* dh = reinterpret_cast(response.data()); - restoreFlags(dh, origFlags); + dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [origFlags](dnsheader& header) { + restoreFlags(&header, origFlags); + return true; + }); if (response.size() == sizeof(dnsheader)) { return true; @@ -422,10 +440,12 @@ static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t if (last) { /* simply remove the last AR */ response.resize(response.size() - optLen); - dh = reinterpret_cast(response.data()); - uint16_t arcount = ntohs(dh->arcount); - arcount--; - dh->arcount = htons(arcount); + dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [](dnsheader& header) { + uint16_t arcount = ntohs(header.arcount); + arcount--; + header.arcount = htons(arcount); + return true; + }); } else { /* Removing an intermediary RR could lead to compression error */ @@ -499,7 +519,10 @@ static bool applyRulesToResponse(const std::vector& r return true; break; case DNSResponseAction::Action::ServFail: - dr.getHeader()->rcode = RCode::ServFail; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) { + header.rcode = RCode::ServFail; + return true; + }); return true; break; /* non-terminal actions follow */ @@ -660,7 +683,10 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re if (ids.udpPayloadSize > 0 && response.size() > ids.udpPayloadSize) { vinfolog("Got a response of size %d while the initial UDP payload size was %d, truncating", response.size(), ids.udpPayloadSize); truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.ids.qname.wirelength()); - dr.getHeader()->tc = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) { + header.tc = true; + return true; + }); } else if (dr.getHeader()->tc && g_truncateTC) { truncateTC(response, dr.getMaximumSize(), dr.ids.qname.wirelength()); @@ -669,7 +695,7 @@ static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& re /* when the answer is encrypted in place, we need to get a copy of the original header before encryption to fill the ring buffer */ dnsheader cleartextDH; - memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); + memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH)); if (!isAsync) { if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) { @@ -759,7 +785,7 @@ void responderThread(std::shared_ptr dss) } response.resize(static_cast(got)); - dnsheader* dh = reinterpret_cast(response.data()); + const dnsheader_aligned dh(response.data()); queryId = dh->id; auto ids = dss->getState(queryId); @@ -775,7 +801,10 @@ void responderThread(std::shared_ptr dss) auto du = std::move(ids->du); - dh->id = ids->origID; + dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) { + header.id = ids->origID; + return true; + }); ++dss->responses; double udiff = ids->queryRealTime.udiff(); @@ -869,7 +898,15 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return false; } - switch(action) { + auto setRCode = [&dq](uint8_t rcode) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) { + header.rcode = rcode; + header.qr = true; + return true; + }); + }; + + switch (action) { case DNSAction::Action::Allow: return true; break; @@ -879,18 +916,15 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return true; break; case DNSAction::Action::Nxdomain: - dq.getHeader()->rcode = RCode::NXDomain; - dq.getHeader()->qr = true; + setRCode(RCode::NXDomain); return true; break; case DNSAction::Action::Refused: - dq.getHeader()->rcode = RCode::Refused; - dq.getHeader()->qr = true; + setRCode(RCode::Refused); return true; break; case DNSAction::Action::ServFail: - dq.getHeader()->rcode = RCode::ServFail; - dq.getHeader()->qr = true; + setRCode(RCode::ServFail); return true; break; case DNSAction::Action::Spoof: @@ -907,11 +941,14 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s break; case DNSAction::Action::Truncate: if (!dq.overTCP()) { - dq.getHeader()->tc = true; - dq.getHeader()->qr = true; - dq.getHeader()->ra = dq.getHeader()->rd; - dq.getHeader()->aa = false; - dq.getHeader()->ad = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.tc = true; + header.qr = true; + header.ra = header.rd; + header.aa = false; + header.ad = false; + return true; + }); ++dnsdist::metrics::g_stats.ruleTruncated; return true; } @@ -926,7 +963,10 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return true; break; case DNSAction::Action::NoRecurse: - dq.getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return true; break; /* non-terminal actions follow */ @@ -946,6 +986,14 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now) { + auto setRCode = [&dq](uint8_t rcode) { + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) { + header.rcode = rcode; + header.qr = true; + return true; + }); + }; + if (g_rings.shouldRecordQueries()) { g_rings.insertQuery(now, dq.ids.origRemote, dq.ids.qname, dq.ids.qtype, dq.getData().size(), *dq.getHeader(), dq.getProtocol()); } @@ -980,6 +1028,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru if (action == DNSAction::Action::None) { action = g_dynBlockAction; } + switch (action) { case DNSAction::Action::NoOp: /* do nothing */ @@ -989,27 +1038,28 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort()); updateBlockStats(); - dq.getHeader()->rcode = RCode::NXDomain; - dq.getHeader()->qr=true; + setRCode(RCode::NXDomain); return true; case DNSAction::Action::Refused: vinfolog("Query from %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort()); updateBlockStats(); - dq.getHeader()->rcode = RCode::Refused; - dq.getHeader()->qr = true; + setRCode(RCode::Refused); return true; case DNSAction::Action::Truncate: if (!dq.overTCP()) { updateBlockStats(); vinfolog("Query from %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort()); - dq.getHeader()->tc = true; - dq.getHeader()->qr = true; - dq.getHeader()->ra = dq.getHeader()->rd; - dq.getHeader()->aa = false; - dq.getHeader()->ad = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.tc = true; + header.qr = true; + header.ra = header.rd; + header.aa = false; + header.ad = false; + return true; + }); return true; } else { @@ -1019,7 +1069,10 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::NoRecurse: updateBlockStats(); vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort()); - dq.getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return true; default: updateBlockStats(); @@ -1048,26 +1101,27 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); updateBlockStats(); - dq.getHeader()->rcode = RCode::NXDomain; - dq.getHeader()->qr = true; + setRCode(RCode::NXDomain); return true; case DNSAction::Action::Refused: vinfolog("Query from %s for %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); updateBlockStats(); - dq.getHeader()->rcode = RCode::Refused; - dq.getHeader()->qr = true; + setRCode(RCode::Refused); return true; case DNSAction::Action::Truncate: if (!dq.overTCP()) { updateBlockStats(); vinfolog("Query from %s for %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); - dq.getHeader()->tc = true; - dq.getHeader()->qr = true; - dq.getHeader()->ra = dq.getHeader()->rd; - dq.getHeader()->aa = false; - dq.getHeader()->ad = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.tc = true; + header.qr = true; + header.ra = header.rd; + header.aa = false; + header.ad = false; + return true; + }); return true; } else { @@ -1077,7 +1131,10 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::NoRecurse: updateBlockStats(); vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort()); - dq.getHeader()->rd = false; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rd = false; + return true; + }); return true; default: updateBlockStats(); @@ -1368,7 +1425,10 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders yet, as we will do a second-lookup */ if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKey, dq.ids.subnet, dq.ids.dnssecOK, forwardedOverUDP, allowExpired, false, true, dq.ids.protocol != dnsdist::Protocol::DoH || forwardedOverUDP)) { - restoreFlags(dq.getHeader(), dq.ids.origFlags); + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [flags=dq.ids.origFlags](dnsheader& header) { + restoreFlags(&header, flags); + return true; + }); vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size()); @@ -1403,8 +1463,11 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders vinfolog("%s query for %s|%s from %s, no downstream server available", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort()); if (g_servFailOnNoPolicy) { - dq.getHeader()->rcode = RCode::ServFail; - dq.getHeader()->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) { + header.rcode = RCode::ServFail; + header.qr = true; + return true; + }); fixUpQueryTurnedResponse(dq, dq.ids.origFlags); @@ -1421,7 +1484,7 @@ ProcessQueryResult processQueryAfterRules(DNSQuestion& dq, LocalHolders& holders } /* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */ - dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader()); + dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader().get()); if (dq.addXPF && selectedBackend->d_config.xpfRRCode != 0) { addXPF(dq, selectedBackend->d_config.xpfRRCode); @@ -1647,16 +1710,20 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct { /* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */ - struct dnsheader* dh = reinterpret_cast(query.data()); + const dnsheader_aligned dh(query.data()); queryId = ntohs(dh->id); - if (!checkQueryHeaders(dh, cs)) { + if (!checkQueryHeaders(dh.get(), cs)) { return; } if (dh->qdcount == 0) { - dh->rcode = RCode::NotImp; - dh->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); + sendUDPResponse(cs.udpFD, query, 0, dest, remote); return; } @@ -1667,7 +1734,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids.protocol = dnsdist::Protocol::DNSCryptUDP; } DNSQuestion dq(ids, query); - const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); + const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader().get()); ids.origFlags = *flags; if (!proxyProtocolValues.empty()) { @@ -1682,7 +1749,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } // the buffer might have been invalidated by now (resized) - struct dnsheader* dh = dq.getHeader(); + const auto dh = dq.getHeader(); if (result == ProcessQueryResult::SendAnswer) { #ifndef DISABLE_RECVMMSG #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index dfaa3d1e58..e0e8da324e 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -89,20 +89,26 @@ struct DNSQuestion return data; } - dnsheader* getHeader() + bool editHeader(std::function editFunction); + + const dnsheader_aligned getHeader() const { if (data.size() < sizeof(dnsheader)) { throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer"); } - return reinterpret_cast(&data.at(0)); + dnsheader_aligned dh(data.data()); + return dh; } - const dnsheader* getHeader() const + /* this function is not safe against unaligned access, you should + use editHeader() instead, but we need it for the Lua bindings */ + dnsheader* getMutableHeader() { if (data.size() < sizeof(dnsheader)) { throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer"); } - return reinterpret_cast(&data.at(0)); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) + return reinterpret_cast(data.data()); } bool hasRoomFor(size_t more) const diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index f9a7bc297b..b0904b295c 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -513,6 +513,7 @@ fuzz_target_dnsdistcache_SOURCES = \ channel.hh channel.cc \ dns.cc dns.hh \ dnsdist-cache.cc dnsdist-cache.hh \ + dnsdist-dnsparser.cc dnsdist-dnsparser.hh \ dnsdist-ecs.cc dnsdist-ecs.hh \ dnsdist-idstate.hh \ dnsdist-protocols.cc dnsdist-protocols.hh \ diff --git a/pdns/dnsdistdist/dnsdist-discovery.cc b/pdns/dnsdistdist/dnsdist-discovery.cc index 5bb61ba450..889249d0c0 100644 --- a/pdns/dnsdistdist/dnsdist-discovery.cc +++ b/pdns/dnsdistdist/dnsdist-discovery.cc @@ -52,7 +52,7 @@ struct DesignatedResolvers static bool parseSVCParams(const PacketBuffer& answer, std::map& resolvers) { std::map> hints; - const struct dnsheader* dh = reinterpret_cast(answer.data()); + const dnsheader_aligned dh(answer.data()); PacketReader pr(std::string_view(reinterpret_cast(answer.data()), answer.size())); uint16_t qdcount = ntohs(dh->qdcount); uint16_t ancount = ntohs(dh->ancount); diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index 48ce507da8..70f0ff2ab0 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -129,7 +129,7 @@ int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->getHeader(); + return dq->dq->getMutableHeader(); } uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq) @@ -458,14 +458,20 @@ void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, ui #ifdef HAVE_DNS_OVER_HTTPS PacketBuffer bodyVect(body, body + bodyLen); dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType); - dq->dq->getHeader()->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) { + header.qr = true; + return true; + }); #endif } void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode) { - dq->dq->getHeader()->rcode = rcode; - dq->dq->getHeader()->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [rcode](dnsheader& header) { + header.rcode = rcode; + header.qr = true; + return true; + }); } void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len) @@ -950,11 +956,15 @@ bool dnsdist_ffi_set_answer_from_async(uint16_t asyncID, uint16_t queryID, const return false; } - auto oldId = reinterpret_cast(query->query.d_buffer.data())->id; + dnsheader_aligned alignedHeader(query->query.d_buffer.data()); + auto oldID = alignedHeader->id; query->query.d_buffer.clear(); query->query.d_buffer.insert(query->query.d_buffer.begin(), raw, raw + rawSize); - reinterpret_cast(query->query.d_buffer.data())->id = oldId; + dnsdist::PacketMangling::editDNSHeaderFromPacket(query->query.d_buffer, [oldID](dnsheader& header) { + header.id = oldID; + return true; + }); query->query.d_idstate.skipCache = true; return dnsdist::queueQueryResumptionEvent(std::move(query)); diff --git a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc index 35591ffd3e..ac1fe458cc 100644 --- a/pdns/dnsdistdist/dnsdist-nghttp2-in.cc +++ b/pdns/dnsdistdist/dnsdist-nghttp2-in.cc @@ -21,6 +21,7 @@ */ #include "base64.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-nghttp2-in.hh" #include "dnsdist-proxy-protocol.hh" #include "dnsparser.hh" @@ -197,10 +198,11 @@ void IncomingHTTP2Connection::handleResponse(const struct timeval& now, TCPRespo if (responseDH.get()->tc && state.d_packet && state.d_packet->size() > state.d_proxyProtocolPayloadSize && state.d_packet->size() - state.d_proxyProtocolPayloadSize > sizeof(dnsheader)) { vinfolog("Response received from backend %s via UDP, for query %d received from %s via DoH, is truncated, retrying over TCP", response.d_ds->getNameWithAddr(), state.d_streamID, state.origRemote.toStringWithPort()); auto& query = *state.d_packet; - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - auto* queryDH = reinterpret_cast(&query.at(state.d_proxyProtocolPayloadSize)); - /* restoring the original ID */ - queryDH->id = state.origID; + dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&query.at(state.d_proxyProtocolPayloadSize), [origID = state.origID](dnsheader& header) { + /* restoring the original ID */ + header.id = origID; + return true; + }); state.forwardedOverUDP = false; bool proxyProtocolPayloadAdded = state.d_proxyProtocolPayloadSize > 0; diff --git a/pdns/dnsdistdist/dnsdist-secpoll.cc b/pdns/dnsdistdist/dnsdist-secpoll.cc index 3bc9aeb40c..26c48ba901 100644 --- a/pdns/dnsdistdist/dnsdist-secpoll.cc +++ b/pdns/dnsdistdist/dnsdist-secpoll.cc @@ -49,7 +49,7 @@ static std::string getFirstTXTAnswer(const std::string& answer) throw std::runtime_error("Looking for a TXT record in an answer smaller than the DNS header"); } - const struct dnsheader* dh = reinterpret_cast(answer.data()); + const dnsheader_aligned dh(answer.data()); PacketReader pr(answer); uint16_t qdcount = ntohs(dh->qdcount); uint16_t ancount = ntohs(dh->ancount); diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index aef6cf6ec3..53c7bec262 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -126,7 +126,8 @@ struct TCPResponse : public TCPQuery TCPQuery(std::move(buffer), std::move(state)), d_connection(std::move(conn)), d_ds(std::move(ds)) { if (d_buffer.size() >= sizeof(dnsheader)) { - memcpy(&d_cleartextDH, reinterpret_cast(d_buffer.data()), sizeof(d_cleartextDH)); + dnsheader_aligned header(d_buffer.data()); + memcpy(&d_cleartextDH, header.get(), sizeof(d_cleartextDH)); } else { memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); @@ -137,7 +138,8 @@ struct TCPResponse : public TCPQuery TCPQuery(std::move(query)) { if (d_buffer.size() >= sizeof(dnsheader)) { - memcpy(&d_cleartextDH, reinterpret_cast(d_buffer.data()), sizeof(d_cleartextDH)); + dnsheader_aligned header(d_buffer.data()); + memcpy(&d_cleartextDH, header.get(), sizeof(d_cleartextDH)); } else { memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index 94782c8121..a3ccd80160 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -26,6 +26,7 @@ #include "dns.hh" #include "dolog.hh" #include "dnsdist-concurrent-connections.hh" +#include "dnsdist-dnsparser.hh" #include "dnsdist-ecs.hh" #include "dnsdist-metrics.hh" #include "dnsdist-proxy-protocol.hh" @@ -499,7 +500,7 @@ public: DNSResponse dr(dohUnit->ids, dohUnit->response, dohUnit->downstream); dnsheader cleartextDH{}; - memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); + memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH)); if (!response.isAsync()) { static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); @@ -716,17 +717,20 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) { /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - auto* dnsHeader = reinterpret_cast(unit->query.data()); + const dnsheader_aligned dnsHeader(unit->query.data()); - if (!checkQueryHeaders(dnsHeader, clientState)) { + if (!checkQueryHeaders(dnsHeader.get(), clientState)) { unit->status_code = 400; handleImmediateResponse(std::move(unit), "DoH invalid headers"); return; } if (dnsHeader->qdcount == 0U) { - dnsHeader->rcode = RCode::NotImp; - dnsHeader->qr = true; + dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) { + header.rcode = RCode::NotImp; + header.qr = true; + return true; + }); unit->response = std::move(unit->query); handleImmediateResponse(std::move(unit), "DoH empty query"); @@ -751,7 +755,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& unit, bool inMainThread = false) // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) ids.qname = DNSName(reinterpret_cast(unit->query.data()), static_cast(unit->query.size()), static_cast(sizeof(dnsheader)), false, &ids.qtype, &ids.qclass); DNSQuestion dnsQuestion(ids, unit->query); - const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader()); + const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader().get()); ids.origFlags = *flags; ids.cs = &clientState; dnsQuestion.sni = std::move(unit->sni); @@ -1322,9 +1326,10 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) dohUnit->query.size() > dohUnit->ids.d_proxyProtocolPayloadSize && (dohUnit->query.size() - dohUnit->ids.d_proxyProtocolPayloadSize) > sizeof(dnsheader)) { /* restoring the original ID */ - // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast) - auto* queryDH = reinterpret_cast(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize)); - queryDH->id = dohUnit->ids.origID; + dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize), [oldID=dohUnit->ids.origID](dnsheader& header) { + header.id = oldID; + return true; + }); dohUnit->ids.forwardedOverUDP = false; dohUnit->tcp = true; dohUnit->truncated = false; @@ -1645,7 +1650,7 @@ void DOHUnit::handleUDPResponse(PacketBuffer&& udpResponse, InternalQueryState&& DNSResponse dnsResponse(dohUnit->ids, udpResponse, dohUnit->downstream); dnsheader cleartextDH{}; - memcpy(&cleartextDH, dnsResponse.getHeader(), sizeof(cleartextDH)); + memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH)); dnsResponse.ids.du = std::move(dohUnit); if (!processResponse(udpResponse, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) { diff --git a/pdns/dnsdistdist/doq.cc b/pdns/dnsdistdist/doq.cc index 9489120191..e2fc597d13 100644 --- a/pdns/dnsdistdist/doq.cc +++ b/pdns/dnsdistdist/doq.cc @@ -140,7 +140,7 @@ public: DNSResponse dnsResponse(unit->ids, unit->response, unit->downstream); dnsheader cleartextDH{}; - memcpy(&cleartextDH, dnsResponse.getHeader(), sizeof(cleartextDH)); + memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH)); if (!response.isAsync()) { diff --git a/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc b/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc index 81897a340d..b886b1fd49 100644 --- a/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/test-dnsdist-lua-ffi.cc @@ -466,7 +466,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCache) ids.queryRealTime.start(); DNSQuestion dq(ids, query); packetCache->get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); - packetCache->insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); + packetCache->insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); std::string poolName("test-pool"); auto testPool = std::make_shared(); diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc index b971e4ac15..bd3f3b75fa 100644 --- a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc @@ -251,7 +251,7 @@ private: auto& query = conn->d_queries.at(frame->hd.stream_id); BOOST_REQUIRE_GT(query.size(), sizeof(dnsheader)); - auto dh = reinterpret_cast(query.data()); + const dnsheader_aligned dh(query.data()); uint16_t id = ntohs(dh->id); // cerr<<"got query ID "<(response.d_buffer.data()); + const dnsheader_aligned dh(response.d_buffer.data()); uint16_t id = ntohs(dh->id); BOOST_REQUIRE_EQUAL(id, d_id); diff --git a/pdns/dnsparser.cc b/pdns/dnsparser.cc index 7b83ad55fc..e799c0d34a 100644 --- a/pdns/dnsparser.cc +++ b/pdns/dnsparser.cc @@ -768,7 +768,7 @@ static bool checkIfPacketContainsRecords(const PacketBuffer& packet, const std:: } try { - auto dh = reinterpret_cast(packet.data()); + const dnsheader_aligned dh(packet.data()); DNSPacketMangler dpm(const_cast(reinterpret_cast(packet.data())), length); const uint16_t qdcount = ntohs(dh->qdcount); @@ -804,7 +804,7 @@ static int rewritePacketWithoutRecordTypes(const PacketBuffer& initialPacket, Pa return EINVAL; } try { - const struct dnsheader* dh = reinterpret_cast(initialPacket.data()); + const dnsheader_aligned dh(initialPacket.data()); if (ntohs(dh->qdcount) == 0) return ENOENT; @@ -979,7 +979,7 @@ uint32_t getDNSPacketMinTTL(const char* packet, size_t length, bool* seenAuthSOA } try { - const dnsheader* dh = (const dnsheader*) packet; + const dnsheader_aligned dh(packet); DNSPacketMangler dpm(const_cast(packet), length); const uint16_t qdcount = ntohs(dh->qdcount); @@ -1026,7 +1026,7 @@ uint32_t getDNSPacketLength(const char* packet, size_t length) } try { - const dnsheader* dh = reinterpret_cast(packet); + const dnsheader_aligned dh(packet); DNSPacketMangler dpm(const_cast(packet), length); const uint16_t qdcount = ntohs(dh->qdcount); @@ -1058,7 +1058,7 @@ uint16_t getRecordsOfTypeCount(const char* packet, size_t length, uint8_t sectio } try { - const dnsheader* dh = (const dnsheader*) packet; + const dnsheader_aligned dh(packet); DNSPacketMangler dpm(const_cast(packet), length); const uint16_t qdcount = ntohs(dh->qdcount); @@ -1148,7 +1148,7 @@ bool getEDNSUDPPayloadSizeAndZ(const char* packet, size_t length, uint16_t* payl try { - const dnsheader* dh = (const dnsheader*) packet; + const dnsheader_aligned dh(packet); DNSPacketMangler dpm(const_cast(packet), length); const uint16_t qdcount = ntohs(dh->qdcount); @@ -1191,13 +1191,12 @@ bool visitDNSPacket(const std::string_view& packet, const std::function(packet.data()), sizeof(dh)); - uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount); + const dnsheader_aligned dh(packet.data()); + uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount); PacketReader reader(packet); uint64_t n; - for (n = 0; n < ntohs(dh.qdcount) ; ++n) { + for (n = 0; n < ntohs(dh->qdcount) ; ++n) { (void) reader.getName(); /* type and class */ reader.skip(4); @@ -1206,7 +1205,7 @@ bool visitDNSPacket(const std::string_view& packet, const std::functionancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3); uint16_t dnstype = reader.get16BitInt(); uint16_t dnsclass = reader.get16BitInt(); diff --git a/pdns/dnstap.cc b/pdns/dnstap.cc index 212c3b5f8f..909f08e0b7 100644 --- a/pdns/dnstap.cc +++ b/pdns/dnstap.cc @@ -74,7 +74,7 @@ DnstapMessage::DnstapMessage(std::string& buffer, DnstapMessage::MessageType typ } if (packet != nullptr && len >= sizeof(dnsheader)) { - const struct dnsheader* dh = reinterpret_cast(packet); + const dnsheader_aligned dh(packet); if (!dh->qr) { pbf_message.add_bytes(DnstapMessageFields::query_message, packet, len); } else { diff --git a/pdns/protozero.cc b/pdns/protozero.cc index 6f6fcf3503..e8c1dceae7 100644 --- a/pdns/protozero.cc +++ b/pdns/protozero.cc @@ -85,7 +85,7 @@ void pdns::ProtoZero::Message::addRRsFromPacket(const char* packet, const size_t return; } - const struct dnsheader* dh = reinterpret_cast(packet); + const dnsheader_aligned dh(packet); if (ntohs(dh->ancount) == 0) { return; diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index 4bad3d0dbc..63a86fef76 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -57,7 +57,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); if (found == true) { @@ -168,7 +168,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); if (found == true) { @@ -265,7 +265,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); @@ -281,7 +281,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, !receivedOverUDP, RCode::NoError, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, !receivedOverUDP, RCode::NoError, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, !receivedOverUDP, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); @@ -328,13 +328,13 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { BOOST_CHECK(!subnet); // Insert with failure-TTL of 0 (-> should not enter cache). - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional(0)); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional(0)); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); // Insert with failure-TTL non-zero (-> should enter cache). - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional(300)); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional(300)); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); @@ -383,7 +383,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); @@ -438,7 +438,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); @@ -492,7 +492,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTruncated) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none); bool allowTruncated = true; found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true, allowTruncated); @@ -542,7 +542,7 @@ static void threadMangler(unsigned int offset) DNSQuestion dq(ids, query); g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); - g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); + g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); } } catch(PDNSException& e) { @@ -1074,7 +1074,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheXFR) { BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, ids.qtype, ids.qclass, response, receivedOverUDP, 0, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, ids.qtype, ids.qclass, response, receivedOverUDP, 0, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); BOOST_CHECK_EQUAL(found, false); }