class dnsheader_aligned
{
public:
+ static bool isMemoryAligned(const void* mem)
+ {
+ return reinterpret_cast<uintptr_t>(mem) % sizeof(uint32_t) == 0; // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
+ }
+
dnsheader_aligned(const void* mem)
{
- if (reinterpret_cast<uintptr_t>(mem) % sizeof(uint32_t) == 0) { // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
+ if (isMemoryAligned(mem)) {
d_p = reinterpret_cast<const dnsheader*>(mem); // NOLINT(cppcoreguidelines-pro-type-reinterpret-cast)
}
else {
return d_p;
}
+ [[nodiscard]] const dnsheader& operator*() const
+ {
+ return *d_p;
+ }
+
+ [[nodiscard]] const dnsheader* operator->() const
+ {
+ return d_p;
+ }
+
private:
dnsheader d_h{};
- const dnsheader *d_p{};
+ const dnsheader* d_p{};
};
-inline uint16_t * getFlagsFromDNSHeader(struct dnsheader * dh)
+inline uint16_t* getFlagsFromDNSHeader(dnsheader* dh)
+{
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ return reinterpret_cast<uint16_t*>(reinterpret_cast<char*>(dh) + sizeof(uint16_t));
+}
+
+inline const uint16_t * getFlagsFromDNSHeader(const dnsheader* dh)
{
- return (uint16_t*) (((char *) dh) + sizeof(uint16_t));
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ return reinterpret_cast<const uint16_t*>(reinterpret_cast<const char*>(dh) + sizeof(uint16_t));
}
#define DNS_TYPE_SIZE (2)
return true;
}
+namespace PacketMangling
+{
+ bool editDNSHeaderFromPacket(PacketBuffer& packet, std::function<bool(dnsheader& header)> editFunction)
+ {
+ if (packet.size() < sizeof(dnsheader)) {
+ throw std::runtime_error("Trying to edit the DNS header of a too small packet");
+ }
+
+ return editDNSHeaderFromRawPacket(packet.data(), editFunction);
+ }
+
+ bool editDNSHeaderFromRawPacket(void* packet, std::function<bool(dnsheader& header)> editFunction)
+ {
+ if (dnsheader_aligned::isMemoryAligned(packet)) {
+ // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
+ auto* header = reinterpret_cast<dnsheader*>(packet);
+ return editFunction(*header);
+ }
+
+ dnsheader header;
+ memcpy(&header, packet, sizeof(header));
+ if (!editFunction(header)) {
+ return false;
+ }
+ memcpy(packet, &header, sizeof(header));
+ return true;
+ }
+}
}
#include "threadname.hh"
#include "dnsdist-ecs.hh"
+#include "dnsdist-dnsparser.hh"
#include "dnsdist-proxy-protocol.hh"
#include "dnsdist-tcp.hh"
#include "dnsdist-random.hh"
if (unit->query.size() < sizeof(dnsheader)) {
++dnsdist::metrics::g_stats.nonCompliantQueries;
++clientState.nonCompliantQueries;
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- auto* dnsHeader = reinterpret_cast<struct dnsheader*>(unit->query.data());
- dnsHeader->rcode = RCode::ServFail;
- dnsHeader->qr = true;
- unit->response = std::move(unit->query);
+ unit->response.clear();
handleImmediateResponse(std::move(unit), "DoQ non-compliant query");
return;
{
/* don't keep that pointer around, it will be invalidated if the buffer is ever resized */
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- auto* dnsHeader = reinterpret_cast<struct dnsheader*>(unit->query.data());
-
- if (!checkQueryHeaders(dnsHeader, clientState)) {
- dnsHeader->rcode = RCode::ServFail;
- dnsHeader->qr = true;
+ dnsheader_aligned dnsHeader(unit->query.data());
+
+ if (!checkQueryHeaders(dnsHeader.get(), clientState)) {
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+ header.rcode = RCode::ServFail;
+ header.qr = true;
+ return true;
+ });
unit->response = std::move(unit->query);
handleImmediateResponse(std::move(unit), "DoQ invalid headers");
}
if (dnsHeader->qdcount == 0) {
- dnsHeader->rcode = RCode::NotImp;
- dnsHeader->qr = true;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(unit->query, [](dnsheader& header) {
+ header.rcode = RCode::NotImp;
+ header.qr = true;
+ return true;
+ });
unit->response = std::move(unit->query);
handleImmediateResponse(std::move(unit), "DoQ empty query");
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
unit->ids.qname = DNSName(reinterpret_cast<const char*>(unit->query.data()), static_cast<int>(unit->query.size()), sizeof(dnsheader), false, &unit->ids.qtype, &unit->ids.qclass);
DNSQuestion dnsQuestion(unit->ids, unit->query);
- const uint16_t* flags = getFlagsFromDNSHeader(dnsQuestion.getHeader());
- ids.origFlags = *flags;
+ dnsdist::PacketMangling::editDNSHeaderFromPacket(dnsQuestion.getMutableData(), [&ids](dnsheader& header) {
+ const uint16_t* flags = getFlagsFromDNSHeader(&header);
+ ids.origFlags = *flags;
+ return true;
+ });
unit->ids.cs = &clientState;
auto result = processQuery(dnsQuestion, holders, downstream);
unit->response = std::move(unit->query);
}
if (unit->response.size() >= sizeof(dnsheader)) {
- // NOLINTNEXTLINE(cppcoreguidelines-pro-type-reinterpret-cast)
- const auto* dnsHeader = reinterpret_cast<const struct dnsheader*>(unit->response.data());
+ const dnsheader_aligned dnsHeader(unit->response.data());
handleResponseSent(unit->ids.qname, QType(unit->ids.qtype), 0., unit->ids.origDest, ComboAddress(), unit->response.size(), *dnsHeader, dnsdist::Protocol::DoQ, dnsdist::Protocol::DoQ, false);
}