return false;
}
- const struct dnsheader * dh = reinterpret_cast<const struct dnsheader *>(packet.data());
- if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query)
+ const dnsheader_aligned dh(packet.data());
+ if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query) {
return false;
+ }
unsigned int qnameWireLength;
uint16_t qtype, qclass;
}
/* check for collision */
- if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) {
+ if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader().get())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) {
++d_lookupCollisions;
return false;
}
*/
#include "dolog.hh"
#include "dnsdist.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-ecs.hh"
#include "dnsparser.hh"
#include "dnswriter.hh"
int rewriteResponseWithoutEDNS(const PacketBuffer& initialPacket, PacketBuffer& newContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+ const dnsheader_aligned dh(initialPacket.data());
- if (ntohs(dh->arcount) == 0)
+ if (ntohs(dh->arcount) == 0) {
return ENOENT;
+ }
- if (ntohs(dh->qdcount) == 0)
+ if (ntohs(dh->qdcount) == 0) {
return ENOENT;
+ }
PacketReader pr(std::string_view(reinterpret_cast<const char*>(initialPacket.data()), initialPacket.size()));
bool slowRewriteEDNSOptionInQueryWithRecords(const PacketBuffer& initialPacket, PacketBuffer& newContent, bool& ednsAdded, uint16_t optionToReplace, bool& optionAdded, bool overrideExisting, const string& newOptionContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+ const dnsheader_aligned dh(initialPacket.data());
if (ntohs(dh->qdcount) == 0) {
return false;
return false;
}
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
if (ntohs(dh->qdcount) == 0) {
return false;
assert(optStart != NULL);
assert(optLen != NULL);
assert(last != NULL);
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
- if (ntohs(dh->arcount) == 0)
+ if (ntohs(dh->arcount) == 0) {
return ENOENT;
+ }
PacketReader pr(std::string_view(reinterpret_cast<const char*>(packet.data()), packet.size()));
{
assert(optRDPosition != nullptr);
assert(remaining != nullptr);
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
if (offset >= packet.size()) {
return ENOENT;
}
- if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0)
+ if (ntohs(dh->qdcount) != 1 || ntohs(dh->ancount) != 0 || ntohs(dh->arcount) != 1 || ntohs(dh->nscount) != 0) {
return ENOENT;
+ }
size_t pos = sizeof(dnsheader) + offset;
pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE;
return false;
}
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
- uint16_t arcount = ntohs(dh->arcount);
- arcount++;
- dh->arcount = htons(arcount);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+ uint16_t arcount = ntohs(header.arcount);
+ arcount++;
+ header.arcount = htons(arcount);
+ return true;
+ });
ednsAdded = true;
ecsAdded = true;
{
assert(qnameWireLength <= packet.size());
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) {
PacketBuffer newContent;
int rewriteResponseWithoutEDNSOption(const PacketBuffer& initialPacket, const uint16_t optionCodeToSkip, PacketBuffer& newContent)
{
assert(initialPacket.size() >= sizeof(dnsheader));
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+ const dnsheader_aligned dh(initialPacket.data());
if (ntohs(dh->arcount) == 0)
return ENOENT;
return false;
}
- auto dh = reinterpret_cast<dnsheader*>(packet.data());
- dh->arcount = htons(ntohs(dh->arcount) + 1);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+ header.arcount = htons(ntohs(header.arcount) + 1);
+ return true;
+ });
return true;
}
/* chop off everything after the question */
packet.resize(queryPartSize);
- dh = dq.getHeader();
- if (nxd) {
- dh->rcode = RCode::NXDomain;
- }
- else {
- dh->rcode = RCode::NoError;
- }
- dh->qr = true;
- dh->ancount = 0;
- dh->nscount = 0;
- dh->arcount = 0;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [nxd](dnsheader& header) {
+ if (nxd) {
+ header.rcode = RCode::NXDomain;
+ }
+ else {
+ header.rcode = RCode::NoError;
+ }
+ header.qr = true;
+ header.ancount = 0;
+ header.nscount = 0;
+ header.arcount = 0;
+ return true;
+ });
rdLength = htons(rdLength);
ttl = htonl(ttl);
}
packet.insert(packet.end(), soa.begin(), soa.end());
- dh = dq.getHeader();
/* We are populating a response with only the query in place, order of sections is QD,AN,NS,AR
NS (authority) is before AR (additional) so we can just decide which section the SOA record is in here
and have EDNS added to AR afterwards */
- if (soaInAuthoritySection) {
- dh->nscount = htons(1);
- } else {
- dh->arcount = htons(1);
- }
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [soaInAuthoritySection](dnsheader& header) {
+ if (soaInAuthoritySection) {
+ header.nscount = htons(1);
+ } else {
+ header.arcount = htons(1);
+ }
+ return true;
+ });
if (hadEDNS) {
/* now we need to add a new OPT record */
/* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */
packet.resize(packet.size() - existingOptLen);
- dq.getHeader()->arcount = 0;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+ header.arcount = 0;
+ return true;
+ });
if (g_addEDNSToSelfGeneratedResponses) {
/* now we need to add a new OPT record */
auto& data = dq.getMutableData();
if (generateOptRR(optRData, data, dq.getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
- dq.getHeader()->arcount = htons(1);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.arcount = htons(1);
+ return true;
+ });
// make sure that any EDNS sent by the backend is removed before forwarding the response to the client
dq.ids.ednsAdded = true;
}
hadEDNS = getEDNS0Record(buffer, edns0);
}
- auto dh = reinterpret_cast<dnsheader*>(buffer.data());
- dh->rcode = rcode;
- dh->ad = false;
- dh->aa = false;
- dh->ra = dh->rd;
- dh->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [rcode,clearAnswers](dnsheader& header) {
+ header.rcode = rcode;
+ header.ad = false;
+ header.aa = false;
+ header.ra = header.rd;
+ header.qr = true;
+
+ if (clearAnswers) {
+ header.ancount = 0;
+ header.nscount = 0;
+ header.arcount = 0;
+ }
+ return true;
+ });
if (clearAnswers) {
- dh->ancount = 0;
- dh->nscount = 0;
- dh->arcount = 0;
buffer.resize(sizeof(dnsheader) + qnameLength + sizeof(uint16_t) + sizeof(uint16_t));
if (hadEDNS) {
DNSQuestion dq(state, buffer);
#include "threadname.hh"
#include "dnsdist.hh"
#include "dnsdist-async.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-edns.hh"
#include "dnsdist-lua.hh"
void TeeAction::worker()
{
setThreadName("dnsdist/TeeWork");
- char packet[1500];
- int res=0;
- struct dnsheader* dh=(struct dnsheader*)packet;
- for(;;) {
- res=waitForData(d_fd, 0, 250000);
- if(d_pleaseQuit)
+ std::array<char, 1500> packet;
+ ssize_t res = 0;
+ const dnsheader_aligned dh(packet.data());
+ for (;;) {
+ res = waitForData(d_fd, 0, 250000);
+ if (d_pleaseQuit) {
break;
- if(res < 0) {
+ }
+
+ if (res < 0) {
usleep(250000);
continue;
}
- if(res==0)
+ if (res == 0) {
continue;
- res=recv(d_fd, packet, sizeof(packet), 0);
- if(res <= (int)sizeof(struct dnsheader))
+ }
+ res = recv(d_fd, packet.data(), packet.size(), 0);
+ if (static_cast<size_t>(res) <= sizeof(struct dnsheader)) {
d_recverrors++;
- else
+ }
+ else {
d_responses++;
+ }
- if(dh->rcode == RCode::NoError)
+ if (dh->rcode == RCode::NoError) {
d_noerrors++;
- else if(dh->rcode == RCode::ServFail)
+ }
+ else if (dh->rcode == RCode::ServFail) {
d_servfails++;
- else if(dh->rcode == RCode::NXDomain)
+ }
+ else if (dh->rcode == RCode::NXDomain) {
d_nxdomains++;
- else if(dh->rcode == RCode::Refused)
+ }
+ else if (dh->rcode == RCode::Refused) {
d_refuseds++;
- else if(dh->rcode == RCode::FormErr)
+ }
+ else if (dh->rcode == RCode::FormErr) {
d_formerrs++;
- else if(dh->rcode == RCode::NotImp)
+ }
+ else if (dh->rcode == RCode::NotImp) {
d_notimps++;
+ }
}
}
RCodeAction(uint8_t rcode) : d_rcode(rcode) {}
DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
{
- dq->getHeader()->rcode = d_rcode;
- dq->getHeader()->qr = true; // for good measure
- setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+ header.rcode = d_rcode;
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
return Action::HeaderModify;
}
std::string toString() const override
ERCodeAction(uint8_t rcode) : d_rcode(rcode) {}
DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
{
- dq->getHeader()->rcode = (d_rcode & 0xF);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+ header.rcode = (d_rcode & 0xF);
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
dq->ednsRCode = ((d_rcode & 0xFFF0) >> 4);
- dq->getHeader()->qr = true; // for good measure
- setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
return Action::HeaderModify;
}
std::string toString() const override
if (d_raw.size() >= sizeof(dnsheader)) {
auto id = dq->getHeader()->id;
dq->getMutableData() = d_raw;
- dq->getHeader()->id = id;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [id](dnsheader& header) {
+ header.id = id;
+ return true;
+ });
return Action::HeaderModify;
}
vector<ComboAddress> addrs;
data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen); // there goes your EDNS
uint8_t* dest = &(data.at(sizeof(dnsheader) + qnameWireLength + 4));
- dq->getHeader()->qr = true; // for good measure
- setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
- dq->getHeader()->ancount = 0;
- dq->getHeader()->arcount = 0; // for now, forget about your EDNS, we're marching over it
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ header.ancount = 0;
+ header.arcount = 0; // for now, forget about your EDNS, we're marching over it
+ return true;
+ });
uint32_t ttl = htonl(d_responseConfig.ttl);
uint16_t qclass = htons(dq->ids.qclass);
memcpy(dest, recordstart, sizeof(recordstart));
dest += sizeof(recordstart);
memcpy(dest, wireData.c_str(), wireData.length());
- dq->getHeader()->ancount++;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+ header.ancount++;
+ return true;
+ });
}
else if (!rawResponses.empty()) {
qtype = htons(qtype);
memcpy(dest, rawResponse.c_str(), rawResponse.size());
dest += rawResponse.size();
- dq->getHeader()->ancount++;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+ header.ancount++;
+ return true;
+ });
}
raw = true;
}
addr.sin4.sin_family == AF_INET ? reinterpret_cast<const void*>(&addr.sin4.sin_addr.s_addr) : reinterpret_cast<const void*>(&addr.sin6.sin6_addr.s6_addr),
addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr));
dest += (addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr));
- dq->getHeader()->ancount++;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+ header.ancount++;
+ return true;
+ });
}
}
- dq->getHeader()->ancount = htons(dq->getHeader()->ancount);
+ auto finalANCount = dq->getHeader()->ancount;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [finalANCount](dnsheader& header) {
+ header.ancount = htons(finalANCount);
+ return true;
+ });
if (hadEDNS && raw == false) {
addEDNS(dq->getMutableData(), dq->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0);
auto& data = dq->getMutableData();
if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
- dq->getHeader()->arcount = htons(1);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+ header.arcount = htons(1);
+ return true;
+ });
// make sure that any EDNS sent by the backend is removed before forwarding the response to the client
dq->ids.ednsAdded = true;
}
// this action does not stop the processing
DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
{
- dq->getHeader()->rd = false;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+ header.rd = false;
+ return true;
+ });
return Action::None;
}
std::string toString() const override
// this action does not stop the processing
DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override
{
- dq->getHeader()->cd = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [](dnsheader& header) {
+ header.cd = true;
+ return true;
+ });
return Action::None;
}
std::string toString() const override
}
dq->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType);
- dq->getHeader()->qr = true; // for good measure
- setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+ header.qr = true; // for good measure
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
return Action::HeaderModify;
}
return Action::None;
}
- setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->getMutableData(), [this](dnsheader& header) {
+ setResponseHeadersFromConfig(header, d_responseConfig);
+ return true;
+ });
return Action::Allow;
}
luaCtx.registerMember<const DNSName (DNSQuestion::*)>("qname", [](const DNSQuestion& dq) -> const DNSName { return dq.ids.qname; }, [](DNSQuestion& dq, const DNSName& newName) { (void) newName; });
luaCtx.registerMember<uint16_t (DNSQuestion::*)>("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; });
luaCtx.registerMember<uint16_t (DNSQuestion::*)>("qclass", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSQuestion& dq, uint16_t newClass) { (void) newClass; });
- luaCtx.registerMember<int (DNSQuestion::*)>("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.getHeader()->rcode = newRCode; });
+ luaCtx.registerMember<int (DNSQuestion::*)>("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [newRCode](dnsheader& header) {
+ header.rcode = newRCode;
+ return true;
+ });
+ });
luaCtx.registerMember<const ComboAddress (DNSQuestion::*)>("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; });
/* DNSDist DNSQuestion */
- luaCtx.registerMember<dnsheader* (DNSQuestion::*)>("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast<DNSQuestion&>(dq).getHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { *(dq.getHeader()) = *dh; });
+ luaCtx.registerMember<dnsheader* (DNSQuestion::*)>("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast<DNSQuestion&>(dq).getMutableHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&dh](dnsheader& header) {
+ header = *dh;
+ return true;
+ });
+ });
luaCtx.registerMember<uint16_t (DNSQuestion::*)>("len", [](const DNSQuestion& dq) -> uint16_t { return dq.getData().size(); }, [](DNSQuestion& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
luaCtx.registerMember<uint8_t (DNSQuestion::*)>("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; });
luaCtx.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.overTCP(); }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
auto& buffer = dq.getMutableData();
buffer.clear();
buffer.insert(buffer.begin(), raw.begin(), raw.end());
- reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+
+ reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [oldID](dnsheader& header) {
+ header.id = oldID;
+ return true;
+ });
});
luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSQuestion::*)()const>("getEDNSOptions", [](const DNSQuestion& dq) {
if (dq.ednsOptions == nullptr) {
luaCtx.registerMember<const DNSName (DNSResponse::*)>("qname", [](const DNSResponse& dq) -> const DNSName { return dq.ids.qname; }, [](DNSResponse& dq, const DNSName& newName) { (void) newName; });
luaCtx.registerMember<uint16_t (DNSResponse::*)>("qtype", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSResponse& dq, uint16_t newType) { (void) newType; });
luaCtx.registerMember<uint16_t (DNSResponse::*)>("qclass", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSResponse& dq, uint16_t newClass) { (void) newClass; });
- luaCtx.registerMember<int (DNSResponse::*)>("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) { dq.getHeader()->rcode = newRCode; });
+ luaCtx.registerMember<int (DNSResponse::*)>("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [newRCode](dnsheader& header) {
+ header.rcode = newRCode;
+ return true;
+ });
+ });
luaCtx.registerMember<const ComboAddress (DNSResponse::*)>("remoteaddr", [](const DNSResponse& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSResponse& dq, const ComboAddress newRemote) { (void) newRemote; });
- luaCtx.registerMember<dnsheader* (DNSResponse::*)>("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast<DNSResponse&>(dr).getHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { *(dr.getHeader()) = *dh; });
+ luaCtx.registerMember<dnsheader* (DNSResponse::*)>("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast<DNSResponse&>(dr).getMutableHeader(); }, [](DNSResponse& dr, const dnsheader* dh) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [&dh](dnsheader& header) {
+ header = *dh;
+ return true;
+ });
+ });
luaCtx.registerMember<uint16_t (DNSResponse::*)>("len", [](const DNSResponse& dq) -> uint16_t { return dq.getData().size(); }, [](DNSResponse& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
luaCtx.registerMember<uint8_t (DNSResponse::*)>("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; });
luaCtx.registerMember<bool (DNSResponse::*)>("tcp", [](const DNSResponse& dq) -> bool { return dq.overTCP(); }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; });
auto& buffer = dr.getMutableData();
buffer.clear();
buffer.insert(buffer.begin(), raw.begin(), raw.end());
- reinterpret_cast<dnsheader*>(buffer.data())->id = oldID;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(buffer, [oldID](dnsheader& header) {
+ header.id = oldID;
+ return true;
+ });
});
luaCtx.registerFunction<std::map<uint16_t, EDNSOptionView>(DNSResponse::*)()const>("getEDNSOptions", [](const DNSResponse& dq) {
#include "dnsdist.hh"
#include "dnsdist-concurrent-connections.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-nghttp2-in.hh"
#include "dnsdist-proxy-protocol.hh"
DNSResponse dr(ids, response.d_buffer, ds);
dr.d_incomingTCPState = state;
- memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH));
+ memcpy(&response.d_cleartextDH, dr.getHeader().get(), sizeof(response.d_cleartextDH));
if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) {
state->terminateClientConnection();
{
/* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
- auto* dh = reinterpret_cast<dnsheader*>(query.data());
- if (!checkQueryHeaders(dh, *d_ci.cs)) {
+ const dnsheader_aligned dh(query.data());
+ if (!checkQueryHeaders(dh.get(), *d_ci.cs)) {
return QueryProcessingResult::InvalidHeaders;
}
if (dh->qdcount == 0) {
TCPResponse response;
- dh->rcode = RCode::NotImp;
- dh->qr = true;
auto queryID = dh->id;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
+ header.rcode = RCode::NotImp;
+ header.qr = true;
+ return true;
+ });
response.d_idstate = std::move(ids);
response.d_idstate.origID = queryID;
response.d_idstate.selfGenerated = true;
}
DNSQuestion dq(ids, query);
- const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
- ids.origFlags = *flags;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [&ids](dnsheader& header) {
+ const uint16_t* flags = getFlagsFromDNSHeader(&header);
+ ids.origFlags = *flags;
+ return true;
+ });
dq.d_incomingTCPState = state;
dq.sni = d_handler.getServerNameIndication();
if (forwardViaUDPFirst()) {
// if there was no EDNS, we add it with a large buffer size
// so we can use UDP to talk to the backend.
- auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(query.data()));
+ const dnsheader_aligned dh(query.data());
if (!dh->arcount) {
if (addEDNS(query, 4096, false, 4096, 0)) {
dq.ids.ednsAdded = true;
// the buffer might have been invalidated by now
uint16_t queryID;
{
- const dnsheader* dh = dq.getHeader();
+ const auto dh = dq.getHeader();
queryID = dh->id;
}
if (result == ProcessQueryResult::SendAnswer) {
TCPResponse response;
{
- const dnsheader* dh = dq.getHeader();
- memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH));
+ const auto dh = dq.getHeader();
+ memcpy(&response.d_cleartextDH, dh.get(), sizeof(response.d_cleartextDH));
}
response.d_idstate = std::move(ids);
response.d_idstate.origID = queryID;
#include "dnsdist-xpf.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsparser.hh"
#include "xpf.hh"
pos += payload.size();
(void) pos;
- dq.getHeader()->arcount = htons(ntohs(dq.getHeader()->arcount) + 1);
-
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.arcount = htons(ntohs(header.arcount) + 1);
+ return true;
+ });
return true;
}
#include "dnsdist-carbon.hh"
#include "dnsdist-console.hh"
#include "dnsdist-discovery.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-dynblocks.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-healthchecks.hh"
}
packet.resize(static_cast<uint16_t>(sizeof(dnsheader)+qnameWireLength+DNS_TYPE_SIZE+DNS_CLASS_SIZE));
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
- dh->ancount = dh->arcount = dh->nscount = 0;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(packet, [](dnsheader& header) {
+ header.ancount = 0;
+ header.arcount = 0;
+ header.nscount = 0;
+ return true;
+ });
if (hadEDNS) {
addEDNS(packet, maximumSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0);
std::string DNSQuestion::getTrailingData() const
{
- const char* message = reinterpret_cast<const char*>(this->getHeader());
- const uint16_t messageLen = getDNSPacketLength(message, this->data.size());
+ const char* message = reinterpret_cast<const char*>(this->getData().data());
+ const uint16_t messageLen = getDNSPacketLength(message, this->getData().size());
return std::string(message + messageLen, this->getData().size() - messageLen);
}
return true;
}
+bool DNSQuestion::editHeader(std::function<bool(dnsheader&)> editFunction)
+{
+ if (data.size() < sizeof(dnsheader)) {
+ throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer");
+ }
+ return dnsdist::PacketMangling::editDNSHeaderFromPacket(data, editFunction);
+}
+
static void doLatencyStats(dnsdist::Protocol protocol, double udiff)
{
constexpr auto doAvg = [](double& var, double n, double weight) {
return false;
}
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(response.data());
+ const dnsheader_aligned dh(response.data());
if (dh->qr == 0) {
++dnsdist::metrics::g_stats.nonCompliantResponses;
if (remote) {
*flags |= origFlags;
}
-static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags)
+static bool fixUpQueryTurnedResponse(DNSQuestion& dnsQuestion, const uint16_t origFlags)
{
- restoreFlags(dq.getHeader(), origFlags);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [origFlags](dnsheader& header) {
+ restoreFlags(&header, origFlags);
+ return true;
+ });
- return addEDNSToQueryTurnedResponse(dq);
+ return addEDNSToQueryTurnedResponse(dnsQuestion);
}
static bool fixUpResponse(PacketBuffer& response, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, bool* zeroScope)
return false;
}
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
- restoreFlags(dh, origFlags);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [origFlags](dnsheader& header) {
+ restoreFlags(&header, origFlags);
+ return true;
+ });
if (response.size() == sizeof(dnsheader)) {
return true;
if (last) {
/* simply remove the last AR */
response.resize(response.size() - optLen);
- dh = reinterpret_cast<struct dnsheader*>(response.data());
- uint16_t arcount = ntohs(dh->arcount);
- arcount--;
- dh->arcount = htons(arcount);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [](dnsheader& header) {
+ uint16_t arcount = ntohs(header.arcount);
+ arcount--;
+ header.arcount = htons(arcount);
+ return true;
+ });
}
else {
/* Removing an intermediary RR could lead to compression error */
return true;
break;
case DNSResponseAction::Action::ServFail:
- dr.getHeader()->rcode = RCode::ServFail;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) {
+ header.rcode = RCode::ServFail;
+ return true;
+ });
return true;
break;
/* non-terminal actions follow */
if (ids.udpPayloadSize > 0 && response.size() > ids.udpPayloadSize) {
vinfolog("Got a response of size %d while the initial UDP payload size was %d, truncating", response.size(), ids.udpPayloadSize);
truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.ids.qname.wirelength());
- dr.getHeader()->tc = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dr.getMutableData(), [](dnsheader& header) {
+ header.tc = true;
+ return true;
+ });
}
else if (dr.getHeader()->tc && g_truncateTC) {
truncateTC(response, dr.getMaximumSize(), dr.ids.qname.wirelength());
/* when the answer is encrypted in place, we need to get a copy
of the original header before encryption to fill the ring buffer */
dnsheader cleartextDH;
- memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
+ memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH));
if (!isAsync) {
if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) {
}
response.resize(static_cast<size_t>(got));
- dnsheader* dh = reinterpret_cast<struct dnsheader*>(response.data());
+ const dnsheader_aligned dh(response.data());
queryId = dh->id;
auto ids = dss->getState(queryId);
auto du = std::move(ids->du);
- dh->id = ids->origID;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(response, [&ids](dnsheader& header) {
+ header.id = ids->origID;
+ return true;
+ });
++dss->responses;
double udiff = ids->queryRealTime.udiff();
return false;
}
- switch(action) {
+ auto setRCode = [&dq](uint8_t rcode) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) {
+ header.rcode = rcode;
+ header.qr = true;
+ return true;
+ });
+ };
+
+ switch (action) {
case DNSAction::Action::Allow:
return true;
break;
return true;
break;
case DNSAction::Action::Nxdomain:
- dq.getHeader()->rcode = RCode::NXDomain;
- dq.getHeader()->qr = true;
+ setRCode(RCode::NXDomain);
return true;
break;
case DNSAction::Action::Refused:
- dq.getHeader()->rcode = RCode::Refused;
- dq.getHeader()->qr = true;
+ setRCode(RCode::Refused);
return true;
break;
case DNSAction::Action::ServFail:
- dq.getHeader()->rcode = RCode::ServFail;
- dq.getHeader()->qr = true;
+ setRCode(RCode::ServFail);
return true;
break;
case DNSAction::Action::Spoof:
break;
case DNSAction::Action::Truncate:
if (!dq.overTCP()) {
- dq.getHeader()->tc = true;
- dq.getHeader()->qr = true;
- dq.getHeader()->ra = dq.getHeader()->rd;
- dq.getHeader()->aa = false;
- dq.getHeader()->ad = false;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.tc = true;
+ header.qr = true;
+ header.ra = header.rd;
+ header.aa = false;
+ header.ad = false;
+ return true;
+ });
++dnsdist::metrics::g_stats.ruleTruncated;
return true;
}
return true;
break;
case DNSAction::Action::NoRecurse:
- dq.getHeader()->rd = false;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.rd = false;
+ return true;
+ });
return true;
break;
/* non-terminal actions follow */
static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now)
{
+ auto setRCode = [&dq](uint8_t rcode) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [rcode](dnsheader& header) {
+ header.rcode = rcode;
+ header.qr = true;
+ return true;
+ });
+ };
+
if (g_rings.shouldRecordQueries()) {
g_rings.insertQuery(now, dq.ids.origRemote, dq.ids.qname, dq.ids.qtype, dq.getData().size(), *dq.getHeader(), dq.getProtocol());
}
if (action == DNSAction::Action::None) {
action = g_dynBlockAction;
}
+
switch (action) {
case DNSAction::Action::NoOp:
/* do nothing */
vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort());
updateBlockStats();
- dq.getHeader()->rcode = RCode::NXDomain;
- dq.getHeader()->qr=true;
+ setRCode(RCode::NXDomain);
return true;
case DNSAction::Action::Refused:
vinfolog("Query from %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort());
updateBlockStats();
- dq.getHeader()->rcode = RCode::Refused;
- dq.getHeader()->qr = true;
+ setRCode(RCode::Refused);
return true;
case DNSAction::Action::Truncate:
if (!dq.overTCP()) {
updateBlockStats();
vinfolog("Query from %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort());
- dq.getHeader()->tc = true;
- dq.getHeader()->qr = true;
- dq.getHeader()->ra = dq.getHeader()->rd;
- dq.getHeader()->aa = false;
- dq.getHeader()->ad = false;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.tc = true;
+ header.qr = true;
+ header.ra = header.rd;
+ header.aa = false;
+ header.ad = false;
+ return true;
+ });
return true;
}
else {
case DNSAction::Action::NoRecurse:
updateBlockStats();
vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort());
- dq.getHeader()->rd = false;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.rd = false;
+ return true;
+ });
return true;
default:
updateBlockStats();
vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString());
updateBlockStats();
- dq.getHeader()->rcode = RCode::NXDomain;
- dq.getHeader()->qr = true;
+ setRCode(RCode::NXDomain);
return true;
case DNSAction::Action::Refused:
vinfolog("Query from %s for %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString());
updateBlockStats();
- dq.getHeader()->rcode = RCode::Refused;
- dq.getHeader()->qr = true;
+ setRCode(RCode::Refused);
return true;
case DNSAction::Action::Truncate:
if (!dq.overTCP()) {
updateBlockStats();
vinfolog("Query from %s for %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString());
- dq.getHeader()->tc = true;
- dq.getHeader()->qr = true;
- dq.getHeader()->ra = dq.getHeader()->rd;
- dq.getHeader()->aa = false;
- dq.getHeader()->ad = false;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.tc = true;
+ header.qr = true;
+ header.ra = header.rd;
+ header.aa = false;
+ header.ad = false;
+ return true;
+ });
return true;
}
else {
case DNSAction::Action::NoRecurse:
updateBlockStats();
vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort());
- dq.getHeader()->rd = false;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.rd = false;
+ return true;
+ });
return true;
default:
updateBlockStats();
yet, as we will do a second-lookup */
if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKey, dq.ids.subnet, dq.ids.dnssecOK, forwardedOverUDP, allowExpired, false, true, dq.ids.protocol != dnsdist::Protocol::DoH || forwardedOverUDP)) {
- restoreFlags(dq.getHeader(), dq.ids.origFlags);
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [flags=dq.ids.origFlags](dnsheader& header) {
+ restoreFlags(&header, flags);
+ return true;
+ });
vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size());
vinfolog("%s query for %s|%s from %s, no downstream server available", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort());
if (g_servFailOnNoPolicy) {
- dq.getHeader()->rcode = RCode::ServFail;
- dq.getHeader()->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq.getMutableData(), [](dnsheader& header) {
+ header.rcode = RCode::ServFail;
+ header.qr = true;
+ return true;
+ });
fixUpQueryTurnedResponse(dq, dq.ids.origFlags);
}
/* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */
- dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader());
+ dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader().get());
if (dq.addXPF && selectedBackend->d_config.xpfRRCode != 0) {
addXPF(dq, selectedBackend->d_config.xpfRRCode);
{
/* this pointer will be invalidated the second the buffer is resized, don't hold onto it! */
- struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(query.data());
+ const dnsheader_aligned dh(query.data());
queryId = ntohs(dh->id);
- if (!checkQueryHeaders(dh, cs)) {
+ if (!checkQueryHeaders(dh.get(), cs)) {
return;
}
if (dh->qdcount == 0) {
- dh->rcode = RCode::NotImp;
- dh->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(query, [](dnsheader& header) {
+ header.rcode = RCode::NotImp;
+ header.qr = true;
+ return true;
+ });
+
sendUDPResponse(cs.udpFD, query, 0, dest, remote);
return;
}
ids.protocol = dnsdist::Protocol::DNSCryptUDP;
}
DNSQuestion dq(ids, query);
- const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader());
+ const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader().get());
ids.origFlags = *flags;
if (!proxyProtocolValues.empty()) {
}
// the buffer might have been invalidated by now (resized)
- struct dnsheader* dh = dq.getHeader();
+ const auto dh = dq.getHeader();
if (result == ProcessQueryResult::SendAnswer) {
#ifndef DISABLE_RECVMMSG
#if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE)
return data;
}
- dnsheader* getHeader()
+ bool editHeader(std::function<bool(dnsheader&)> editFunction);
+
+ const dnsheader_aligned getHeader() const
{
if (data.size() < sizeof(dnsheader)) {
throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer");
}
- return reinterpret_cast<dnsheader*>(&data.at(0));
+ dnsheader_aligned dh(data.data());
+ return dh;
}
- const dnsheader* getHeader() const
+ /* this function is not safe against unaligned access, you should
+ use editHeader() instead, but we need it for the Lua bindings */
+ dnsheader* getMutableHeader()
{
if (data.size() < sizeof(dnsheader)) {
throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer");
}
- return reinterpret_cast<const dnsheader*>(&data.at(0));
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ return reinterpret_cast<dnsheader*>(data.data());
}
bool hasRoomFor(size_t more) const
channel.hh channel.cc \
dns.cc dns.hh \
dnsdist-cache.cc dnsdist-cache.hh \
+ dnsdist-dnsparser.cc dnsdist-dnsparser.hh \
dnsdist-ecs.cc dnsdist-ecs.hh \
dnsdist-idstate.hh \
dnsdist-protocols.cc dnsdist-protocols.hh \
static bool parseSVCParams(const PacketBuffer& answer, std::map<uint16_t, DesignatedResolvers>& resolvers)
{
std::map<DNSName, std::vector<ComboAddress>> hints;
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(answer.data());
+ const dnsheader_aligned dh(answer.data());
PacketReader pr(std::string_view(reinterpret_cast<const char*>(answer.data()), answer.size()));
uint16_t qdcount = ntohs(dh->qdcount);
uint16_t ancount = ntohs(dh->ancount);
void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq)
{
- return dq->dq->getHeader();
+ return dq->dq->getMutableHeader();
}
uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq)
#ifdef HAVE_DNS_OVER_HTTPS
PacketBuffer bodyVect(body, body + bodyLen);
dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType);
- dq->dq->getHeader()->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [](dnsheader& header) {
+ header.qr = true;
+ return true;
+ });
#endif
}
void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode)
{
- dq->dq->getHeader()->rcode = rcode;
- dq->dq->getHeader()->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dq->dq->getMutableData(), [rcode](dnsheader& header) {
+ header.rcode = rcode;
+ header.qr = true;
+ return true;
+ });
}
void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len)
return false;
}
- auto oldId = reinterpret_cast<const dnsheader*>(query->query.d_buffer.data())->id;
+ dnsheader_aligned alignedHeader(query->query.d_buffer.data());
+ auto oldID = alignedHeader->id;
query->query.d_buffer.clear();
query->query.d_buffer.insert(query->query.d_buffer.begin(), raw, raw + rawSize);
- reinterpret_cast<dnsheader*>(query->query.d_buffer.data())->id = oldId;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(query->query.d_buffer, [oldID](dnsheader& header) {
+ header.id = oldID;
+ return true;
+ });
query->query.d_idstate.skipCache = true;
return dnsdist::queueQueryResumptionEvent(std::move(query));
*/
#include "base64.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-nghttp2-in.hh"
#include "dnsdist-proxy-protocol.hh"
#include "dnsparser.hh"
if (responseDH.get()->tc && state.d_packet && state.d_packet->size() > state.d_proxyProtocolPayloadSize && state.d_packet->size() - state.d_proxyProtocolPayloadSize > sizeof(dnsheader)) {
vinfolog("Response received from backend %s via UDP, for query %d received from %s via DoH, is truncated, retrying over TCP", response.d_ds->getNameWithAddr(), state.d_streamID, state.origRemote.toStringWithPort());
auto& query = *state.d_packet;
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- auto* queryDH = reinterpret_cast<struct dnsheader*>(&query.at(state.d_proxyProtocolPayloadSize));
- /* restoring the original ID */
- queryDH->id = state.origID;
+ dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&query.at(state.d_proxyProtocolPayloadSize), [origID = state.origID](dnsheader& header) {
+ /* restoring the original ID */
+ header.id = origID;
+ return true;
+ });
state.forwardedOverUDP = false;
bool proxyProtocolPayloadAdded = state.d_proxyProtocolPayloadSize > 0;
throw std::runtime_error("Looking for a TXT record in an answer smaller than the DNS header");
}
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(answer.data());
+ const dnsheader_aligned dh(answer.data());
PacketReader pr(answer);
uint16_t qdcount = ntohs(dh->qdcount);
uint16_t ancount = ntohs(dh->ancount);
TCPQuery(std::move(buffer), std::move(state)), d_connection(std::move(conn)), d_ds(std::move(ds))
{
if (d_buffer.size() >= sizeof(dnsheader)) {
- memcpy(&d_cleartextDH, reinterpret_cast<const dnsheader*>(d_buffer.data()), sizeof(d_cleartextDH));
+ dnsheader_aligned header(d_buffer.data());
+ memcpy(&d_cleartextDH, header.get(), sizeof(d_cleartextDH));
}
else {
memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
TCPQuery(std::move(query))
{
if (d_buffer.size() >= sizeof(dnsheader)) {
- memcpy(&d_cleartextDH, reinterpret_cast<const dnsheader*>(d_buffer.data()), sizeof(d_cleartextDH));
+ dnsheader_aligned header(d_buffer.data());
+ memcpy(&d_cleartextDH, header.get(), sizeof(d_cleartextDH));
}
else {
memset(&d_cleartextDH, 0, sizeof(d_cleartextDH));
#include "dns.hh"
#include "dolog.hh"
#include "dnsdist-concurrent-connections.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-ecs.hh"
#include "dnsdist-metrics.hh"
#include "dnsdist-proxy-protocol.hh"
DNSResponse dr(dohUnit->ids, dohUnit->response, dohUnit->downstream);
dnsheader cleartextDH{};
- memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
+ memcpy(&cleartextDH, dr.getHeader().get(), sizeof(cleartextDH));
if (!response.isAsync()) {
static thread_local LocalStateHolder<vector<DNSDistResponseRuleAction>> localRespRuleActions = g_respruleactions.getLocal();
{
/* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- auto* dnsHeader = reinterpret_cast<struct dnsheader*>(unit->query.data());
+ const dnsheader_aligned dnsHeader(unit->query.data());
- if (!checkQueryHeaders(dnsHeader, clientState)) {
+ if (!checkQueryHeaders(dnsHeader.get(), clientState)) {
unit->status_code = 400;
handleImmediateResponse(std::move(unit), "DoH invalid headers");
return;
}
if (dnsHeader->qdcount == 0U) {
- dnsHeader->rcode = RCode::NotImp;
- dnsHeader->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+ header.rcode = RCode::NotImp;
+ header.qr = true;
+ return true;
+ });
unit->response = std::move(unit->query);
handleImmediateResponse(std::move(unit), "DoH empty query");
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), static_cast<int>(sizeof(dnsheader)), false, &ids.qtype, &ids.qclass);
DNSQuestion dnsQuestion(ids, unit->query);
- const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader());
+ const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader().get());
ids.origFlags = *flags;
ids.cs = &clientState;
dnsQuestion.sni = std::move(unit->sni);
dohUnit->query.size() > dohUnit->ids.d_proxyProtocolPayloadSize &&
(dohUnit->query.size() - dohUnit->ids.d_proxyProtocolPayloadSize) > sizeof(dnsheader)) {
/* restoring the original ID */
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- auto* queryDH = reinterpret_cast<struct dnsheader*>(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize));
- queryDH->id = dohUnit->ids.origID;
+ dnsdist::PacketMangling::editDNSHeaderFromRawPacket(&dohUnit->query.at(dohUnit->ids.d_proxyProtocolPayloadSize), [oldID=dohUnit->ids.origID](dnsheader& header) {
+ header.id = oldID;
+ return true;
+ });
dohUnit->ids.forwardedOverUDP = false;
dohUnit->tcp = true;
dohUnit->truncated = false;
DNSResponse dnsResponse(dohUnit->ids, udpResponse, dohUnit->downstream);
dnsheader cleartextDH{};
- memcpy(&cleartextDH, dnsResponse.getHeader(), sizeof(cleartextDH));
+ memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
dnsResponse.ids.du = std::move(dohUnit);
if (!processResponse(udpResponse, *localRespRuleActions, *localCacheInsertedRespRuleActions, dnsResponse, false)) {
DNSResponse dnsResponse(unit->ids, unit->response, unit->downstream);
dnsheader cleartextDH{};
- memcpy(&cleartextDH, dnsResponse.getHeader(), sizeof(cleartextDH));
+ memcpy(&cleartextDH, dnsResponse.getHeader().get(), sizeof(cleartextDH));
if (!response.isAsync()) {
ids.queryRealTime.start();
DNSQuestion dq(ids, query);
packetCache->get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
- packetCache->insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
+ packetCache->insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
std::string poolName("test-pool");
auto testPool = std::make_shared<ServerPool>();
auto& query = conn->d_queries.at(frame->hd.stream_id);
BOOST_REQUIRE_GT(query.size(), sizeof(dnsheader));
- auto dh = reinterpret_cast<const dnsheader*>(query.data());
+ const dnsheader_aligned dh(query.data());
uint16_t id = ntohs(dh->id);
// cerr<<"got query ID "<<id<<endl;
}
BOOST_REQUIRE_GT(response.d_buffer.size(), sizeof(dnsheader));
- auto dh = reinterpret_cast<const dnsheader*>(response.d_buffer.data());
+ const dnsheader_aligned dh(response.d_buffer.data());
uint16_t id = ntohs(dh->id);
BOOST_REQUIRE_EQUAL(id, d_id);
}
try {
- auto dh = reinterpret_cast<const dnsheader*>(packet.data());
+ const dnsheader_aligned dh(packet.data());
DNSPacketMangler dpm(const_cast<char*>(reinterpret_cast<const char*>(packet.data())), length);
const uint16_t qdcount = ntohs(dh->qdcount);
return EINVAL;
}
try {
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
+ const dnsheader_aligned dh(initialPacket.data());
if (ntohs(dh->qdcount) == 0)
return ENOENT;
}
try
{
- const dnsheader* dh = (const dnsheader*) packet;
+ const dnsheader_aligned dh(packet);
DNSPacketMangler dpm(const_cast<char*>(packet), length);
const uint16_t qdcount = ntohs(dh->qdcount);
}
try
{
- const dnsheader* dh = reinterpret_cast<const dnsheader*>(packet);
+ const dnsheader_aligned dh(packet);
DNSPacketMangler dpm(const_cast<char*>(packet), length);
const uint16_t qdcount = ntohs(dh->qdcount);
}
try
{
- const dnsheader* dh = (const dnsheader*) packet;
+ const dnsheader_aligned dh(packet);
DNSPacketMangler dpm(const_cast<char*>(packet), length);
const uint16_t qdcount = ntohs(dh->qdcount);
try
{
- const dnsheader* dh = (const dnsheader*) packet;
+ const dnsheader_aligned dh(packet);
DNSPacketMangler dpm(const_cast<char*>(packet), length);
const uint16_t qdcount = ntohs(dh->qdcount);
try
{
- dnsheader dh;
- memcpy(&dh, reinterpret_cast<const dnsheader*>(packet.data()), sizeof(dh));
- uint64_t numrecords = ntohs(dh.ancount) + ntohs(dh.nscount) + ntohs(dh.arcount);
+ const dnsheader_aligned dh(packet.data());
+ uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount);
PacketReader reader(packet);
uint64_t n;
- for (n = 0; n < ntohs(dh.qdcount) ; ++n) {
+ for (n = 0; n < ntohs(dh->qdcount) ; ++n) {
(void) reader.getName();
/* type and class */
reader.skip(4);
for (n = 0; n < numrecords; ++n) {
(void) reader.getName();
- uint8_t section = n < ntohs(dh.ancount) ? 1 : (n < (ntohs(dh.ancount) + ntohs(dh.nscount)) ? 2 : 3);
+ uint8_t section = n < ntohs(dh->ancount) ? 1 : (n < (ntohs(dh->ancount) + ntohs(dh->nscount)) ? 2 : 3);
uint16_t dnstype = reader.get16BitInt();
uint16_t dnsclass = reader.get16BitInt();
}
if (packet != nullptr && len >= sizeof(dnsheader)) {
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+ const dnsheader_aligned dh(packet);
if (!dh->qr) {
pbf_message.add_bytes(DnstapMessageFields::query_message, packet, len);
} else {
return;
}
- const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(packet);
+ const dnsheader_aligned dh(packet);
if (ntohs(dh->ancount) == 0) {
return;
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
if (found == true) {
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none);
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
if (found == true) {
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
BOOST_CHECK_EQUAL(found, true);
BOOST_CHECK(!subnet);
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, !receivedOverUDP, RCode::NoError, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, !receivedOverUDP, RCode::NoError, boost::none);
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, !receivedOverUDP, 0, true);
BOOST_CHECK_EQUAL(found, true);
BOOST_CHECK(!subnet);
BOOST_CHECK(!subnet);
// Insert with failure-TTL of 0 (-> should not enter cache).
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(0));
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(0));
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
// Insert with failure-TTL non-zero (-> should enter cache).
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(300));
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, RCode::ServFail, boost::optional<uint32_t>(300));
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
BOOST_CHECK_EQUAL(found, true);
BOOST_CHECK(!subnet);
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none);
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
BOOST_CHECK_EQUAL(found, true);
BOOST_CHECK(!subnet);
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
BOOST_CHECK_EQUAL(found, true);
BOOST_CHECK(!subnet);
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none);
bool allowTruncated = true;
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true, allowTruncated);
DNSQuestion dq(ids, query);
g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP);
- g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
+ g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none);
}
}
catch(PDNSException& e) {
BOOST_CHECK_EQUAL(found, false);
BOOST_CHECK(!subnet);
- PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, ids.qtype, ids.qclass, response, receivedOverUDP, 0, boost::none);
+ PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader().get())), dnssecOK, ids.qname, ids.qtype, ids.qclass, response, receivedOverUDP, 0, boost::none);
found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true);
BOOST_CHECK_EQUAL(found, false);
}