From 592b1d997f567d5d04a6bca734eeeea03037178f Mon Sep 17 00:00:00 2001 From: Remi Gacogne Date: Fri, 25 Nov 2022 17:38:07 +0100 Subject: [PATCH] dnsdist: Turn DNSQuestion and DNSResponse into IDState overlays Remaining: queryRealTime udpPayloadSize --- pdns/dnsdist-cache.cc | 8 +- pdns/dnsdist-ecs.cc | 21 +- pdns/dnsdist-idstate.hh | 131 ++++--- pdns/dnsdist-lua-actions.cc | 98 +++-- pdns/dnsdist-lua-bindings-dnsquestion.cc | 72 ++-- pdns/dnsdist-lua-rules.cc | 30 +- pdns/dnsdist-protobuf.cc | 4 +- pdns/dnsdist-snmp.cc | 12 +- pdns/dnsdist-tcp.cc | 55 +-- pdns/dnsdist-xpf.cc | 2 +- pdns/dnsdist.cc | 340 ++++++++++-------- pdns/dnsdist.hh | 76 ++-- pdns/dnsdistdist/Makefile.am | 4 +- pdns/dnsdistdist/dnsdist-backend.cc | 23 +- pdns/dnsdistdist/dnsdist-healthchecks.cc | 4 +- pdns/dnsdistdist/dnsdist-idstate.cc | 75 ---- pdns/dnsdistdist/dnsdist-kvs.hh | 12 +- pdns/dnsdistdist/dnsdist-lbpolicies.cc | 4 +- pdns/dnsdistdist/dnsdist-lua-ffi.cc | 72 ++-- pdns/dnsdistdist/dnsdist-proxy-protocol.cc | 2 +- pdns/dnsdistdist/dnsdist-rules.hh | 40 +-- pdns/dnsdistdist/dnsdist-tcp-downstream.cc | 2 +- pdns/dnsdistdist/dnsdist-tcp-upstream.hh | 2 +- pdns/dnsdistdist/dnsdist-tcp.hh | 9 +- pdns/dnsdistdist/doh.cc | 219 +++++------ pdns/dnsdistdist/test-dnsdistkvs_cc.cc | 60 ++-- pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc | 15 +- pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc | 30 +- pdns/dnsdistdist/test-dnsdistrules_cc.cc | 34 +- pdns/dnsdistdist/test-dnsdisttcp_cc.cc | 4 +- pdns/doh.hh | 16 +- pdns/test-dnsdist_cc.cc | 279 +++++++------- pdns/test-dnsdistpacketcache_cc.cc | 197 ++++++---- 33 files changed, 927 insertions(+), 1025 deletions(-) delete mode 100644 pdns/dnsdistdist/dnsdist-idstate.cc diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index 9a4169af25..cf2dab9e57 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -194,15 +194,15 @@ void DNSDistPacketCache::insert(uint32_t key, const boost::optional& su bool DNSDistPacketCache::get(DNSQuestion& dq, uint16_t queryId, uint32_t* keyOut, boost::optional& subnet, bool dnssecOK, bool receivedOverUDP, uint32_t allowExpired, bool skipAging, bool truncatedOK) { - const auto& dnsQName = dq.qname->getStorage(); - uint32_t key = getKey(dnsQName, dq.qname->wirelength(), dq.getData(), receivedOverUDP); + const auto& dnsQName = dq.ids.qname.getStorage(); + uint32_t key = getKey(dnsQName, dq.ids.qname.wirelength(), dq.getData(), receivedOverUDP); if (keyOut) { *keyOut = key; } if (d_parseECS) { - getClientSubnet(dq.getData(), dq.qname->wirelength(), subnet); + getClientSubnet(dq.getData(), dq.ids.qname.wirelength(), subnet); } uint32_t shardIndex = getShardIndex(key); @@ -240,7 +240,7 @@ bool DNSDistPacketCache::get(DNSQuestion& dq, uint16_t queryId, uint32_t* keyOut } /* check for collision */ - if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader())), *dq.qname, dq.qtype, dq.qclass, receivedOverUDP, dnssecOK, subnet)) { + if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader())), dq.ids.qname, dq.ids.qtype, dq.ids.qclass, receivedOverUDP, dnssecOK, subnet)) { d_lookupCollisions++; return false; } diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index 60f075cc9d..af3240a63f 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -528,7 +528,7 @@ bool parseEDNSOptions(const DNSQuestion& dq) size_t remaining = 0; uint16_t optRDPosition; - int res = getEDNSOptionsStart(dq.getData(), dq.qname->wirelength(), &optRDPosition, &remaining); + int res = getEDNSOptionsStart(dq.getData(), dq.ids.qname.wirelength(), &optRDPosition, &remaining); if (res == 0) { res = getEDNSOptions(reinterpret_cast(&dq.getData().at(optRDPosition)), remaining, *dq.ednsOptions); @@ -649,11 +649,10 @@ bool handleEDNSClientSubnet(PacketBuffer& packet, const size_t maximumSize, cons bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded) { - assert(dq.remote != nullptr); string newECSOption; - generateECSOption(dq.ecsSet ? dq.ecs.getNetwork() : *dq.remote, newECSOption, dq.ecsSet ? dq.ecs.getBits() : dq.ecsPrefixLength); + generateECSOption(dq.ecs ? dq.ecs->getNetwork() : dq.ids.origRemote, newECSOption, dq.ecs ? dq.ecs->getBits() : dq.ecsPrefixLength); - return handleEDNSClientSubnet(dq.getMutableData(), dq.getMaximumSize(), dq.qname->wirelength(), ednsAdded, ecsAdded, dq.ecsOverride, newECSOption); + return handleEDNSClientSubnet(dq.getMutableData(), dq.getMaximumSize(), dq.ids.qname.wirelength(), ednsAdded, ecsAdded, dq.ecsOverride, newECSOption); } static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen) @@ -870,7 +869,7 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, return false; } - size_t queryPartSize = sizeof(dnsheader) + dq.qname->wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE; + size_t queryPartSize = sizeof(dnsheader) + dq.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE; if (packet.size() < queryPartSize) { /* something is already wrong, don't build on flawed foundations */ return false; @@ -960,7 +959,7 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq) size_t remaining = 0; auto& packet = dq.getMutableData(); - int res = getEDNSOptionsStart(packet, dq.qname->wirelength(), &optRDPosition, &remaining); + int res = getEDNSOptionsStart(packet, dq.ids.qname.wirelength(), &optRDPosition, &remaining); if (res != 0) { /* if the initial query did not have EDNS0, we are done */ @@ -1007,7 +1006,7 @@ int getEDNSZ(const DNSQuestion& dq) return 0; } - size_t pos = sizeof(dnsheader) + dq.qname->wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE; + size_t pos = sizeof(dnsheader) + dq.ids.qname.wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE; if (dq.getData().size() <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) { return 0; @@ -1044,7 +1043,7 @@ bool queryHasEDNS(const DNSQuestion& dq) uint16_t optRDPosition; size_t ecsRemaining = 0; - int res = getEDNSOptionsStart(dq.getData(), dq.qname->wirelength(), &optRDPosition, &ecsRemaining); + int res = getEDNSOptionsStart(dq.getData(), dq.ids.qname.wirelength(), &optRDPosition, &ecsRemaining); if (res == 0) { return true; } @@ -1099,8 +1098,8 @@ bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsDa } dq.getMutableData() = std::move(newContent); - if (!dq.ednsAdded && ednsAdded) { - dq.ednsAdded = true; + if (!dq.ids.ednsAdded && ednsAdded) { + dq.ids.ednsAdded = true; } return true; @@ -1110,7 +1109,7 @@ bool setEDNSOption(DNSQuestion& dq, uint16_t ednsCode, const std::string& ednsDa if (generateOptRR(optRData, data, dq.getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) { dq.getHeader()->arcount = htons(1); // make sure that any EDNS sent by the backend is removed before forwarding the response to the client - dq.ednsAdded = true; + dq.ids.ednsAdded = true; } return true; diff --git a/pdns/dnsdist-idstate.hh b/pdns/dnsdist-idstate.hh index 98520c0f12..33699761b2 100644 --- a/pdns/dnsdist-idstate.hh +++ b/pdns/dnsdist-idstate.hh @@ -93,24 +93,79 @@ private: #endif #endif +struct InternalQueryState +{ + static void DeleterPlaceHolder(DOHUnit*) + { + } + + InternalQueryState() : + du(std::unique_ptr(nullptr, DeleterPlaceHolder)) + { + origDest.sin4.sin_family = 0; + } + + InternalQueryState(InternalQueryState&& rhs) = default; + InternalQueryState& operator=(InternalQueryState&& rhs) = default; + + InternalQueryState(const InternalQueryState& orig) = delete; + InternalQueryState& operator=(const InternalQueryState& orig) = delete; + + boost::optional subnet{boost::none}; // 40 + ComboAddress origRemote; // 28 + ComboAddress origDest; // 28 + ComboAddress hopRemote; + ComboAddress hopLocal; + DNSName qname; // 24 + std::string poolName; // 24 + StopWatch sentTime; // 16 + std::shared_ptr packetCache{nullptr}; // 16 + std::unique_ptr dnsCryptQuery{nullptr}; // 8 + std::unique_ptr qTag{nullptr}; // 8 + boost::optional tempFailureTTL; // 8 + ClientState* cs{nullptr}; // 8 + std::unique_ptr du; // 8 + uint32_t cacheKey{0}; // 4 + uint32_t cacheKeyNoECS{0}; // 4 + // DoH-only */ + uint32_t cacheKeyUDP{0}; // 4 + int backendFD{-1}; // 4 + int delayMsec{0}; + uint16_t qtype{0}; // 2 + uint16_t qclass{0}; // 2 + // origID is in network-byte order + uint16_t origID{0}; // 2 + uint16_t origFlags{0}; // 2 + uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2 + dnsdist::Protocol protocol; // 1 + boost::optional uniqueId{boost::none}; // 17 (placed here to reduce the space lost to padding) + bool ednsAdded{false}; + bool ecsAdded{false}; + bool skipCache{false}; + bool dnssecOK{false}; + bool useZeroScope{false}; + bool forwardedOverUDP{false}; +}; + struct IDState { - IDState() : - sentTime(true), tempFailureTTL(boost::none) { origDest.sin4.sin_family = 0; } + IDState() + { + } + IDState(const IDState& orig) = delete; - IDState(IDState&& rhs) : - subnet(rhs.subnet), origRemote(rhs.origRemote), origDest(rhs.origDest), hopRemote(rhs.hopRemote), hopLocal(rhs.hopLocal), qname(std::move(rhs.qname)), sentTime(rhs.sentTime), packetCache(std::move(rhs.packetCache)), dnsCryptQuery(std::move(rhs.dnsCryptQuery)), qTag(std::move(rhs.qTag)), tempFailureTTL(rhs.tempFailureTTL), cs(rhs.cs), du(std::move(rhs.du)), cacheKey(rhs.cacheKey), cacheKeyNoECS(rhs.cacheKeyNoECS), cacheKeyUDP(rhs.cacheKeyUDP), backendFD(rhs.backendFD), delayMsec(rhs.delayMsec), qtype(rhs.qtype), qclass(rhs.qclass), origID(rhs.origID), origFlags(rhs.origFlags), cacheFlags(rhs.cacheFlags), protocol(rhs.protocol), ednsAdded(rhs.ednsAdded), ecsAdded(rhs.ecsAdded), skipCache(rhs.skipCache), dnssecOK(rhs.dnssecOK), useZeroScope(rhs.useZeroScope) + IDState(IDState&& rhs) { if (rhs.isInUse()) { throw std::runtime_error("Trying to move an in-use IDState"); } - uniqueId = std::move(rhs.uniqueId); #ifdef __SANITIZE_THREAD__ age.store(rhs.age.load()); #else age = rhs.age; #endif + internal = std::move(rhs.internal); } IDState& operator=(IDState&& rhs) @@ -122,42 +177,13 @@ struct IDState if (rhs.isInUse()) { throw std::runtime_error("Trying to move an in-use IDState"); } - - subnet = std::move(rhs.subnet); - origRemote = rhs.origRemote; - origDest = rhs.origDest; - hopRemote = rhs.hopRemote; - hopLocal = rhs.hopLocal; - qname = std::move(rhs.qname); - sentTime = rhs.sentTime; - dnsCryptQuery = std::move(rhs.dnsCryptQuery); - packetCache = std::move(rhs.packetCache); - qTag = std::move(rhs.qTag); - tempFailureTTL = std::move(rhs.tempFailureTTL); - cs = rhs.cs; - du = std::move(rhs.du); - cacheKey = rhs.cacheKey; - cacheKeyNoECS = rhs.cacheKeyNoECS; - cacheKeyUDP = rhs.cacheKeyUDP; - backendFD = rhs.backendFD; - delayMsec = rhs.delayMsec; #ifdef __SANITIZE_THREAD__ age.store(rhs.age.load()); #else age = rhs.age; #endif - qtype = rhs.qtype; - qclass = rhs.qclass; - origID = rhs.origID; - origFlags = rhs.origFlags; - cacheFlags = rhs.cacheFlags; - protocol = rhs.protocol; - uniqueId = std::move(rhs.uniqueId); - ednsAdded = rhs.ednsAdded; - ecsAdded = rhs.ecsAdded; - skipCache = rhs.skipCache; - dnssecOK = rhs.dnssecOK; - useZeroScope = rhs.useZeroScope; + + internal = std::move(rhs.internal); return *this; } @@ -228,43 +254,12 @@ struct IDState wrapping around if necessary, and we set an atomic signed 64-bit value, so that we still have -1 when the state is unused and the value of our counter otherwise. */ - boost::optional subnet{boost::none}; // 40 - ComboAddress origRemote; // 28 - ComboAddress origDest; // 28 - ComboAddress hopRemote; - ComboAddress hopLocal; - DNSName qname; // 24 - StopWatch sentTime; // 16 - std::shared_ptr packetCache{nullptr}; // 16 - std::unique_ptr dnsCryptQuery{nullptr}; // 8 - std::unique_ptr qTag{nullptr}; // 8 - boost::optional tempFailureTTL; // 8 - ClientState* cs{nullptr}; // 8 - DOHUnit* du{nullptr}; // 8 (not a unique_ptr because we currently need to be able to peek at it without taking ownership until later) + InternalQueryState internal; std::atomic usageIndicator{unusedIndicator}; // set to unusedIndicator to indicate this state is empty // 8 std::atomic generation{0}; // increased every time a state is used, to be able to detect an ABA issue // 4 - uint32_t cacheKey{0}; // 4 - uint32_t cacheKeyNoECS{0}; // 4 - // DoH-only */ - uint32_t cacheKeyUDP{0}; // 4 - int backendFD{-1}; // 4 - int delayMsec{0}; #ifdef __SANITIZE_THREAD__ std::atomic age{0}; #else uint16_t age{0}; // 2 #endif - uint16_t qtype{0}; // 2 - uint16_t qclass{0}; // 2 - // origID is in network-byte order - uint16_t origID{0}; // 2 - uint16_t origFlags{0}; // 2 - uint16_t cacheFlags{0}; // DNS flags as sent to the backend // 2 - dnsdist::Protocol protocol; // 1 - boost::optional uniqueId{boost::none}; // 17 (placed here to reduce the space lost to padding) - bool ednsAdded{false}; - bool ecsAdded{false}; - bool skipCache{false}; - bool dnssecOK{false}; - bool useZeroScope{false}; }; diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index d0df07c791..12ce2793a6 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -198,9 +198,9 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult bool ecsAdded = false; std::string newECSOption; - generateECSOption(dq->ecsSet ? dq->ecs.getNetwork() : *dq->remote, newECSOption, dq->ecsSet ? dq->ecs.getBits() : dq->ecsPrefixLength); + generateECSOption(dq->ecs ? dq->ecs->getNetwork() : dq->ids.origRemote, newECSOption, dq->ecs ? dq->ecs->getBits() : dq->ecsPrefixLength); - if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->qname->wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) { + if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->ids.qname.wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) { return DNSAction::Action::None; } @@ -288,7 +288,7 @@ public: return Action::Pool; } else { - dq->poolname = d_pool; + dq->ids.poolName = d_pool; return Action::None; } } @@ -317,7 +317,7 @@ public: return Action::Pool; } else { - dq->poolname = d_pool; + dq->ids.poolName = d_pool; return Action::None; } } @@ -409,23 +409,23 @@ public: { /* it will likely be a bit bigger than that because of additionals */ uint16_t numberOfRecords = d_payloads.size(); - const auto qnameWireLength = dq->qname->wirelength(); + const auto qnameWireLength = dq->ids.qname.wirelength(); if (dq->getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + d_totalPayloadsSize)) { return Action::None; } PacketBuffer newPacket; newPacket.reserve(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + d_totalPayloadsSize); - GenericDNSPacketWriter pw(newPacket, *dq->qname, dq->qtype); + GenericDNSPacketWriter pw(newPacket, dq->ids.qname, dq->ids.qtype); for (const auto& payload : d_payloads) { - pw.startRecord(*dq->qname, dq->qtype, d_responseConfig.ttl); + pw.startRecord(dq->ids.qname, dq->ids.qtype, d_responseConfig.ttl); pw.xfrBlob(payload); pw.commit(); } if (newPacket.size() < dq->getMaximumSize()) { for (const auto& additional : d_additionals4) { - pw.startRecord(additional.first.isRoot() ? *dq->qname : additional.first, QType::A, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); + pw.startRecord(additional.first.isRoot() ? dq->ids.qname : additional.first, QType::A, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); pw.xfrCAWithoutPort(4, additional.second); pw.commit(); } @@ -433,7 +433,7 @@ public: if (newPacket.size() < dq->getMaximumSize()) { for (const auto& additional : d_additionals6) { - pw.startRecord(additional.first.isRoot() ? *dq->qname : additional.first, QType::AAAA, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); + pw.startRecord(additional.first.isRoot() ? dq->ids.qname : additional.first, QType::AAAA, d_responseConfig.ttl, QClass::IN, DNSResourceRecord::ADDITIONAL); pw.xfrCAWithoutPort(6, additional.second); pw.commit(); } @@ -783,7 +783,7 @@ thread_local std::default_random_engine SpoofAction::t_randomEngine; DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresult) const { - uint16_t qtype = dq->qtype; + uint16_t qtype = dq->ids.qtype; // do we even have a response? if (d_cname.empty() && d_rawResponses.empty() && @@ -934,7 +934,7 @@ public: DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { dnsdist::MacAddress mac; - int res = dnsdist::MacAddressesCache::get(*dq->remote, mac.data(), mac.size()); + int res = dnsdist::MacAddressesCache::get(dq->ids.origRemote, mac.data(), mac.size()); if (res != 0) { return Action::None; } @@ -957,8 +957,8 @@ public: } dq->getMutableData() = std::move(newContent); - if (!dq->ednsAdded && ednsAdded) { - dq->ednsAdded = true; + if (!dq->ids.ednsAdded && ednsAdded) { + dq->ids.ednsAdded = true; } return Action::None; @@ -968,7 +968,7 @@ public: if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) { dq->getHeader()->arcount = htons(1); // make sure that any EDNS sent by the backend is removed before forwarding the response to the client - dq->ednsAdded = true; + dq->ids.ednsAdded = true; } return Action::None; @@ -1045,41 +1045,41 @@ public: if (!fp) { if (!d_verboseOnly || g_verbose) { if (d_includeTimestamp) { - infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast(dq->queryTime->tv_sec), static_cast(dq->queryTime->tv_nsec), dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).toString(), dq->getHeader()->id); + infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast(dq->queryTime.tv_sec), static_cast(dq->queryTime.tv_nsec), dq->ids.origRemote.toStringWithPort(), dq->ids.qname.toString(), QType(dq->ids.qtype).toString(), dq->getHeader()->id); } else { - infolog("Packet from %s for %s %s with id %d", dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).toString(), dq->getHeader()->id); + infolog("Packet from %s for %s %s with id %d", dq->ids.origRemote.toStringWithPort(), dq->ids.qname.toString(), QType(dq->ids.qtype).toString(), dq->getHeader()->id); } } } else { if (d_binary) { - const auto& out = dq->qname->getStorage(); + const auto& out = dq->ids.qname.getStorage(); if (d_includeTimestamp) { - uint64_t tv_sec = static_cast(dq->queryTime->tv_sec); - uint32_t tv_nsec = static_cast(dq->queryTime->tv_nsec); + uint64_t tv_sec = static_cast(dq->queryTime.tv_sec); + uint32_t tv_nsec = static_cast(dq->queryTime.tv_nsec); fwrite(&tv_sec, sizeof(tv_sec), 1, fp.get()); fwrite(&tv_nsec, sizeof(tv_nsec), 1, fp.get()); } uint16_t id = dq->getHeader()->id; fwrite(&id, sizeof(id), 1, fp.get()); fwrite(out.c_str(), 1, out.size(), fp.get()); - fwrite(&dq->qtype, sizeof(dq->qtype), 1, fp.get()); - fwrite(&dq->remote->sin4.sin_family, sizeof(dq->remote->sin4.sin_family), 1, fp.get()); - if (dq->remote->sin4.sin_family == AF_INET) { - fwrite(&dq->remote->sin4.sin_addr.s_addr, sizeof(dq->remote->sin4.sin_addr.s_addr), 1, fp.get()); + fwrite(&dq->ids.qtype, sizeof(dq->ids.qtype), 1, fp.get()); + fwrite(&dq->ids.origRemote.sin4.sin_family, sizeof(dq->ids.origRemote.sin4.sin_family), 1, fp.get()); + if (dq->ids.origRemote.sin4.sin_family == AF_INET) { + fwrite(&dq->ids.origRemote.sin4.sin_addr.s_addr, sizeof(dq->ids.origRemote.sin4.sin_addr.s_addr), 1, fp.get()); } - else if (dq->remote->sin4.sin_family == AF_INET6) { - fwrite(&dq->remote->sin6.sin6_addr.s6_addr, sizeof(dq->remote->sin6.sin6_addr.s6_addr), 1, fp.get()); + else if (dq->ids.origRemote.sin4.sin_family == AF_INET6) { + fwrite(&dq->ids.origRemote.sin6.sin6_addr.s6_addr, sizeof(dq->ids.origRemote.sin6.sin6_addr.s6_addr), 1, fp.get()); } - fwrite(&dq->remote->sin4.sin_port, sizeof(dq->remote->sin4.sin_port), 1, fp.get()); + fwrite(&dq->ids.origRemote.sin4.sin_port, sizeof(dq->ids.origRemote.sin4.sin_port), 1, fp.get()); } else { if (d_includeTimestamp) { - fprintf(fp.get(), "[%llu.%lu] Packet from %s for %s %s with id %d\n", static_cast(dq->queryTime->tv_sec), static_cast(dq->queryTime->tv_nsec), dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).toString().c_str(), dq->getHeader()->id); + fprintf(fp.get(), "[%llu.%lu] Packet from %s for %s %s with id %d\n", static_cast(dq->queryTime.tv_sec), static_cast(dq->queryTime.tv_nsec), dq->ids.origRemote.toStringWithPort().c_str(), dq->ids.qname.toString().c_str(), QType(dq->ids.qtype).toString().c_str(), dq->getHeader()->id); } else { - fprintf(fp.get(), "Packet from %s for %s %s with id %d\n", dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).toString().c_str(), dq->getHeader()->id); + fprintf(fp.get(), "Packet from %s for %s %s with id %d\n", dq->ids.origRemote.toStringWithPort().c_str(), dq->ids.qname.toString().c_str(), QType(dq->ids.qtype).toString().c_str(), dq->getHeader()->id); } } } @@ -1157,19 +1157,19 @@ public: if (!fp) { if (!d_verboseOnly || g_verbose) { if (d_includeTimestamp) { - infolog("[%u.%u] Answer to %s for %s %s (%s) with id %d", static_cast(dr->queryTime->tv_sec), static_cast(dr->queryTime->tv_nsec), dr->remote->toStringWithPort(), dr->qname->toString(), QType(dr->qtype).toString(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); + infolog("[%u.%u] Answer to %s for %s %s (%s) with id %d", static_cast(dr->queryTime.tv_sec), static_cast(dr->queryTime.tv_nsec), dr->ids.origRemote.toStringWithPort(), dr->ids.qname.toString(), QType(dr->ids.qtype).toString(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); } else { - infolog("Answer to %s for %s %s (%s) with id %d", dr->remote->toStringWithPort(), dr->qname->toString(), QType(dr->qtype).toString(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); + infolog("Answer to %s for %s %s (%s) with id %d", dr->ids.origRemote.toStringWithPort(), dr->ids.qname.toString(), QType(dr->ids.qtype).toString(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); } } } else { if (d_includeTimestamp) { - fprintf(fp.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %d\n", static_cast(dr->queryTime->tv_sec), static_cast(dr->queryTime->tv_nsec), dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); + fprintf(fp.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %d\n", static_cast(dr->queryTime.tv_sec), static_cast(dr->queryTime.tv_nsec), dr->ids.origRemote.toStringWithPort().c_str(), dr->ids.qname.toString().c_str(), QType(dr->ids.qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); } else { - fprintf(fp.get(), "Answer to %s for %s %s (%s) with id %d\n", dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); + fprintf(fp.get(), "Answer to %s for %s %s (%s) with id %d\n", dr->ids.origRemote.toStringWithPort().c_str(), dr->ids.qname.toString().c_str(), QType(dr->ids.qtype).toString().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); } } return Action::None; @@ -1242,7 +1242,7 @@ public: // this action does not stop the processing DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->skipCache = true; + dq->ids.skipCache = true; return Action::None; } std::string toString() const override @@ -1256,7 +1256,7 @@ class SetSkipCacheResponseAction : public DNSResponseAction public: DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override { - dr->skipCache = true; + dr->ids.skipCache = true; return Action::None; } std::string toString() const override @@ -1274,7 +1274,7 @@ public: } DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->tempFailureTTL = d_ttl; + dq->ids.tempFailureTTL = d_ttl; return Action::None; } std::string toString() const override @@ -1294,7 +1294,7 @@ public: } DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->ecsPrefixLength = dq->remote->sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength; + dq->ecsPrefixLength = dq->ids.origRemote.sin4.sin_family == AF_INET ? d_v4PrefixLength : d_v6PrefixLength; return Action::None; } std::string toString() const override @@ -1356,13 +1356,11 @@ public: DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->ecsSet = true; - if (d_hasV6) { - dq->ecs = dq->remote->isIPv4() ? d_v4 : d_v6; + dq->ecs = std::make_unique(dq->ids.origRemote.isIPv4() ? d_v4 : d_v6); } else { - dq->ecs = d_v4; + dq->ecs = std::make_unique(d_v4);; } return Action::None; @@ -1440,7 +1438,7 @@ public: data.clear(); DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dq->getProtocol()); - DnstapMessage message(data, !dq->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, dq->remote, dq->local, protocol, reinterpret_cast(dq->getData().data()), dq->getData().size(), dq->queryTime, nullptr); + DnstapMessage message(data, !dq->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, &dq->ids.origRemote, &dq->ids.origDest, protocol, reinterpret_cast(dq->getData().data()), dq->getData().size(), &dq->queryTime, nullptr); { if (d_alterFunc) { auto lock = g_lua.lock(); @@ -1471,8 +1469,8 @@ public: } DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - if (!dq->uniqueId) { - dq->uniqueId = getUniqueID(); + if (!dq->ids.uniqueId) { + dq->ids.uniqueId = getUniqueID(); } DNSDistProtoBufMessage message(*dq); @@ -1483,7 +1481,7 @@ public: #if HAVE_IPCIPHER if (!d_ipEncryptKey.empty()) { - message.setRequestor(encryptCA(*dq->remote, d_ipEncryptKey)); + message.setRequestor(encryptCA(dq->ids.origRemote, d_ipEncryptKey)); } #endif /* HAVE_IPCIPHER */ @@ -1573,7 +1571,7 @@ public: data.clear(); DnstapMessage::ProtocolType protocol = ProtocolToDNSTap(dr->getProtocol()); - DnstapMessage message(data, DnstapMessage::MessageType::client_response, d_identity, dr->remote, dr->local, protocol, reinterpret_cast(dr->getData().data()), dr->getData().size(), dr->queryTime, &now); + DnstapMessage message(data, DnstapMessage::MessageType::client_response, d_identity, &dr->ids.origRemote, &dr->ids.origDest, protocol, reinterpret_cast(dr->getData().data()), dr->getData().size(), &dr->queryTime, &now); { if (d_alterFunc) { auto lock = g_lua.lock(); @@ -1604,8 +1602,8 @@ public: } DNSResponseAction::Action operator()(DNSResponse* dr, std::string* ruleresult) const override { - if (!dr->uniqueId) { - dr->uniqueId = getUniqueID(); + if (!dr->ids.uniqueId) { + dr->ids.uniqueId = getUniqueID(); } DNSDistProtoBufMessage message(*dr, d_includeCNAME); @@ -1616,7 +1614,7 @@ public: #if HAVE_IPCIPHER if (!d_ipEncryptKey.empty()) { - message.setRequestor(encryptCA(*dr->remote, d_ipEncryptKey)); + message.setRequestor(encryptCA(dr->ids.origRemote, d_ipEncryptKey)); } #endif /* HAVE_IPCIPHER */ @@ -1810,11 +1808,11 @@ public: DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - if (!dq->du) { + if (!dq->ids.du) { return Action::None; } - dq->du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); + dq->ids.du->setHTTPResponse(d_code, PacketBuffer(d_body), d_contentType); dq->getHeader()->qr = true; // for good measure setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); return Action::HeaderModify; diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index f58047f5d7..e53637ef59 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -29,27 +29,27 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) #ifndef DISABLE_NON_FFI_DQ_BINDINGS /* DNSQuestion */ /* PowerDNS DNSQuestion compat */ - luaCtx.registerMember("localaddr", [](const DNSQuestion& dq) -> const ComboAddress { return *dq.local; }, [](DNSQuestion& dq, const ComboAddress newLocal) { (void) newLocal; }); - luaCtx.registerMember("qname", [](const DNSQuestion& dq) -> const DNSName { return *dq.qname; }, [](DNSQuestion& dq, const DNSName newName) { (void) newName; }); - luaCtx.registerMember("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; }); - luaCtx.registerMember("qclass", [](const DNSQuestion& dq) -> uint16_t { return dq.qclass; }, [](DNSQuestion& dq, uint16_t newClass) { (void) newClass; }); + luaCtx.registerMember("localaddr", [](const DNSQuestion& dq) -> const ComboAddress { return dq.ids.origDest; }, [](DNSQuestion& dq, const ComboAddress newLocal) { (void) newLocal; }); + luaCtx.registerMember("qname", [](const DNSQuestion& dq) -> const DNSName { return dq.ids.qname; }, [](DNSQuestion& dq, const DNSName newName) { (void) newName; }); + luaCtx.registerMember("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; }); + luaCtx.registerMember("qclass", [](const DNSQuestion& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSQuestion& dq, uint16_t newClass) { (void) newClass; }); luaCtx.registerMember("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.getHeader()->rcode = newRCode; }); - luaCtx.registerMember("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return *dq.remote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; }); + luaCtx.registerMember("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; }); /* DNSDist DNSQuestion */ luaCtx.registerMember("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast(dq).getHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { *(dq.getHeader()) = *dh; }); luaCtx.registerMember("len", [](const DNSQuestion& dq) -> uint16_t { return dq.getData().size(); }, [](DNSQuestion& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); }); luaCtx.registerMember("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; }); luaCtx.registerMember("tcp", [](const DNSQuestion& dq) -> bool { return dq.overTCP(); }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; }); - luaCtx.registerMember("skipCache", [](const DNSQuestion& dq) -> bool { return dq.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.skipCache = newSkipCache; }); + luaCtx.registerMember("skipCache", [](const DNSQuestion& dq) -> bool { return dq.ids.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; }); luaCtx.registerMember("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; }); luaCtx.registerMember("ecsOverride", [](const DNSQuestion& dq) -> bool { return dq.ecsOverride; }, [](DNSQuestion& dq, bool ecsOverride) { dq.ecsOverride = ecsOverride; }); luaCtx.registerMember("ecsPrefixLength", [](const DNSQuestion& dq) -> uint16_t { return dq.ecsPrefixLength; }, [](DNSQuestion& dq, uint16_t newPrefixLength) { dq.ecsPrefixLength = newPrefixLength; }); luaCtx.registerMember (DNSQuestion::*)>("tempFailureTTL", [](const DNSQuestion& dq) -> boost::optional { - return dq.tempFailureTTL; + return dq.ids.tempFailureTTL; }, [](DNSQuestion& dq, boost::optional newValue) { - dq.tempFailureTTL = newValue; + dq.ids.tempFailureTTL = newValue; } ); luaCtx.registerFunction("getDO", [](const DNSQuestion& dq) { @@ -97,24 +97,24 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) } }); luaCtx.registerFunction("getTag", [](const DNSQuestion& dq, const std::string& strLabel) { - if (!dq.qTag) { + if (!dq.ids.qTag) { return string(); } std::string strValue; - const auto it = dq.qTag->find(strLabel); - if (it == dq.qTag->cend()) { + const auto it = dq.ids.qTag->find(strLabel); + if (it == dq.ids.qTag->cend()) { return string(); } return it->second; }); luaCtx.registerFunction("getTagArray", [](const DNSQuestion& dq) { - if (!dq.qTag) { + if (!dq.ids.qTag) { QTag empty; return empty; } - return *dq.qTag; + return *dq.ids.qTag; }); luaCtx.registerFunction)>("setProxyProtocolValues", [](DNSQuestion& dq, const LuaArray& values) { @@ -185,17 +185,17 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) }); /* LuaWrapper doesn't support inheritance */ - luaCtx.registerMember("localaddr", [](const DNSResponse& dq) -> const ComboAddress { return *dq.local; }, [](DNSResponse& dq, const ComboAddress newLocal) { (void) newLocal; }); - luaCtx.registerMember("qname", [](const DNSResponse& dq) -> const DNSName { return *dq.qname; }, [](DNSResponse& dq, const DNSName newName) { (void) newName; }); - luaCtx.registerMember("qtype", [](const DNSResponse& dq) -> uint16_t { return dq.qtype; }, [](DNSResponse& dq, uint16_t newType) { (void) newType; }); - luaCtx.registerMember("qclass", [](const DNSResponse& dq) -> uint16_t { return dq.qclass; }, [](DNSResponse& dq, uint16_t newClass) { (void) newClass; }); + luaCtx.registerMember("localaddr", [](const DNSResponse& dq) -> const ComboAddress { return dq.ids.origDest; }, [](DNSResponse& dq, const ComboAddress newLocal) { (void) newLocal; }); + luaCtx.registerMember("qname", [](const DNSResponse& dq) -> const DNSName { return dq.ids.qname; }, [](DNSResponse& dq, const DNSName newName) { (void) newName; }); + luaCtx.registerMember("qtype", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qtype; }, [](DNSResponse& dq, uint16_t newType) { (void) newType; }); + luaCtx.registerMember("qclass", [](const DNSResponse& dq) -> uint16_t { return dq.ids.qclass; }, [](DNSResponse& dq, uint16_t newClass) { (void) newClass; }); luaCtx.registerMember("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) { dq.getHeader()->rcode = newRCode; }); - luaCtx.registerMember("remoteaddr", [](const DNSResponse& dq) -> const ComboAddress { return *dq.remote; }, [](DNSResponse& dq, const ComboAddress newRemote) { (void) newRemote; }); + luaCtx.registerMember("remoteaddr", [](const DNSResponse& dq) -> const ComboAddress { return dq.ids.origRemote; }, [](DNSResponse& dq, const ComboAddress newRemote) { (void) newRemote; }); luaCtx.registerMember("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast(dr).getHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { *(dr.getHeader()) = *dh; }); luaCtx.registerMember("len", [](const DNSResponse& dq) -> uint16_t { return dq.getData().size(); }, [](DNSResponse& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); }); luaCtx.registerMember("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; }); luaCtx.registerMember("tcp", [](const DNSResponse& dq) -> bool { return dq.overTCP(); }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; }); - luaCtx.registerMember("skipCache", [](const DNSResponse& dq) -> bool { return dq.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.skipCache = newSkipCache; }); + luaCtx.registerMember("skipCache", [](const DNSResponse& dq) -> bool { return dq.ids.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.ids.skipCache = newSkipCache; }); luaCtx.registerFunction editFunc)>("editTTLs", [](DNSResponse& dr, std::function editFunc) { editDNSPacketTTL(reinterpret_cast(dr.getMutableData().data()), dr.getData().size(), editFunc); }); @@ -229,24 +229,24 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) } }); luaCtx.registerFunction("getTag", [](const DNSResponse& dr, const std::string& strLabel) { - if (!dr.qTag) { + if (!dr.ids.qTag) { return string(); } std::string strValue; - const auto it = dr.qTag->find(strLabel); - if (it == dr.qTag->cend()) { + const auto it = dr.ids.qTag->find(strLabel); + if (it == dr.ids.qTag->cend()) { return string(); } return it->second; }); luaCtx.registerFunction("getTagArray", [](const DNSResponse& dr) { - if (!dr.qTag) { + if (!dr.ids.qTag) { QTag empty; return empty; } - return *dr.qTag; + return *dr.ids.qTag; }); luaCtx.registerFunction("getProtocol", [](const DNSResponse& dr) { @@ -263,47 +263,47 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) #ifdef HAVE_DNS_OVER_HTTPS luaCtx.registerFunction("getHTTPPath", [](const DNSQuestion& dq) { - if (dq.du == nullptr) { + if (dq.ids.du == nullptr) { return std::string(); } - return dq.du->getHTTPPath(); + return dq.ids.du->getHTTPPath(); }); luaCtx.registerFunction("getHTTPQueryString", [](const DNSQuestion& dq) { - if (dq.du == nullptr) { + if (dq.ids.du == nullptr) { return std::string(); } - return dq.du->getHTTPQueryString(); + return dq.ids.du->getHTTPQueryString(); }); luaCtx.registerFunction("getHTTPHost", [](const DNSQuestion& dq) { - if (dq.du == nullptr) { + if (dq.ids.du == nullptr) { return std::string(); } - return dq.du->getHTTPHost(); + return dq.ids.du->getHTTPHost(); }); luaCtx.registerFunction("getHTTPScheme", [](const DNSQuestion& dq) { - if (dq.du == nullptr) { + if (dq.ids.du == nullptr) { return std::string(); } - return dq.du->getHTTPScheme(); + return dq.ids.du->getHTTPScheme(); }); luaCtx.registerFunction(DNSQuestion::*)(void)const>("getHTTPHeaders", [](const DNSQuestion& dq) { - if (dq.du == nullptr) { + if (dq.ids.du == nullptr) { return LuaAssociativeTable(); } - return dq.du->getHTTPHeaders(); + return dq.ids.du->getHTTPHeaders(); }); luaCtx.registerFunction contentType)>("setHTTPResponse", [](DNSQuestion& dq, uint64_t statusCode, const std::string& body, const boost::optional contentType) { - if (dq.du == nullptr) { + if (dq.ids.du == nullptr) { return; } checkParameterBound("DNSQuestion::setHTTPResponse", statusCode, std::numeric_limits::max()); PacketBuffer vect(body.begin(), body.end()); - dq.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : ""); + dq.ids.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : ""); }); #endif /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/dnsdist-lua-rules.cc b/pdns/dnsdist-lua-rules.cc index a234a3fd61..07bb3efd32 100644 --- a/pdns/dnsdist-lua-rules.cc +++ b/pdns/dnsdist-lua-rules.cc @@ -448,31 +448,31 @@ void setupLuaRules(LuaContext& luaCtx) DNSName suffix(suffix_.get_value_or("powerdns.com")); struct item { PacketBuffer packet; - ComboAddress rem; - DNSName qname; - uint16_t qtype, qclass; + InternalQueryState ids; }; vector items; items.reserve(1000); - for(int n=0; n < 1000; ++n) { + for (int n = 0; n < 1000; ++n) { struct item i; - i.qname=DNSName(std::to_string(random())); - i.qname += suffix; - i.qtype = random() % 0xff; - i.qclass = 1; - i.rem=ComboAddress("127.0.0.1"); - i.rem.sin4.sin_addr.s_addr = random(); - GenericDNSPacketWriter pw(i.packet, i.qname, i.qtype); - items.push_back(i); + i.ids.qname = DNSName(std::to_string(random())); + i.ids.qname += suffix; + i.ids.qtype = random() % 0xff; + i.ids.qclass = QClass::IN; + i.ids.protocol = dnsdist::Protocol::DoUDP; + i.ids.origRemote = ComboAddress("127.0.0.1"); + i.ids.origRemote.sin4.sin_addr.s_addr = random(); + GenericDNSPacketWriter pw(i.packet, i.ids.qname, i.ids.qtype); + items.push_back(std::move(i)); } - int matches=0; + int matches = 0; ComboAddress dummy("127.0.0.1"); StopWatch sw; sw.start(); - for(unsigned int n=0; n < times; ++n) { + for (unsigned int n = 0; n < times; ++n) { item& i = items[n % items.size()]; - DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, dnsdist::Protocol::DoUDP, &sw.d_start); + DNSQuestion dq(i.ids, i.packet, sw.d_start); + if (rule->matches(&dq)) { matches++; } diff --git a/pdns/dnsdist-protobuf.cc b/pdns/dnsdist-protobuf.cc index 7dfd0648cb..be73fb9081 100644 --- a/pdns/dnsdist-protobuf.cc +++ b/pdns/dnsdist-protobuf.cc @@ -145,7 +145,7 @@ void DNSDistProtoBufMessage::serialize(std::string& data) const protocol = pdns::ProtoZero::Message::TransportProtocol::DNSCryptTCP; } - m.setRequest(d_dq.uniqueId ? *d_dq.uniqueId : getUniqueID(), d_requestor ? *d_requestor : *d_dq.remote, d_responder ? *d_responder : *d_dq.local, d_question ? d_question->d_name : *d_dq.qname, d_question ? d_question->d_type : d_dq.qtype, d_question ? d_question->d_class : d_dq.qclass, d_dq.getHeader()->id, protocol, d_bytes ? *d_bytes : d_dq.getData().size()); + m.setRequest(d_dq.ids.uniqueId ? *d_dq.ids.uniqueId : getUniqueID(), d_requestor ? *d_requestor : d_dq.ids.origRemote, d_responder ? *d_responder : d_dq.ids.origDest, d_question ? d_question->d_name : d_dq.ids.qname, d_question ? d_question->d_type : d_dq.ids.qtype, d_question ? d_question->d_class : d_dq.ids.qclass, d_dq.getHeader()->id, protocol, d_bytes ? *d_bytes : d_dq.getData().size()); if (d_serverIdentity) { m.setServerIdentity(*d_serverIdentity); @@ -164,7 +164,7 @@ void DNSDistProtoBufMessage::serialize(std::string& data) const m.setQueryTime(d_queryTime->first, d_queryTime->second); } else { - m.setQueryTime(d_dq.queryTime->tv_sec, d_dq.queryTime->tv_nsec / 1000); + m.setQueryTime(d_dq.queryTime.tv_sec, d_dq.queryTime.tv_nsec / 1000); } if (d_dr != nullptr) { diff --git a/pdns/dnsdist-snmp.cc b/pdns/dnsdist-snmp.cc index 43f425f7ff..784af371ab 100644 --- a/pdns/dnsdist-snmp.cc +++ b/pdns/dnsdist-snmp.cc @@ -445,16 +445,16 @@ bool DNSDistSNMPAgent::sendCustomTrap(const std::string& reason) bool DNSDistSNMPAgent::sendDNSTrap(const DNSQuestion& dq, const std::string& reason) { #ifdef HAVE_NET_SNMP - std::string local = dq.local->toString(); - std::string remote = dq.remote->toString(); - std::string qname = dq.qname->toStringNoDot(); - const uint32_t socketFamily = dq.remote->isIPv4() ? 1 : 2; + std::string local = dq.ids.origDest.toString(); + std::string remote = dq.ids.origRemote.toString(); + std::string qname = dq.ids.qname.toStringNoDot(); + const uint32_t socketFamily = dq.ids.origRemote.isIPv4() ? 1 : 2; const uint32_t socketProtocol = dq.overTCP() ? 2 : 1; const uint32_t queryType = dq.getHeader()->qr ? 2 : 1; const uint32_t querySize = (uint32_t) dq.getData().size(); const uint32_t queryID = (uint32_t) ntohs(dq.getHeader()->id); - const uint32_t qType = (uint32_t) dq.qtype; - const uint32_t qClass = (uint32_t) dq.qclass; + const uint32_t qType = (uint32_t) dq.ids.qtype; + const uint32_t qClass = (uint32_t) dq.ids.qclass; netsnmp_variable_list* varList = nullptr; diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index c1c220f162..150c8de657 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -539,11 +539,11 @@ void IncomingTCPConnectionState::handleResponse(const struct timeval& now, TCPRe ++response.d_connection->getDS()->responses; } - DNSResponse dr = makeDNSResponseFromIDState(ids, response.d_buffer); + DNSResponse dr(ids, response.d_buffer, ids.sentTime.d_start, response.d_connection->getDS()); memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); - if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false, false)) { + if (!processResponse(response.d_buffer, *state->d_threadData.localRespRuleActions, *state->d_threadData.localCacheInsertedRespRuleActions, dr, false)) { state->terminateClientConnection(); return; } @@ -574,7 +574,7 @@ struct TCPCrossProtocolResponse class TCPCrossProtocolQuery : public CrossProtocolQuery { public: - TCPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr ds, std::shared_ptr sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) + TCPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr& ds, std::shared_ptr sender): CrossProtocolQuery(InternalQuery(std::move(buffer), std::move(ids)), ds), d_sender(std::move(sender)) { proxyProtocolPayloadSize = 0; } @@ -652,9 +652,12 @@ static void handleQuery(std::shared_ptr& state, cons rings for example */ struct timespec queryRealTime; gettime(&queryRealTime, true); + InternalQueryState ids; + ids.origDest = state->d_proxiedDestination; + ids.origRemote = state->d_proxiedRemote; + ids.cs = state->d_ci.cs; - std::unique_ptr dnsCryptQuery{nullptr}; - auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, dnsCryptQuery, queryRealTime.tv_sec, true); + auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, ids.dnsCryptQuery, queryRealTime.tv_sec, true); if (dnsCryptResponse) { TCPResponse response; state->d_state = IncomingTCPConnectionState::State::idle; @@ -684,19 +687,19 @@ static void handleQuery(std::shared_ptr& state, cons } } - uint16_t qtype, qclass; - unsigned int qnameWireLength = 0; - DNSName qname(reinterpret_cast(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); - dnsdist::Protocol protocol = dnsdist::Protocol::DoTCP; - if (dnsCryptQuery) { - protocol = dnsdist::Protocol::DNSCryptTCP; + ids.qname = DNSName(reinterpret_cast(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + ids.protocol = dnsdist::Protocol::DoTCP; + if (ids.dnsCryptQuery) { + ids.protocol = dnsdist::Protocol::DNSCryptTCP; } else if (state->d_handler.isTLS()) { - protocol = dnsdist::Protocol::DoT; + ids.protocol = dnsdist::Protocol::DoT; } - DNSQuestion dq(&qname, qtype, qclass, &state->d_proxiedDestination, &state->d_proxiedRemote, state->d_buffer, protocol, &queryRealTime); - dq.dnsCryptQuery = std::move(dnsCryptQuery); + DNSQuestion dq(ids, state->d_buffer, queryRealTime); + const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); + ids.origFlags = *flags; + dq.sni = state->d_handler.getServerNameIndication(); if (state->d_proxyProtocolValues) { /* we need to copy them, because the next queries received on that connection will @@ -704,8 +707,8 @@ static void handleQuery(std::shared_ptr& state, cons dq.proxyProtocolValues = make_unique>(*state->d_proxyProtocolValues); } - if (dq.qtype == QType::AXFR || dq.qtype == QType::IXFR) { - dq.skipCache = true; + if (dq.ids.qtype == QType::AXFR || dq.ids.qtype == QType::IXFR) { + dq.ids.skipCache = true; } std::shared_ptr ds; @@ -720,12 +723,10 @@ static void handleQuery(std::shared_ptr& state, cons const dnsheader* dh = dq.getHeader(); if (result == ProcessQueryResult::SendAnswer) { TCPResponse response; - response.d_selfGenerated = true; + memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); + response.d_idstate = std::move(ids); response.d_idstate.origID = dh->id; response.d_idstate.cs = state->d_ci.cs; - setIDStateFromDNSQuestion(response.d_idstate, dq, std::move(qname)); - - memcpy(&response.d_cleartextDH, dh, sizeof(response.d_cleartextDH)); response.d_buffer = std::move(state->d_buffer); state->d_state = IncomingTCPConnectionState::State::idle; @@ -739,16 +740,13 @@ static void handleQuery(std::shared_ptr& state, cons return; } - IDState ids; - setIDStateFromDNSQuestion(ids, dq, std::move(qname)); - ids.origID = dh->id; - ids.cs = state->d_ci.cs; + dq.ids.origID = dh->id; ++state->d_currentQueriesCount; std::string proxyProtocolPayload; if (ds->isDoH()) { - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getName()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", ids.qname.toLogString(), QType(ids.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), state->d_buffer.size(), ds->getNameWithAddr()); /* we need to do this _before_ creating the cross protocol query because after that the buffer will have been moved */ @@ -783,7 +781,7 @@ static void handleQuery(std::shared_ptr& state, cons TCPQuery query(std::move(state->d_buffer), std::move(ids)); query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); - vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getName()); + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), state->d_proxiedRemote.toStringWithPort(), (state->d_handler.isTLS() ? "DoT" : "TCP"), query.d_buffer.size(), ds->getNameWithAddr()); std::shared_ptr incoming = state; downstreamConnection->queueQuery(incoming, std::move(query)); } @@ -1055,7 +1053,7 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptrd_lastIOBlocked); } -void IncomingTCPConnectionState::notifyIOError(IDState&& query, const struct timeval& now) +void IncomingTCPConnectionState::notifyIOError(InternalQueryState&& query, const struct timeval& now) { std::shared_ptr state = shared_from_this(); @@ -1182,6 +1180,9 @@ static void handleCrossProtocolQuery(int pipefd, FDMultiplexer::funcparam_t& par prependSizeToTCPQuery(query.d_buffer, proxyProtocolPayloadSize); query.d_proxyProtocolPayloadAddedSize = proxyProtocolPayloadSize; + + vinfolog("Got query for %s|%s from %s (%s, %d bytes), relayed to %s", query.d_idstate.qname.toLogString(), QType(query.d_idstate.qtype).toString(), query.d_idstate.origRemote.toStringWithPort(), query.d_idstate.protocol.toString(), query.d_buffer.size(), downstreamServer->getNameWithAddr()); + downstream->queueQuery(tqs, std::move(query)); } catch (...) { diff --git a/pdns/dnsdist-xpf.cc b/pdns/dnsdist-xpf.cc index b02ad62176..6f4cba5315 100644 --- a/pdns/dnsdist-xpf.cc +++ b/pdns/dnsdist-xpf.cc @@ -27,7 +27,7 @@ bool addXPF(DNSQuestion& dq, uint16_t optionCode) { - std::string payload = generateXPFPayload(dq.overTCP(), *dq.remote, *dq.local); + std::string payload = generateXPFPayload(dq.overTCP(), dq.ids.origRemote, dq.ids.origDest); uint8_t root = '\0'; dnsrecordheader drh; drh.d_type = htons(optionCode); diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index 99f5522b94..d122976a4c 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -249,7 +249,7 @@ bool DNSQuestion::setTrailingData(const std::string& tail) return true; } -void doLatencyStats(dnsdist::Protocol protocol, double udiff) +static void doLatencyStats(dnsdist::Protocol protocol, double udiff) { constexpr auto doAvg = [](double& var, double n, double weight) { var = (weight -1) * var/weight + n/weight; @@ -502,7 +502,7 @@ static bool applyRulesToResponse(const std::vector& r break; /* non-terminal actions follow */ case DNSResponseAction::Action::Delay: - pdns::checked_stoi_into(dr.delayMsec, ruleresult); // sorry + pdns::checked_stoi_into(dr.ids.delayMsec, ruleresult); // sorry break; case DNSResponseAction::Action::None: break; @@ -513,21 +513,19 @@ static bool applyRulesToResponse(const std::vector& r return true; } -// whether the query was received over TCP or not (for rules, dnstap, protobuf, ...) will be taken from the DNSResponse, but receivedOverUDP is used to insert into the cache, -// so that answers received over UDP for DoH are still cached with UDP answers. -bool processResponse(PacketBuffer& response, const vector& respRuleActions, const vector& insertedRespRuleActions, DNSResponse& dr, bool muted, bool receivedOverUDP) +bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& insertedRespRuleActions, DNSResponse& dr, bool muted) { if (!applyRulesToResponse(respRuleActions, dr)) { return false; } bool zeroScope = false; - if (!fixUpResponse(response, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, dr.useZeroScope ? &zeroScope : nullptr)) { + if (!fixUpResponse(response, dr.ids.qname, dr.ids.origFlags, dr.ids.ednsAdded, dr.ids.ecsAdded, dr.ids.useZeroScope ? &zeroScope : nullptr)) { return false; } - if (dr.packetCache && !dr.skipCache && response.size() <= s_maxPacketCacheEntrySize) { - if (!dr.useZeroScope) { + if (dr.ids.packetCache && !dr.ids.skipCache && response.size() <= s_maxPacketCacheEntrySize) { + if (!dr.ids.useZeroScope) { /* if the query was not suitable for zero-scope, for example because it had an existing ECS entry so the hash is not really 'no ECS', so just insert it for the existing subnet @@ -538,16 +536,16 @@ bool processResponse(PacketBuffer& response, const vectorinsert(cacheKey, zeroScope ? boost::none : dr.subnet, dr.cacheFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, response, receivedOverUDP, dr.getHeader()->rcode, dr.tempFailureTTL); + dr.ids.packetCache->insert(cacheKey, zeroScope ? boost::none : dr.ids.subnet, dr.ids.cacheFlags, dr.ids.dnssecOK, dr.ids.qname, dr.ids.qtype, dr.ids.qclass, response, dr.ids.forwardedOverUDP, dr.getHeader()->rcode, dr.ids.tempFailureTTL); if (!applyRulesToResponse(insertedRespRuleActions, dr)) { return false; @@ -556,7 +554,7 @@ bool processResponse(PacketBuffer& response, const vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, const std::shared_ptr& ds, bool selfGenerated, std::optional queryId) +static void handleResponseForUDPClient(InternalQueryState& ids, PacketBuffer& response, uint16_t maxPayloadSize, const std::vector& respRuleActions, const std::vector& cacheInsertedRespRuleActions, const std::shared_ptr& ds, bool selfGenerated, std::optional queryId) { - DNSResponse dr = makeDNSResponseFromIDState(ids, response); + DNSResponse dr(ids, response, ids.sentTime.d_start, ds); if (maxPayloadSize > 0 && response.size() > maxPayloadSize) { vinfolog("Got a response of size %d while the initial UDP payload size was %d, truncating", response.size(), maxPayloadSize); - truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.qname->wirelength()); + truncateTC(dr.getMutableData(), dr.getMaximumSize(), dr.ids.qname.wirelength()); dr.getHeader()->tc = true; } else if (dr.getHeader()->tc && g_truncateTC) { - truncateTC(response, dr.getMaximumSize(), dr.qname->wirelength()); + truncateTC(response, dr.getMaximumSize(), dr.ids.qname.wirelength()); } /* when the answer is encrypted in place, we need to get a copy @@ -654,7 +652,7 @@ static void handleResponseForUDPClient(IDState& ids, PacketBuffer& response, uin dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); - if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted, true)) { + if (!processResponse(response, respRuleActions, cacheInsertedRespRuleActions, dr, ids.cs && ids.cs->muted)) { if (queryId) { ds->releaseState(*queryId); } @@ -670,7 +668,7 @@ static void handleResponseForUDPClient(IDState& ids, PacketBuffer& response, uin if (ids.cs && !ids.cs->muted) { ComboAddress empty; empty.sin4.sin_family = 0; - sendUDPResponse(ids.cs->udpFD, response, dr.delayMsec, ids.hopLocal, ids.hopRemote); + sendUDPResponse(ids.cs->udpFD, response, dr.ids.delayMsec, ids.hopLocal, ids.hopRemote); muted = false; } @@ -683,10 +681,10 @@ static void handleResponseForUDPClient(IDState& ids, PacketBuffer& response, uin vinfolog("Got answer from %s, NOT relayed to %s (UDP) since that frontend is muted, took %f usec", ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); } - handleResponseSent(ids, udiff, *dr.remote, ds->d_config.remote, response.size(), cleartextDH, ds->getProtocol()); + handleResponseSent(ids, udiff, dr.ids.origRemote, ds->d_config.remote, response.size(), cleartextDH, ds->getProtocol()); } else { - handleResponseSent(ids, 0., *dr.remote, ComboAddress(), response.size(), cleartextDH, dnsdist::Protocol::DoUDP); + handleResponseSent(ids, 0., dr.ids.origRemote, ComboAddress(), response.size(), cleartextDH, dnsdist::Protocol::DoUDP); } if (queryId) { @@ -746,26 +744,24 @@ void responderThread(std::shared_ptr dss) continue; } - /* read the potential DOHUnit state as soon as possible, but don't use it - until we have confirmed that we own this state by updating usageIndicator */ - auto du = DOHUnitUniquePtr(ids->du, DOHUnit::release); /* setting age to 0 to prevent the maintainer thread from cleaning this IDS while we process the response. */ ids->age = 0; unsigned int qnameWireLength = 0; - if (fd != ids->backendFD || !responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss, qnameWireLength)) { + if (fd != ids->internal.backendFD || !responseContentMatches(response, ids->internal.qname, ids->internal.qtype, ids->internal.qclass, dss, qnameWireLength)) { continue; } + DOHUnitUniquePtr du(nullptr, DOHUnit::release); /* atomically mark the state as available, but only if it has not been altered in the meantime */ if (ids->tryMarkUnused(usageIndicator)) { /* clear the potential DOHUnit asap, it's ours now and since we just marked the state as unused, someone could overwrite it. */ - ids->du = nullptr; + du = std::move(ids->internal.du); /* we only decrement the outstanding counter if the value was not altered in the meantime, which would mean that the state has been actively reused and the other thread has not incremented the outstanding counter, so we don't @@ -779,10 +775,10 @@ void responderThread(std::shared_ptr dss) continue; } - dh->id = ids->origID; + dh->id = ids->internal.origID; ++dss->responses; - double udiff = ids->sentTime.udiff(); + double udiff = ids->internal.sentTime.udiff(); // do that _before_ the processing, otherwise it's not fair to the backend dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff / 128.0; dss->reportResponse(dh->rcode); @@ -791,30 +787,27 @@ void responderThread(std::shared_ptr dss) if (du) { #ifdef HAVE_DNS_OVER_HTTPS // DoH query, we cannot touch du after that - handleUDPResponseForDoH(std::move(du), std::move(response), std::move(*ids)); + handleUDPResponseForDoH(std::move(du), std::move(response), std::move(ids->internal)); #endif dss->releaseState(queryId); continue; } - handleResponseForUDPClient(*ids, response, 0, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, queryId); + handleResponseForUDPClient(ids->internal, response, 0, *localRespRuleActions, *localCacheInsertedRespRuleActions, dss, false, queryId); } } - catch (const std::exception& e){ + catch (const std::exception& e) { vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->d_config.remote.toStringWithPort(), queryId, e.what()); } } } -catch (const std::exception& e) -{ +catch (const std::exception& e) { errlog("UDP responder thread died because of exception: %s", e.what()); } -catch (const PDNSException& e) -{ +catch (const PDNSException& e) { errlog("UDP responder thread died because of PowerDNS exception: %s", e.reason); } -catch (...) -{ +catch (...) { errlog("UDP responder thread died because of an exception: %s", "unknown"); } } @@ -929,7 +922,7 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s case DNSAction::Action::Pool: /* we need to keep this because a custom Lua action can return DNSAction.Spoof, 'poolname' */ - dq.poolname = ruleresult; + dq.ids.poolName = ruleresult; return true; break; case DNSAction::Action::NoRecurse: @@ -938,7 +931,7 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s break; /* non-terminal actions follow */ case DNSAction::Action::Delay: - pdns::checked_stoi_into(dq.delayMsec, ruleresult); // sorry + pdns::checked_stoi_into(dq.ids.delayMsec, ruleresult); // sorry break; case DNSAction::Action::None: /* fall-through */ @@ -954,11 +947,11 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now) { if (g_rings.shouldRecordQueries()) { - g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.getData().size(), *dq.getHeader(), dq.getProtocol()); + g_rings.insertQuery(now, dq.ids.origRemote, dq.ids.qname, dq.ids.qtype, dq.getData().size(), *dq.getHeader(), dq.getProtocol()); } if (g_qcount.enabled) { - string qname = (*dq.qname).toLogString(); + string qname = dq.ids.qname.toLogString(); bool countQuery{true}; if (g_qcount.filter) { auto lock = g_lua.lock(); @@ -976,7 +969,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru #ifndef DISABLE_DYNBLOCKS /* the Dynamic Block mechanism supports address and port ranges, so we need to pass the full address and port */ - if (auto got = holders.dynNMGBlock->lookup(AddressAndPortRange(*dq.remote, dq.remote->isIPv4() ? 32 : 128, 16))) { + if (auto got = holders.dynNMGBlock->lookup(AddressAndPortRange(dq.ids.origRemote, dq.ids.origRemote.isIPv4() ? 32 : 128, 16))) { auto updateBlockStats = [&got]() { ++g_stats.dynBlocked; got->second.blocks++; @@ -993,7 +986,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru break; case DNSAction::Action::Nxdomain: - vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort()); + vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort()); updateBlockStats(); dq.getHeader()->rcode = RCode::NXDomain; @@ -1001,7 +994,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru return true; case DNSAction::Action::Refused: - vinfolog("Query from %s refused because of dynamic block", dq.remote->toStringWithPort()); + vinfolog("Query from %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort()); updateBlockStats(); dq.getHeader()->rcode = RCode::Refused; @@ -1011,7 +1004,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::Truncate: if (!dq.overTCP()) { updateBlockStats(); - vinfolog("Query from %s truncated because of dynamic block", dq.remote->toStringWithPort()); + vinfolog("Query from %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort()); dq.getHeader()->tc = true; dq.getHeader()->qr = true; dq.getHeader()->ra = dq.getHeader()->rd; @@ -1020,23 +1013,23 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru return true; } else { - vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); + vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); } break; case DNSAction::Action::NoRecurse: updateBlockStats(); - vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort()); + vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort()); dq.getHeader()->rd = false; return true; default: updateBlockStats(); - vinfolog("Query from %s dropped because of dynamic block", dq.remote->toStringWithPort()); + vinfolog("Query from %s dropped because of dynamic block", dq.ids.origRemote.toStringWithPort()); return false; } } } - if (auto got = holders.dynSMTBlock->lookup(*dq.qname)) { + if (auto got = holders.dynSMTBlock->lookup(dq.ids.qname)) { auto updateBlockStats = [&got]() { ++g_stats.dynBlocked; got->blocks++; @@ -1052,14 +1045,14 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru /* do nothing */ break; case DNSAction::Action::Nxdomain: - vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); + vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); updateBlockStats(); dq.getHeader()->rcode = RCode::NXDomain; dq.getHeader()->qr=true; return true; case DNSAction::Action::Refused: - vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); + vinfolog("Query from %s for %s refused because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); updateBlockStats(); dq.getHeader()->rcode = RCode::Refused; @@ -1069,7 +1062,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru if (!dq.overTCP()) { updateBlockStats(); - vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); + vinfolog("Query from %s for %s truncated because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); dq.getHeader()->tc = true; dq.getHeader()->qr = true; dq.getHeader()->ra = dq.getHeader()->rd; @@ -1078,17 +1071,17 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru return true; } else { - vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); + vinfolog("Query from %s for %s over TCP *not* truncated because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); } break; case DNSAction::Action::NoRecurse: updateBlockStats(); - vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort()); + vinfolog("Query from %s setting rd=0 because of dynamic block", dq.ids.origRemote.toStringWithPort()); dq.getHeader()->rd = false; return true; default: updateBlockStats(); - vinfolog("Query from %s for %s dropped because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); + vinfolog("Query from %s for %s dropped because of dynamic block", dq.ids.origRemote.toStringWithPort(), dq.ids.qname.toLogString()); return false; } } @@ -1260,22 +1253,16 @@ static void queueResponse(const ClientState& cs, const PacketBuffer& response, c /* self-generated responses or cache hits */ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit) { - DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, dq.getMutableData(), dq.protocol, dq.queryTime); - - dr.uniqueId = dq.uniqueId; - dr.qTag = std::move(dq.qTag); - dr.delayMsec = dq.delayMsec; + std::shared_ptr ds{nullptr}; + DNSResponse dr(dq.ids, dq.getMutableData(), dq.queryTime, ds); if (!applyRulesToResponse(cacheHit ? *holders.cacheHitRespRuleactions : *holders.selfAnsweredRespRuleactions, dr)) { return false; } - /* in case a rule changed it */ - dq.delayMsec = dr.delayMsec; - #ifdef HAVE_DNSCRYPT if (!cs.muted) { - if (!encryptResponse(dq.getMutableData(), dq.getMaximumSize(), dq.overTCP(), dq.dnsCryptQuery)) { + if (!encryptResponse(dq.getMutableData(), dq.getMaximumSize(), dq.overTCP(), dq.ids.dnsCryptQuery)) { return false; } } @@ -1304,60 +1291,69 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& } if (dq.getHeader()->qr) { // something turned it into a response - fixUpQueryTurnedResponse(dq, dq.origFlags); + fixUpQueryTurnedResponse(dq, dq.ids.origFlags); - if (!prepareOutgoingResponse(holders, cs, dq, false)) { + if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, false)) { return ProcessQueryResult::Drop; } ++g_stats.selfAnswered; - ++cs.responses; + ++dq.ids.cs->responses; return ProcessQueryResult::SendAnswer; } - std::shared_ptr serverPool = getPool(*holders.pools, dq.poolname); + std::shared_ptr serverPool = getPool(*holders.pools, dq.ids.poolName); std::shared_ptr poolPolicy = serverPool->policy; - dq.packetCache = serverPool->packetCache; + dq.ids.packetCache = serverPool->packetCache; const auto& policy = poolPolicy != nullptr ? *poolPolicy : *(holders.policy); const auto servers = serverPool->getServers(); selectedBackend = policy.getSelectedBackend(*servers, dq); uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL; - if (dq.packetCache && !dq.skipCache) { - dq.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO); + if (dq.ids.packetCache && !dq.ids.skipCache) { + dq.ids.dnssecOK = (getEDNSZ(dq) & EDNS_HEADER_FLAG_DO); } if (dq.useECS && ((selectedBackend && selectedBackend->d_config.useECS) || (!selectedBackend && serverPool->getECS()))) { // we special case our cache in case a downstream explicitly gave us a universally valid response with a 0 scope // we need ECS parsing (parseECS) to be true so we can be sure that the initial incoming query did not have an existing // ECS option, which would make it unsuitable for the zero-scope feature. - if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->d_config.disableZeroScope) && dq.packetCache->isECSParsingEnabled()) { - if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, !dq.overTCP(), allowExpired)) { + if (dq.ids.packetCache && !dq.ids.skipCache && (!selectedBackend || !selectedBackend->d_config.disableZeroScope) && dq.ids.packetCache->isECSParsingEnabled()) { + if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKeyNoECS, dq.ids.subnet, dq.ids.dnssecOK, !dq.overTCP(), allowExpired)) { - if (!prepareOutgoingResponse(holders, cs, dq, true)) { + vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size()); + + if (!prepareOutgoingResponse(holders, *dq.ids.cs, dq, true)) { return ProcessQueryResult::Drop; } return ProcessQueryResult::SendAnswer; } - if (!dq.subnet) { + if (!dq.ids.subnet) { /* there was no existing ECS on the query, enable the zero-scope feature */ - dq.useZeroScope = true; + dq.ids.useZeroScope = true; } } - if (!handleEDNSClientSubnet(dq, dq.ednsAdded, dq.ecsAdded)) { - vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort()); + if (!handleEDNSClientSubnet(dq, dq.ids.ednsAdded, dq.ids.ecsAdded)) { + vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.ids.origRemote.toStringWithPort()); return ProcessQueryResult::Drop; } } - if (dq.packetCache && !dq.skipCache) { - if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKey, dq.subnet, dq.dnssecOK, !dq.overTCP(), allowExpired)) { + if (dq.ids.packetCache && !dq.ids.skipCache) { + bool forwardedOverUDP = !dq.overTCP(); + if (selectedBackend && selectedBackend->isTCPOnly()) { + forwardedOverUDP = false; + } - restoreFlags(dq.getHeader(), dq.origFlags); + if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKey, dq.ids.subnet, dq.ids.dnssecOK, forwardedOverUDP, allowExpired)) { + + restoreFlags(dq.getHeader(), dq.ids.origFlags); + + vinfolog("Packet cache hit for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size()); if (!prepareOutgoingResponse(holders, cs, dq, true)) { return ProcessQueryResult::Drop; @@ -1365,10 +1361,9 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& return ProcessQueryResult::SendAnswer; } - else if (dq.protocol == dnsdist::Protocol::DoH) { + else if (dq.ids.protocol == dnsdist::Protocol::DoH && !forwardedOverUDP) { /* do a second-lookup for UDP responses, but we do not want TC=1 answers */ - PacketBuffer initialQuery(dq.getData()); - if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyUDP, dq.subnet, dq.dnssecOK, true, allowExpired, false)) { + if (dq.ids.packetCache->get(dq, dq.getHeader()->id, &dq.ids.cacheKeyUDP, dq.ids.subnet, dq.ids.dnssecOK, true, allowExpired, false)) { if (!prepareOutgoingResponse(holders, cs, dq, true)) { return ProcessQueryResult::Drop; } @@ -1377,18 +1372,20 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& } } + vinfolog("Packet cache miss for query for %s|%s from %s (%s, %d bytes)", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort(), dq.ids.protocol.toString(), dq.getData().size()); + ++g_stats.cacheMisses; } if (!selectedBackend) { ++g_stats.noPolicy; - vinfolog("%s query for %s|%s from %s, no downstream server available", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toLogString(), QType(dq.qtype).toString(), dq.remote->toStringWithPort()); + vinfolog("%s query for %s|%s from %s, no downstream server available", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.ids.qname.toLogString(), QType(dq.ids.qtype).toString(), dq.ids.origRemote.toStringWithPort()); if (g_servFailOnNoPolicy) { dq.getHeader()->rcode = RCode::ServFail; dq.getHeader()->qr = true; - fixUpQueryTurnedResponse(dq, dq.origFlags); + fixUpQueryTurnedResponse(dq, dq.ids.origFlags); if (!prepareOutgoingResponse(holders, cs, dq, false)) { return ProcessQueryResult::Drop; @@ -1401,7 +1398,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& } /* save the DNS flags as sent to the backend so we can cache the answer with the right flags later */ - dq.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader()); + dq.ids.cacheFlags = *getFlagsFromDNSHeader(dq.getHeader()); if (dq.addXPF && selectedBackend->d_config.xpfRRCode != 0) { addXPF(dq, selectedBackend->d_config.xpfRRCode); @@ -1411,7 +1408,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& return ProcessQueryResult::PassToBackend; } catch (const std::exception& e){ - vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what()); + vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.overTCP() ? "TCP" : "UDP"), dq.ids.origRemote.toStringWithPort(), queryId, e.what()); } return ProcessQueryResult::Drop; } @@ -1448,11 +1445,6 @@ public: static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - double udiff = ids.sentTime.udiff(); - if (d_ds && !response.d_selfGenerated) { - vinfolog("Got answer from %s, relayed to %s (UDP), took %f usec", d_ds->d_config.remote.toStringWithPort(), ids.origRemote.toStringWithPort(), udiff); - } - handleResponseForUDPClient(ids, response.d_buffer, d_payloadSize, *localRespRuleActions, *localCacheInsertedRespRuleActions, d_ds, response.d_selfGenerated, std::nullopt); } @@ -1461,7 +1453,7 @@ public: return handleResponse(now, std::move(response)); } - void notifyIOError(IDState&& query, const struct timeval& now) override + void notifyIOError(InternalQueryState&& query, const struct timeval& now) override { // nothing to do } @@ -1474,7 +1466,7 @@ private: class UDPCrossProtocolQuery : public CrossProtocolQuery { public: - UDPCrossProtocolQuery(PacketBuffer&& buffer, IDState&& ids, std::shared_ptr& ds): d_cs(*ids.cs) + UDPCrossProtocolQuery(PacketBuffer&& buffer, InternalQueryState&& ids, std::shared_ptr& ds): d_cs(*ids.cs) { uint16_t z = 0; getEDNSUDPPayloadSizeAndZ(reinterpret_cast(buffer.data()), buffer.size(), &d_payloadSize, &z); @@ -1501,12 +1493,86 @@ private: uint16_t d_payloadSize{0}; }; +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest) +{ + bool doh = dq.ids.du != nullptr; + unsigned int idOffset = 0; + int64_t generation; + IDState* ids = ds->getIDState(idOffset, generation); + + dq.getHeader()->id = idOffset; + + bool failed = false; + if (ds->d_config.useProxyProtocol) { + try { + size_t payloadSize = 0; + if (addProxyProtocol(dq, &payloadSize)) { + if (dq.ids.du) { + dq.ids.du->proxyProtocolPayloadSize = payloadSize; + } + } + } + catch (const std::exception& e) { + vinfolog("Adding proxy protocol payload to %squery from %s failed: %s", (dq.ids.du ? "DoH" : ""), dq.ids.origDest.toStringWithPort(), e.what()); + failed = true; + } + } + + try { + if (!failed) { + int fd = ds->pickSocketForSending(); + dq.ids.backendFD = fd; + dq.ids.origID = queryID; + dq.ids.forwardedOverUDP = true; + ids->internal = std::move(dq.ids); + + vinfolog("Got query for %s|%s from %s%s, relayed to %s", ids->internal.qname.toLogString(), QType(ids->internal.qtype).toString(), ids->internal.origRemote.toStringWithPort(), (doh ? " (https)" : ""), ds->getNameWithAddr()); + /* you can't touch du after this line, unless the call returned a non-negative value, + because it might already have been freed */ + ssize_t ret = udpClientSendRequestToBackend(ds, fd, query); + + if (ret < 0) { + failed = true; + } + } + else { + ids->internal = std::move(dq.ids); + } + + if (failed) { + /* we are about to handle the error, make sure that + this pointer is not accessed when the state is cleaned, + but first check that it still belongs to us */ + if (ids->tryMarkUnused(generation) && ids->internal.du) { + dq.ids.du = std::move(ids->internal.du); + --ds->outstanding; + } + if (dq.ids.du) { + dq.ids.du->status_code = 502; + } + ++g_stats.downstreamSendErrors; + ++ds->sendErrors; + return false; + } + } + catch (const std::exception& e) { + throw; + } + + return true; +} + static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, PacketBuffer& query, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, cmsgbuf_aligned* respCBuf) { assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr)); uint16_t queryId = 0; - ComboAddress proxiedRemote = remote; - ComboAddress proxiedDestination = dest; + InternalQueryState ids; + ids.cs = &cs; + ids.origRemote = remote; + ids.hopRemote = remote; + ids.origDest = dest; + ids.hopLocal = dest; + ids.protocol = dnsdist::Protocol::DoUDP; try { bool expectProxyProtocol = false; @@ -1514,10 +1580,10 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct return; } /* dest might have been updated, if we managed to harvest the destination address */ - proxiedDestination = dest; + ids.origDest = dest; std::vector proxyProtocolValues; - if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, proxiedRemote, proxiedDestination, proxyProtocolValues)) { + if (expectProxyProtocol && !handleProxyProtocol(remote, false, *holders.acl, query, ids.origRemote, ids.origDest, proxyProtocolValues)) { return; } @@ -1527,8 +1593,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct struct timespec queryRealTime; gettime(&queryRealTime, true); - std::unique_ptr dnsCryptQuery = nullptr; - auto dnsCryptResponse = checkDNSCryptQuery(cs, query, dnsCryptQuery, queryRealTime.tv_sec, false); + auto dnsCryptResponse = checkDNSCryptQuery(cs, query, ids.dnsCryptQuery, queryRealTime.tv_sec, false); if (dnsCryptResponse) { sendUDPResponse(cs.udpFD, query, 0, dest, remote); return; @@ -1551,16 +1616,21 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } } - uint16_t qtype, qclass; - unsigned int qnameWireLength = 0; - DNSName qname(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); - DNSQuestion dq(&qname, qtype, qclass, proxiedDestination.sin4.sin_family != 0 ? &proxiedDestination : &cs.local, &proxiedRemote, query, dnsCryptQuery ? dnsdist::Protocol::DNSCryptUDP : dnsdist::Protocol::DoUDP, &queryRealTime); - dq.dnsCryptQuery = std::move(dnsCryptQuery); + ids.qname = DNSName(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + if (ids.origDest.sin4.sin_family == 0) { + ids.origDest = cs.local; + } + if (ids.dnsCryptQuery) { + ids.protocol = dnsdist::Protocol::DNSCryptUDP; + } + DNSQuestion dq(ids, query, queryRealTime); + const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); + ids.origFlags = *flags; + if (!proxyProtocolValues.empty()) { dq.proxyProtocolValues = make_unique>(std::move(proxyProtocolValues)); } - dq.hopRemote = &remote; - dq.hopLocal = &dest; + std::shared_ptr ss{nullptr}; auto result = processQuery(dq, cs, holders, ss); @@ -1573,7 +1643,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct if (result == ProcessQueryResult::SendAnswer) { #ifndef DISABLE_RECVMMSG #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) - if (dq.delayMsec == 0 && responsesVect != nullptr) { + if (dq.ids.delayMsec == 0 && responsesVect != nullptr) { queueResponse(cs, query, dest, remote, responsesVect[*queuedResponses], respIOV, respCBuf); (*queuedResponses)++; return; @@ -1581,9 +1651,9 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ #endif /* DISABLE_RECVMMSG */ /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */ - sendUDPResponse(cs.udpFD, query, dq.delayMsec, dest, remote); + sendUDPResponse(cs.udpFD, query, dq.ids.delayMsec, dest, remote); - handleResponseSent(qname, qtype, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP, dnsdist::Protocol::DoUDP); + handleResponseSent(ids, 0., remote, ComboAddress(), query.size(), *dh, dnsdist::Protocol::DoUDP); return; } @@ -1599,16 +1669,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct proxyProtocolPayload = getProxyProtocolPayload(dq); } - IDState ids; - ids.cs = &cs; ids.origID = dh->id; - setIDStateFromDNSQuestion(ids, dq, std::move(qname)); - if (dest.sin4.sin_family != 0) { - ids.origDest = dest; - } - else { - ids.origDest = cs.local; - } auto cpq = std::make_unique(std::move(query), std::move(ids), ss); cpq->query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); @@ -1616,41 +1677,10 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct return; } - unsigned int idOffset = 0; - int64_t generation; - IDState* ids = ss->getIDState(idOffset, generation); - - ids->cs = &cs; - ids->origID = dh->id; - setIDStateFromDNSQuestion(*ids, dq, std::move(qname)); - - if (dest.sin4.sin_family != 0) { - ids->origDest = dest; - } - else { - ids->origDest = cs.local; - } - - dh = dq.getHeader(); - dh->id = idOffset; - - if (ss->d_config.useProxyProtocol) { - addProxyProtocol(dq); - } - - int fd = ss->pickSocketForSending(); - ids->backendFD = fd; - ssize_t ret = udpClientSendRequestToBackend(ss, fd, query); - - if(ret < 0) { - ++ss->sendErrors; - ++g_stats.downstreamSendErrors; - } - - vinfolog("Got query for %s|%s from %s, relayed to %s", ids->qname.toLogString(), QType(ids->qtype).toString(), proxiedRemote.toStringWithPort(), ss->getNameWithAddr()); + assignOutgoingUDPQueryToBackend(ss, dh->id, dq, std::move(query), dest); } catch(const std::exception& e){ - vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", proxiedRemote.toStringWithPort(), queryId, e.what()); + vinfolog("Got an error in UDP question thread while parsing a query from %s, id %d: %s", ids.origRemote.toStringWithPort(), queryId, e.what()); } } diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index f93ee063a6..203f1b2a70 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -37,6 +37,7 @@ #include "dnscrypt.hh" #include "dnsdist-cache.hh" #include "dnsdist-dynbpf.hh" +#include "dnsdist-idstate.hh" #include "dnsdist-lbpolicies.hh" #include "dnsdist-protocols.hh" #include "dnsname.hh" @@ -62,10 +63,8 @@ using QTag = std::unordered_map; struct DNSQuestion { - DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, dnsdist::Protocol proto, const struct timespec* queryTime_): - data(data_), qname(name), local(lc), remote(rem), queryTime(queryTime_), tempFailureTTL(boost::none), qtype(type), qclass(class_), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), protocol(proto), ecsOverride(g_ECSOverride) { - const uint16_t* flags = getFlagsFromDNSHeader(getHeader()); - origFlags = *flags; + DNSQuestion(InternalQueryState& ids_, PacketBuffer& data_, const struct timespec& queryTime_): + data(data_), ids(ids_), queryTime(queryTime_), ecsPrefixLength(ids.origRemote.sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), ecsOverride(g_ECSOverride) { } DNSQuestion(const DNSQuestion&) = delete; DNSQuestion& operator=(const DNSQuestion&) = delete; @@ -113,83 +112,54 @@ struct DNSQuestion dnsdist::Protocol getProtocol() const { - return protocol; + return ids.protocol; } bool overTCP() const { - return !(protocol == dnsdist::Protocol::DoUDP || protocol == dnsdist::Protocol::DNSCryptUDP); + return !(ids.protocol == dnsdist::Protocol::DoUDP || ids.protocol == dnsdist::Protocol::DNSCryptUDP); } void setTag(std::string&& key, std::string&& value) { - if (!qTag) { - qTag = std::make_unique(); + if (!ids.qTag) { + ids.qTag = std::make_unique(); } - qTag->insert_or_assign(std::move(key), std::move(value)); + ids.qTag->insert_or_assign(std::move(key), std::move(value)); } void setTag(const std::string& key, const std::string& value) { - if (!qTag) { - qTag = std::make_unique(); + if (!ids.qTag) { + ids.qTag = std::make_unique(); } - qTag->insert_or_assign(key, value); + ids.qTag->insert_or_assign(key, value); } protected: PacketBuffer& data; public: - boost::optional uniqueId; - Netmask ecs; - boost::optional subnet; + InternalQueryState& ids; + std::unique_ptr ecs{nullptr}; std::string sni; /* Server Name Indication, if any (DoT or DoH) */ - std::string poolname; mutable std::shared_ptr > ednsOptions; - std::shared_ptr packetCache{nullptr}; - const DNSName* qname{nullptr}; - const ComboAddress* local{nullptr}; - const ComboAddress* remote{nullptr}; - /* this is the address dnsdist received the packet on, - which might not match local when support for incoming proxy protocol - is enabled */ - const ComboAddress* hopLocal{nullptr}; /* the address dnsdist received the packet from, see above */ - const ComboAddress* hopRemote{nullptr}; - std::unique_ptr qTag{nullptr}; std::unique_ptr> proxyProtocolValues{nullptr}; - std::unique_ptr dnsCryptQuery{nullptr}; - const struct timespec* queryTime{nullptr}; - struct DOHUnit* du{nullptr}; - int delayMsec{0}; - boost::optional tempFailureTTL; - uint32_t cacheKeyNoECS{0}; - uint32_t cacheKey{0}; - /* for DoH */ - uint32_t cacheKeyUDP{0}; - const uint16_t qtype; - const uint16_t qclass; + const struct timespec& queryTime; uint16_t ecsPrefixLength; - uint16_t origFlags; - uint16_t cacheFlags{0}; /* DNS flags as sent to the backend */ - const dnsdist::Protocol protocol; uint8_t ednsRCode{0}; - bool skipCache{false}; bool ecsOverride; bool useECS{true}; bool addXPF{true}; - bool ecsSet{false}; - bool ecsAdded{false}; - bool ednsAdded{false}; - bool useZeroScope{false}; - bool dnssecOK{false}; }; struct DNSResponse : DNSQuestion { - DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, PacketBuffer& data_, dnsdist::Protocol proto, const struct timespec* queryTime_): - DNSQuestion(name, type, class_, lc, rem, data_, proto, queryTime_) { } + DNSResponse(InternalQueryState& ids_, PacketBuffer& data_, const struct timespec& queryTime_, const std::shared_ptr& downstream): + DNSQuestion(ids_, data_, queryTime_), d_downstream(downstream) { } DNSResponse(const DNSResponse&) = delete; DNSResponse& operator=(const DNSResponse&) = delete; DNSResponse(DNSResponse&&) = default; + + const std::shared_ptr& d_downstream; }; /* so what could you do: @@ -457,9 +427,6 @@ struct DNSDistStats }; extern struct DNSDistStats g_stats; -void doLatencyStats(dnsdist::Protocol protocol, double udiff); - -#include "dnsdist-idstate.hh" class BasicQPSLimiter { @@ -1212,7 +1179,7 @@ bool getLuaNoSideEffect(); // set if there were only explicit declarations of _n void resetLuaSideEffect(); // reset to indeterminate state bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const std::shared_ptr& remote, unsigned int& qnameWireLength); -bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& insertedRespRuleActions, DNSResponse& dr, bool muted, bool receivedOverUDP); +bool processResponse(PacketBuffer& response, const std::vector& respRuleActions, const std::vector& insertedRespRuleActions, DNSResponse& dr, bool muted); bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop); bool checkQueryHeaders(const struct dnsheader* dh, ClientState& cs); @@ -1235,9 +1202,8 @@ static const size_t s_maxPacketCacheEntrySize{4096}; // don't cache responses la enum class ProcessQueryResult : uint8_t { Drop, SendAnswer, PassToBackend }; ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend); -DNSResponse makeDNSResponseFromIDState(IDState& ids, PacketBuffer& data); -void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname); +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest); ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const PacketBuffer& request, bool healthCheck = false); void handleResponseSent(const DNSName& qname, const QType& qtype, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol, dnsdist::Protocol incomingProtocol); -void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol); +void handleResponseSent(const InternalQueryState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol outgoingProtocol); diff --git a/pdns/dnsdistdist/Makefile.am b/pdns/dnsdistdist/Makefile.am index c3ed1f1580..ef7ac181a6 100644 --- a/pdns/dnsdistdist/Makefile.am +++ b/pdns/dnsdistdist/Makefile.am @@ -146,7 +146,7 @@ dnsdist_SOURCES = \ dnsdist-dynbpf.cc dnsdist-dynbpf.hh \ dnsdist-ecs.cc dnsdist-ecs.hh \ dnsdist-healthchecks.cc dnsdist-healthchecks.hh \ - dnsdist-idstate.cc dnsdist-idstate.hh \ + dnsdist-idstate.hh \ dnsdist-kvs.hh dnsdist-kvs.cc \ dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \ dnsdist-lua-actions.cc \ @@ -249,7 +249,7 @@ testrunner_SOURCES = \ dnsdist-dynblocks.cc dnsdist-dynblocks.hh \ dnsdist-dynbpf.cc dnsdist-dynbpf.hh \ dnsdist-ecs.cc dnsdist-ecs.hh \ - dnsdist-idstate.cc dnsdist-idstate.hh \ + dnsdist-idstate.hh \ dnsdist-kvs.cc dnsdist-kvs.hh \ dnsdist-lbpolicies.cc dnsdist-lbpolicies.hh \ dnsdist-lua-bindings-dnsquestion.cc \ diff --git a/pdns/dnsdistdist/dnsdist-backend.cc b/pdns/dnsdistdist/dnsdist-backend.cc index 79fd3dc7a7..daff23d6aa 100644 --- a/pdns/dnsdistdist/dnsdist-backend.cc +++ b/pdns/dnsdistdist/dnsdist-backend.cc @@ -333,18 +333,14 @@ void DownstreamState::handleUDPTimeout(IDState& ids) to limit the risk of racing with the responder thread. */ - auto oldDU = ids.du; - - ids.du = nullptr; - handleDOHTimeout(DOHUnitUniquePtr(oldDU, DOHUnit::release)); - oldDU = nullptr; ids.age = 0; + handleDOHTimeout(std::move(ids.internal.du)); reuseds++; --outstanding; ++g_stats.downstreamTimeouts; // this is an 'actively' discovered timeout vinfolog("Had a downstream timeout from %s (%s) for query for %s|%s from %s", d_config.remote.toStringWithPort(), getName(), - ids.qname.toLogString(), QType(ids.qtype).toString(), ids.origRemote.toStringWithPort()); + ids.internal.qname.toLogString(), QType(ids.internal.qtype).toString(), ids.internal.origRemote.toStringWithPort()); if (g_rings.shouldRecordResponses()) { struct timespec ts; @@ -352,9 +348,9 @@ void DownstreamState::handleUDPTimeout(IDState& ids) struct dnsheader fake; memset(&fake, 0, sizeof(fake)); - fake.id = ids.origID; + fake.id = ids.internal.origID; - g_rings.insertResponse(ts, ids.origRemote, ids.qname, ids.qtype, std::numeric_limits::max(), 0, fake, d_config.remote, getProtocol()); + g_rings.insertResponse(ts, ids.internal.origRemote, ids.internal.qname, ids.internal.qtype, std::numeric_limits::max(), 0, fake, d_config.remote, getProtocol()); } reportTimeoutOrError(); @@ -476,13 +472,6 @@ IDState* DownstreamState::getIDState(unsigned int& selectedID, int64_t& generati ids->age = 0; - /* that means that the state was in use, possibly with an allocated - DOHUnit that we will need to handle, but we can't touch it before - confirming that we now own this state */ - if (ids->isInUse()) { - du = DOHUnitUniquePtr(ids->du, DOHUnit::release); - } - /* we atomically replace the value, we now own this state */ generation = ids->generation++; if (!ids->markAsUsed(generation)) { @@ -494,10 +483,10 @@ IDState* DownstreamState::getIDState(unsigned int& selectedID, int64_t& generati else { /* we are reusing a state, no change in outstanding but if there was an existing DOHUnit we need to handle it because it's about to be overwritten. */ - ids->du = nullptr; + auto oldDU = std::move(ids->internal.du); ++reuseds; ++g_stats.downstreamTimeouts; - handleDOHTimeout(std::move(du)); + handleDOHTimeout(std::move(oldDU)); } return ids; diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index 1815d8bf1b..a6328f168c 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -156,7 +156,7 @@ public: throw std::runtime_error("Unexpected XFR reponse to a health check query"); } - void notifyIOError(IDState&& query, const struct timeval& now) override + void notifyIOError(InternalQueryState&& query, const struct timeval& now) override { d_data->d_ds->submitHealthCheckResult(d_data->d_initial, false); } @@ -359,7 +359,7 @@ bool queueHealthCheck(std::unique_ptr& mplexer, const std::shared mplexer->addReadFD(data->d_udpSocket.getHandle(), &healthCheckUDPCallback, data, &data->d_ttd); } else if (ds->isDoH()) { - InternalQuery query(std::move(packet), IDState()); + InternalQuery query(std::move(packet), InternalQueryState()); query.d_proxyProtocolPayload = std::move(proxyProtocolPayload); auto sender = std::shared_ptr(new HealthCheckQuerySender(data)); if (!sendH2Query(ds, mplexer, sender, std::move(query), true)) { diff --git a/pdns/dnsdistdist/dnsdist-idstate.cc b/pdns/dnsdistdist/dnsdist-idstate.cc deleted file mode 100644 index 286808c765..0000000000 --- a/pdns/dnsdistdist/dnsdist-idstate.cc +++ /dev/null @@ -1,75 +0,0 @@ - -#include "dnsdist.hh" - -DNSResponse makeDNSResponseFromIDState(IDState& ids, PacketBuffer& data) -{ - DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, &ids.origDest, &ids.origRemote, data, ids.protocol, &ids.sentTime.d_start); - dr.origFlags = ids.origFlags; - dr.cacheFlags = ids.cacheFlags; - dr.ecsAdded = ids.ecsAdded; - dr.ednsAdded = ids.ednsAdded; - dr.useZeroScope = ids.useZeroScope; - dr.packetCache = std::move(ids.packetCache); - dr.delayMsec = ids.delayMsec; - dr.skipCache = ids.skipCache; - dr.cacheKey = ids.cacheKey; - dr.cacheKeyNoECS = ids.cacheKeyNoECS; - dr.cacheKeyUDP = ids.cacheKeyUDP; - dr.dnssecOK = ids.dnssecOK; - dr.tempFailureTTL = ids.tempFailureTTL; - dr.qTag = std::move(ids.qTag); - dr.subnet = std::move(ids.subnet); - dr.uniqueId = std::move(ids.uniqueId); - - if (ids.dnsCryptQuery) { - dr.dnsCryptQuery = std::move(ids.dnsCryptQuery); - } - - dr.hopRemote = &ids.hopRemote; - dr.hopLocal = &ids.hopLocal; - - return dr; -} - -void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname) -{ - ids.origRemote = *dq.remote; - ids.origDest = *dq.local; - ids.sentTime.set(*dq.queryTime); - ids.qname = std::move(qname); - ids.qtype = dq.qtype; - ids.qclass = dq.qclass; - ids.protocol = dq.protocol; - ids.delayMsec = dq.delayMsec; - ids.tempFailureTTL = dq.tempFailureTTL; - ids.origFlags = dq.origFlags; - ids.cacheFlags = dq.cacheFlags; - ids.cacheKey = dq.cacheKey; - ids.cacheKeyNoECS = dq.cacheKeyNoECS; - ids.cacheKeyUDP = dq.cacheKeyUDP; - ids.subnet = dq.subnet; - ids.skipCache = dq.skipCache; - ids.packetCache = dq.packetCache; - ids.ednsAdded = dq.ednsAdded; - ids.ecsAdded = dq.ecsAdded; - ids.useZeroScope = dq.useZeroScope; - ids.qTag = std::move(dq.qTag); - ids.dnssecOK = dq.dnssecOK; - ids.uniqueId = std::move(dq.uniqueId); - - if (dq.hopRemote) { - ids.hopRemote = *dq.hopRemote; - } - else { - ids.hopRemote.sin4.sin_family = 0; - } - - if (dq.hopLocal) { - ids.hopLocal = *dq.hopLocal; - } - else { - ids.hopLocal.sin4.sin_family = 0; - } - - ids.dnsCryptQuery = std::move(dq.dnsCryptQuery); -} diff --git a/pdns/dnsdistdist/dnsdist-kvs.hh b/pdns/dnsdistdist/dnsdist-kvs.hh index 9f9392cdc1..764a45c35c 100644 --- a/pdns/dnsdistdist/dnsdist-kvs.hh +++ b/pdns/dnsdistdist/dnsdist-kvs.hh @@ -44,7 +44,7 @@ public: std::vector getKeys(const DNSQuestion& dq) override { - return getKeys(*dq.remote); + return getKeys(dq.ids.origRemote); } std::string toString() const override @@ -75,7 +75,7 @@ public: std::vector getKeys(const DNSQuestion& dq) override { - return getKeys(*dq.qname); + return getKeys(dq.ids.qname); } std::string toString() const override @@ -101,7 +101,7 @@ public: std::vector getKeys(const DNSQuestion& dq) override { - return getKeys(*dq.qname); + return getKeys(dq.ids.qname); } std::string toString() const override @@ -126,9 +126,9 @@ public: std::vector getKeys(const DNSQuestion& dq) override { - if (dq.qTag) { - const auto& it = dq.qTag->find(d_tag); - if (it != dq.qTag->end()) { + if (dq.ids.qTag) { + const auto& it = dq.ids.qTag->find(d_tag); + if (it != dq.ids.qTag->end()) { return { it->second }; } } diff --git a/pdns/dnsdistdist/dnsdist-lbpolicies.cc b/pdns/dnsdistdist/dnsdist-lbpolicies.cc index 1e80eb756a..cec6769dd0 100644 --- a/pdns/dnsdistdist/dnsdist-lbpolicies.cc +++ b/pdns/dnsdistdist/dnsdist-lbpolicies.cc @@ -166,7 +166,7 @@ shared_ptr whashedFromHash(const ServerPolicy::NumberedServerVe shared_ptr whashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { - return whashedFromHash(servers, dq->qname->hash(g_hashperturb)); + return whashedFromHash(servers, dq->ids.qname.hash(g_hashperturb)); } shared_ptr chashedFromHash(const ServerPolicy::NumberedServerVector& servers, size_t qhash) @@ -228,7 +228,7 @@ shared_ptr chashedFromHash(const ServerPolicy::NumberedServerVe shared_ptr chashed(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) { - return chashedFromHash(servers, dq->qname->hash(g_hashperturb)); + return chashedFromHash(servers, dq->ids.qname.hash(g_hashperturb)); } shared_ptr roundrobin(const ServerPolicy::NumberedServerVector& servers, const DNSQuestion* dq) diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index 9ff6a2d429..d27c5dd6cd 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -31,12 +31,12 @@ uint16_t dnsdist_ffi_dnsquestion_get_qtype(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->qtype; + return dq->dq->ids.qtype; } uint16_t dnsdist_ffi_dnsquestion_get_qclass(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->qclass; + return dq->dq->ids.qclass; } static void dnsdist_ffi_comboaddress_to_raw(const ComboAddress& ca, const void** addr, size_t* addrSize) @@ -53,12 +53,12 @@ static void dnsdist_ffi_comboaddress_to_raw(const ComboAddress& ca, const void** void dnsdist_ffi_dnsquestion_get_localaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) { - dnsdist_ffi_comboaddress_to_raw(*dq->dq->local, addr, addrSize); + dnsdist_ffi_comboaddress_to_raw(dq->dq->ids.origDest, addr, addrSize); } void dnsdist_ffi_dnsquestion_get_remoteaddr(const dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize) { - dnsdist_ffi_comboaddress_to_raw(*dq->dq->remote, addr, addrSize); + dnsdist_ffi_comboaddress_to_raw(dq->dq->ids.origRemote, addr, addrSize); } size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq, void* buffer, size_t bufferSize) @@ -67,7 +67,7 @@ size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq, return 0; } - auto ret = dnsdist::MacAddressesCache::get(*dq->dq->remote, reinterpret_cast(buffer), bufferSize); + auto ret = dnsdist::MacAddressesCache::get(dq->dq->ids.origRemote, reinterpret_cast(buffer), bufferSize); if (ret != 0) { return 0; } @@ -77,30 +77,30 @@ size_t dnsdist_ffi_dnsquestion_get_mac_addr(const dnsdist_ffi_dnsquestion_t* dq, void dnsdist_ffi_dnsquestion_get_masked_remoteaddr(dnsdist_ffi_dnsquestion_t* dq, const void** addr, size_t* addrSize, uint8_t bits) { - dq->maskedRemote = Netmask(*dq->dq->remote, bits).getMaskedNetwork(); + dq->maskedRemote = Netmask(dq->dq->ids.origRemote, bits).getMaskedNetwork(); dnsdist_ffi_comboaddress_to_raw(dq->maskedRemote, addr, addrSize); } uint16_t dnsdist_ffi_dnsquestion_get_local_port(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->local->getPort(); + return dq->dq->ids.origDest.getPort(); } uint16_t dnsdist_ffi_dnsquestion_get_remote_port(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->remote->getPort(); + return dq->dq->ids.origRemote.getPort(); } void dnsdist_ffi_dnsquestion_get_qname_raw(const dnsdist_ffi_dnsquestion_t* dq, const char** qname, size_t* qnameSize) { - const auto& storage = dq->dq->qname->getStorage(); + const auto& storage = dq->dq->ids.qname.getStorage(); *qname = storage.data(); *qnameSize = storage.size(); } size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* dq, size_t init) { - return dq->dq->qname->hash(init); + return dq->dq->ids.qname.hash(init); } int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) @@ -172,7 +172,7 @@ dnsdist_ffi_protocol_type dnsdist_ffi_dnsquestion_get_protocol(const dnsdist_ffi bool dnsdist_ffi_dnsquestion_get_skip_cache(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->skipCache; + return dq->dq->ids.skipCache; } bool dnsdist_ffi_dnsquestion_get_use_ecs(const dnsdist_ffi_dnsquestion_t* dq) @@ -197,13 +197,13 @@ uint16_t dnsdist_ffi_dnsquestion_get_ecs_prefix_length(const dnsdist_ffi_dnsques bool dnsdist_ffi_dnsquestion_is_temp_failure_ttl_set(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->tempFailureTTL != boost::none; + return dq->dq->ids.tempFailureTTL != boost::none; } uint32_t dnsdist_ffi_dnsquestion_get_temp_failure_ttl(const dnsdist_ffi_dnsquestion_t* dq) { - if (dq->dq->tempFailureTTL) { - return *dq->dq->tempFailureTTL; + if (dq->dq->ids.tempFailureTTL) { + return *dq->dq->ids.tempFailureTTL; } return 0; } @@ -223,9 +223,9 @@ const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, { const char * result = nullptr; - if (dq->dq->qTag != nullptr) { - const auto it = dq->dq->qTag->find(label); - if (it != dq->dq->qTag->cend()) { + if (dq->dq->ids.qTag != nullptr) { + const auto it = dq->dq->ids.qTag->find(label); + if (it != dq->dq->ids.qTag->cend()) { result = it->second.c_str(); } } @@ -236,11 +236,11 @@ const char* dnsdist_ffi_dnsquestion_get_tag(const dnsdist_ffi_dnsquestion_t* dq, const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpPath) { - if (dq->dq->du == nullptr) { + if (dq->dq->ids.du == nullptr) { return nullptr; } #ifdef HAVE_DNS_OVER_HTTPS - dq->httpPath = dq->dq->du->getHTTPPath(); + dq->httpPath = dq->dq->ids.du->getHTTPPath(); #endif /* HAVE_DNS_OVER_HTTPS */ } if (dq->httpPath) { @@ -252,11 +252,11 @@ const char* dnsdist_ffi_dnsquestion_get_http_path(dnsdist_ffi_dnsquestion_t* dq) const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpQueryString) { - if (dq->dq->du == nullptr) { + if (dq->dq->ids.du == nullptr) { return nullptr; } #ifdef HAVE_DNS_OVER_HTTPS - dq->httpQueryString = dq->dq->du->getHTTPQueryString(); + dq->httpQueryString = dq->dq->ids.du->getHTTPQueryString(); #endif /* HAVE_DNS_OVER_HTTPS */ } if (dq->httpQueryString) { @@ -268,11 +268,11 @@ const char* dnsdist_ffi_dnsquestion_get_http_query_string(dnsdist_ffi_dnsquestio const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpHost) { - if (dq->dq->du == nullptr) { + if (dq->dq->ids.du == nullptr) { return nullptr; } #ifdef HAVE_DNS_OVER_HTTPS - dq->httpHost = dq->dq->du->getHTTPHost(); + dq->httpHost = dq->dq->ids.du->getHTTPHost(); #endif /* HAVE_DNS_OVER_HTTPS */ } if (dq->httpHost) { @@ -284,11 +284,11 @@ const char* dnsdist_ffi_dnsquestion_get_http_host(dnsdist_ffi_dnsquestion_t* dq) const char* dnsdist_ffi_dnsquestion_get_http_scheme(dnsdist_ffi_dnsquestion_t* dq) { if (!dq->httpScheme) { - if (dq->dq->du == nullptr) { + if (dq->dq->ids.du == nullptr) { return nullptr; } #ifdef HAVE_DNS_OVER_HTTPS - dq->httpScheme = dq->dq->du->getHTTPScheme(); + dq->httpScheme = dq->dq->ids.du->getHTTPScheme(); #endif /* HAVE_DNS_OVER_HTTPS */ } if (dq->httpScheme) { @@ -346,12 +346,12 @@ size_t dnsdist_ffi_dnsquestion_get_edns_options(dnsdist_ffi_dnsquestion_t* dq, c size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_http_header_t** out) { - if (dq->dq->du == nullptr) { + if (dq->dq->ids.du == nullptr) { return 0; } #ifdef HAVE_DNS_OVER_HTTPS - auto headers = dq->dq->du->getHTTPHeaders(); + auto headers = dq->dq->ids.du->getHTTPHeaders(); if (headers.size() == 0) { return 0; } @@ -380,7 +380,7 @@ size_t dnsdist_ffi_dnsquestion_get_http_headers(dnsdist_ffi_dnsquestion_t* dq, c size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, const dnsdist_ffi_tag_t** out) { - if (dq->dq->qTag == nullptr || dq->dq->qTag->size() == 0) { + if (dq->dq->ids.qTag == nullptr || dq->dq->ids.qTag->size() == 0) { return 0; } @@ -388,10 +388,10 @@ size_t dnsdist_ffi_dnsquestion_get_tag_array(dnsdist_ffi_dnsquestion_t* dq, cons dq->tagsVect = std::make_unique>(); } dq->tagsVect->clear(); - dq->tagsVect->resize(dq->dq->qTag->size()); + dq->tagsVect->resize(dq->dq->ids.qTag->size()); size_t pos = 0; - for (const auto& tag : *dq->dq->qTag) { + for (const auto& tag : *dq->dq->ids.qTag) { auto& entry = dq->tagsVect->at(pos); entry.name = tag.first.c_str(); entry.value = tag.second.c_str(); @@ -413,13 +413,13 @@ void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const cha void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType) { - if (dq->dq->du == nullptr) { + if (dq->dq->ids.du == nullptr) { return; } #ifdef HAVE_DNS_OVER_HTTPS PacketBuffer bodyVect(body, body + bodyLen); - dq->dq->du->setHTTPResponse(statusCode, std::move(bodyVect), contentType); + dq->dq->ids.du->setHTTPResponse(statusCode, std::move(bodyVect), contentType); dq->dq->getHeader()->qr = true; #endif } @@ -437,7 +437,7 @@ void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len void dnsdist_ffi_dnsquestion_set_skip_cache(dnsdist_ffi_dnsquestion_t* dq, bool skipCache) { - dq->dq->skipCache = skipCache; + dq->dq->ids.skipCache = skipCache; } void dnsdist_ffi_dnsquestion_set_use_ecs(dnsdist_ffi_dnsquestion_t* dq, bool useECS) @@ -457,12 +457,12 @@ void dnsdist_ffi_dnsquestion_set_ecs_prefix_length(dnsdist_ffi_dnsquestion_t* dq void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, uint32_t tempFailureTTL) { - dq->dq->tempFailureTTL = tempFailureTTL; + dq->dq->ids.tempFailureTTL = tempFailureTTL; } void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq) { - dq->dq->tempFailureTTL = boost::none; + dq->dq->ids.tempFailureTTL = boost::none; } void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value) @@ -723,7 +723,7 @@ size_t dnsdist_ffi_dnsquestion_generate_proxy_protocol_payload(const dnsdist_ffi } } - std::string payload = makeProxyHeader(dq->dq->overTCP(), *dq->dq->remote, *dq->dq->local, valuesVect); + std::string payload = makeProxyHeader(dq->dq->overTCP(), dq->dq->ids.origRemote, dq->dq->ids.origDest, valuesVect); if (payload.size() > outSize) { return 0; } diff --git a/pdns/dnsdistdist/dnsdist-proxy-protocol.cc b/pdns/dnsdistdist/dnsdist-proxy-protocol.cc index 6a3e978977..aefcafba54 100644 --- a/pdns/dnsdistdist/dnsdist-proxy-protocol.cc +++ b/pdns/dnsdistdist/dnsdist-proxy-protocol.cc @@ -29,7 +29,7 @@ bool g_applyACLToProxiedClients = false; std::string getProxyProtocolPayload(const DNSQuestion& dq) { - return makeProxyHeader(dq.overTCP(), *dq.remote, *dq.local, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector()); + return makeProxyHeader(dq.overTCP(), dq.ids.origRemote, dq.ids.origDest, dq.proxyProtocolValues ? *dq.proxyProtocolValues : std::vector()); } bool addProxyProtocol(DNSQuestion& dq, const std::string& payload) diff --git a/pdns/dnsdistdist/dnsdist-rules.hh b/pdns/dnsdistdist/dnsdist-rules.hh index 885101957b..e21582fb20 100644 --- a/pdns/dnsdistdist/dnsdist-rules.hh +++ b/pdns/dnsdistdist/dnsdist-rules.hh @@ -96,9 +96,9 @@ public: bool matches(const DNSQuestion* dq) const override { - cleanupIfNeeded(*dq->queryTime); + cleanupIfNeeded(dq->queryTime); - ComboAddress zeroport(*dq->remote); + ComboAddress zeroport(dq->ids.origRemote); zeroport.sin4.sin_port=0; zeroport.truncate(zeroport.sin4.sin_family == AF_INET ? d_ipv4trunc : d_ipv6trunc); { @@ -196,9 +196,9 @@ public: bool matches(const DNSQuestion* dq) const override { if(!d_src) { - return d_nmg.match(*dq->local); + return d_nmg.match(dq->ids.origDest); } - return d_nmg.match(*dq->remote); + return d_nmg.match(dq->ids.origRemote); } string toString() const override @@ -242,16 +242,16 @@ public: } bool matches(const DNSQuestion* dq) const override { - if (dq->remote->sin4.sin_family == AF_INET) { + if (dq->ids.origRemote.sin4.sin_family == AF_INET) { auto ip4s = d_ip4s.read_lock(); - auto fnd = ip4s->find(dq->remote->sin4.sin_addr.s_addr); + auto fnd = ip4s->find(dq->ids.origRemote.sin4.sin_addr.s_addr); if (fnd == ip4s->end()) { return false; } return time(nullptr) < fnd->second; } else { auto ip6s = d_ip6s.read_lock(); - auto fnd = ip6s->find({*dq->remote}); + auto fnd = ip6s->find({dq->ids.origRemote}); if (fnd == ip6s->end()) { return false; } @@ -473,7 +473,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return d_regex.match(dq->qname->toStringNoDot()); + return d_regex.match(dq->ids.qname.toStringNoDot()); } string toString() const override @@ -496,7 +496,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return RE2::FullMatch(dq->qname->toStringNoDot(), d_re2); + return RE2::FullMatch(dq->ids.qname.toStringNoDot(), d_re2); } string toString() const override @@ -570,7 +570,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return d_smn.check(*dq->qname); + return d_smn.check(dq->ids.qname); } string toString() const override { @@ -593,7 +593,7 @@ public: bool matches(const DNSQuestion* dq) const override { - return d_qname==*dq->qname; + return d_qname==dq->ids.qname; } string toString() const override { @@ -608,7 +608,7 @@ public: QNameSetRule(const DNSNameSet& names) : qname_idx(names) {} bool matches(const DNSQuestion* dq) const override { - return qname_idx.find(*dq->qname) != qname_idx.end(); + return qname_idx.find(dq->ids.qname) != qname_idx.end(); } string toString() const override { @@ -628,7 +628,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return d_qtype == dq->qtype; + return d_qtype == dq->ids.qtype; } string toString() const override { @@ -647,7 +647,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return d_qclass == dq->qclass; + return d_qclass == dq->ids.qclass; } string toString() const override { @@ -683,7 +683,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return htons(d_port) == dq->local->sin4.sin_port; + return htons(d_port) == dq->ids.origDest.sin4.sin_port; } string toString() const override { @@ -860,7 +860,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - unsigned int count = dq->qname->countLabels(); + unsigned int count = dq->ids.qname.countLabels(); return count < d_min || count > d_max; } string toString() const override @@ -880,7 +880,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - size_t const wirelength = dq->qname->wirelength(); + size_t const wirelength = dq->ids.qname.wirelength(); return wirelength < d_min || wirelength > d_max; } string toString() const override @@ -1043,12 +1043,12 @@ public: } bool matches(const DNSQuestion* dq) const override { - if (!dq->qTag) { + if (!dq->ids.qTag) { return false; } - const auto it = dq->qTag->find(d_tag); - if (it == dq->qTag->cend()) { + const auto it = dq->ids.qTag->find(d_tag); + if (it == dq->ids.qTag->cend()) { return false; } diff --git a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc index a6ba8f44d7..2c5a4f3811 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-downstream.cc +++ b/pdns/dnsdistdist/dnsdist-tcp-downstream.cc @@ -38,7 +38,7 @@ ConnectionToBackend::~ConnectionToBackend() bool ConnectionToBackend::reconnect() { std::unique_ptr tlsSession{nullptr}; - if (d_handler) { + if (d_handler) { DEBUGLOG("closing socket "<getDescriptor()); if (d_handler->isTLS()) { if (d_handler->hasTLSSessionBeenResumed()) { diff --git a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh index 23bfd1a63f..e48d52317d 100644 --- a/pdns/dnsdistdist/dnsdist-tcp-upstream.hh +++ b/pdns/dnsdistdist/dnsdist-tcp-upstream.hh @@ -125,7 +125,7 @@ static void handleTimeout(std::shared_ptr& state, bo /* we take a copy of a shared pointer, not a reference, because the initial shared pointer might be released during the handling of the response */ void handleResponse(const struct timeval& now, TCPResponse&& response) override; void handleXFRResponse(const struct timeval& now, TCPResponse&& response) override; - void notifyIOError(IDState&& query, const struct timeval& now) override; + void notifyIOError(InternalQueryState&& query, const struct timeval& now) override; void handleCrossProtocolResponse(const struct timeval& now, TCPResponse&& response); diff --git a/pdns/dnsdistdist/dnsdist-tcp.hh b/pdns/dnsdistdist/dnsdist-tcp.hh index 6fb1b827a2..a3811e9989 100644 --- a/pdns/dnsdistdist/dnsdist-tcp.hh +++ b/pdns/dnsdistdist/dnsdist-tcp.hh @@ -78,7 +78,7 @@ struct InternalQuery { } - InternalQuery(PacketBuffer&& buffer, IDState&& state) : + InternalQuery(PacketBuffer&& buffer, InternalQueryState&& state) : d_idstate(std::move(state)), d_buffer(std::move(buffer)) { } @@ -94,7 +94,7 @@ struct InternalQuery return d_idstate.qtype == QType::AXFR || d_idstate.qtype == QType::IXFR; } - IDState d_idstate; + InternalQueryState d_idstate; std::string d_proxyProtocolPayload; PacketBuffer d_buffer; uint32_t d_proxyProtocolPayloadAddedSize{0}; @@ -118,7 +118,7 @@ struct TCPResponse : public TCPQuery memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); } - TCPResponse(PacketBuffer&& buffer, IDState&& state, std::shared_ptr conn) : + TCPResponse(PacketBuffer&& buffer, InternalQueryState&& state, std::shared_ptr conn) : TCPQuery(std::move(buffer), std::move(state)), d_connection(conn) { memset(&d_cleartextDH, 0, sizeof(d_cleartextDH)); @@ -140,7 +140,7 @@ public: virtual const ClientState* getClientState() const = 0; virtual void handleResponse(const struct timeval& now, TCPResponse&& response) = 0; virtual void handleXFRResponse(const struct timeval& now, TCPResponse&& response) = 0; - virtual void notifyIOError(IDState&& query, const struct timeval& now) = 0; + virtual void notifyIOError(InternalQueryState&& query, const struct timeval& now) = 0; /* whether the connection should be automatically released to the pool after handleResponse() has been called */ @@ -158,7 +158,6 @@ struct CrossProtocolQuery CrossProtocolQuery() { } - CrossProtocolQuery(InternalQuery&& query_, std::shared_ptr& downstream_) : query(std::move(query_)), downstream(downstream_) { diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index 5cdfe5e649..ff3bea223b 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -265,7 +265,7 @@ void handleDOHTimeout(DOHUnitUniquePtr&& oldDU) return; } -/* we are about to erase an existing DU */ + /* we are about to erase an existing DU */ oldDU->status_code = 502; sendDoHUnitToTheMainThread(std::move(oldDU), "DoH timeout"); @@ -427,7 +427,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo class DoHTCPCrossQuerySender : public TCPQuerySender { public: - DoHTCPCrossQuerySender(DOHUnitUniquePtr&& du_): du(std::move(du_)) + DoHTCPCrossQuerySender(const ClientState& cs): d_cs(cs) { } @@ -438,37 +438,42 @@ public: const ClientState* getClientState() const override { - if (!du || !du->dsc || !du->dsc->cs) { - throw std::runtime_error("No query associated to this DoHTCPCrossQuerySender"); - } - - return du->dsc->cs; + return &d_cs; } void handleResponse(const struct timeval& now, TCPResponse&& response) override { - if (!du) { + if (!response.d_idstate.du) { return; } + auto du = std::move(response.d_idstate.du); if (du->rsock == -1) { return; } du->response = std::move(response.d_buffer); du->ids = std::move(response.d_idstate); + DNSResponse dr(du->ids, du->response, du->ids.sentTime.d_start, du->downstream); static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder> localCacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - DNSResponse dr = makeDNSResponseFromIDState(du->ids, du->response); + dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); - if (!processResponse(du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false, false)) { - du.reset(); + dr.ids.du = std::move(du); + + if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localCacheInsertedRespRuleActions, dr, false)) { + if (dr.ids.du) { + dr.ids.du->status_code = 503; + sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + } return; } + du = std::move(dr.ids.du); + double udiff = du->ids.sentTime.udiff(); vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); @@ -476,7 +481,7 @@ public: if (backendProtocol == dnsdist::Protocol::DoUDP && du->tcp) { backendProtocol = dnsdist::Protocol::DoTCP; } - handleResponseSent(du->ids, udiff, *dr.remote, du->downstream->d_config.remote, du->response.size(), cleartextDH, backendProtocol); + handleResponseSent(du->ids, udiff, du->ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, backendProtocol); ++g_stats.responses; if (du->ids.cs) { @@ -491,65 +496,71 @@ public: return handleResponse(now, std::move(response)); } - void notifyIOError(IDState&& query, const struct timeval& now) override + void notifyIOError(InternalQueryState&& query, const struct timeval& now) override { - if (!du) { + if (!query.du) { return; } - if (du->rsock == -1) { + if (query.du->rsock == -1) { return; } + auto du = std::move(query.du); du->ids = std::move(query); du->status_code = 502; sendDoHUnitToTheMainThread(std::move(du), "cross-protocol error response"); } - -private: - DOHUnitUniquePtr du; +protected: + const ClientState& d_cs; }; class DoHCrossProtocolQuery : public CrossProtocolQuery { public: - DoHCrossProtocolQuery(DOHUnitUniquePtr&& du_): du(std::move(du_)) + DoHCrossProtocolQuery(DOHUnitUniquePtr&& du) { query = InternalQuery(std::move(du->query), std::move(du->ids)); + /* it might have been moved when we moved du->ids */ + if (du) { + query.d_idstate.du = std::move(du); + } + /* we _could_ remove it from the query buffer and put in query's d_proxyProtocolPayload, clearing query.d_proxyProtocolPayloadAdded and du->proxyProtocolPayloadSize. Leave it for now because we know that the onky case where the payload has been added is when we tried over UDP, got a TC=1 answer and retried over TCP/DoT, and we know the TCP/DoT code can handle it. */ - query.d_proxyProtocolPayloadAdded = du->proxyProtocolPayloadSize > 0; - downstream = du->downstream; - proxyProtocolPayloadSize = du->proxyProtocolPayloadSize; + query.d_proxyProtocolPayloadAdded = query.d_idstate.du->proxyProtocolPayloadSize > 0; + downstream = query.d_idstate.du->downstream; + proxyProtocolPayloadSize = query.d_idstate.du->proxyProtocolPayloadSize; } void handleInternalError() { - du->status_code = 502; - sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); + query.d_idstate.du->status_code = 502; + sendDoHUnitToTheMainThread(std::move(query.d_idstate.du), "DoH internal error"); } std::shared_ptr getTCPQuerySender() override { - auto sender = std::make_shared(std::move(du)); + query.d_idstate.du->downstream = downstream; + auto sender = std::make_shared(*query.d_idstate.cs); return sender; } - -private: - DOHUnitUniquePtr du; }; /* We are not in the main DoH thread but in the DoH 'client' thread. */ -static void processDOHQuery(DOHUnitUniquePtr&& du) +static void processDOHQuery(DOHUnitUniquePtr&& unit) { + auto& ids = unit->ids; + ids.du = std::move(unit); + auto& du = ids.du; uint16_t queryId = 0; ComboAddress remote; - bool duRefCountIncremented = false; + try { if (!du->req) { // we got closed meanwhile. XXX small race condition here @@ -603,13 +614,11 @@ static void processDOHQuery(DOHUnitUniquePtr&& du) queryId = ntohs(dh->id); } - uint16_t qtype, qclass; - unsigned int qnameWireLength = 0; - DNSName qname(reinterpret_cast(du->query.data()), du->query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); - DNSQuestion dq(&qname, qtype, qclass, &du->ids.origDest, &du->ids.origRemote, du->query, dnsdist::Protocol::DoH, &queryRealTime); - dq.ednsAdded = du->ids.ednsAdded; - /* store the raw pointer */ - dq.du = du.get(); + du->ids.qname = DNSName(reinterpret_cast(du->query.data()), du->query.size(), sizeof(dnsheader), false, &du->ids.qtype, &du->ids.qclass); + DNSQuestion dq(du->ids, du->query, queryRealTime); + const uint16_t* flags = getFlagsFromDNSHeader(dq.getHeader()); + ids.origFlags = *flags; + du->ids.cs = &cs; dq.sni = std::move(du->sni); auto result = processQuery(dq, cs, holders, du->downstream); @@ -627,7 +636,7 @@ static void processDOHQuery(DOHUnitUniquePtr&& du) if (du->response.size() >= sizeof(dnsheader) && du->contentType.empty()) { auto dh = reinterpret_cast(du->response.data()); - handleResponseSent(qname, QType(qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH); + handleResponseSent(ids.qname, QType(ids.qtype), 0., du->ids.origDest, ComboAddress(), du->response.size(), *dh, dnsdist::Protocol::DoH, dnsdist::Protocol::DoH); } sendDoHUnitToTheMainThread(std::move(du), "DoH self-answered response"); return; @@ -639,26 +648,25 @@ static void processDOHQuery(DOHUnitUniquePtr&& du) return; } - if (du->downstream == nullptr) { + auto downstream = du->downstream; + if (downstream == nullptr) { du->status_code = 502; sendDoHUnitToTheMainThread(std::move(du), "DoH no backend available"); return; } - if (du->downstream->isTCPOnly()) { + du->downstream = downstream; + + if (downstream->isTCPOnly()) { std::string proxyProtocolPayload; /* we need to do this _before_ creating the cross protocol query because after that the buffer will have been moved */ - if (du->downstream->d_config.useProxyProtocol) { + if (downstream->d_config.useProxyProtocol) { proxyProtocolPayload = getProxyProtocolPayload(dq); } du->ids.origID = htons(queryId); - du->ids.cs = &cs; - setIDStateFromDNSQuestion(du->ids, dq, std::move(qname)); - du->tcp = true; - std::shared_ptr& downstream = du->downstream; /* this moves du->ids, careful! */ auto cpq = std::make_unique(std::move(du)); @@ -673,88 +681,11 @@ static void processDOHQuery(DOHUnitUniquePtr&& du) } } - ComboAddress dest = du->ids.origDest; - unsigned int idOffset = 0; - int64_t generation; - IDState* ids = du->downstream->getIDState(idOffset, generation); - - /* increase the ref count since we are about to store the pointer */ - du->get(); - duRefCountIncremented = true; - /* store the raw pointer */ - ids->du = du.get(); - - ids->cs = &cs; - ids->origID = htons(queryId); - setIDStateFromDNSQuestion(*ids, dq, std::move(qname)); - - dq.getHeader()->id = idOffset; - - /* If we couldn't harvest the real dest addr, still - write down the listening addr since it will be useful - (especially if it's not an 'any' one). - We need to keep track of which one it is since we may - want to use the real but not the listening addr to reply. - */ - if (dest.sin4.sin_family != 0) { - ids->origDest = dest; - } - else { - ids->origDest = cs.local; - } - - bool failed = false; - if (du->downstream->d_config.useProxyProtocol) { - try { - size_t payloadSize = 0; - if (addProxyProtocol(dq, &payloadSize)) { - du->proxyProtocolPayloadSize = payloadSize; - } - } - catch (const std::exception& e) { - vinfolog("Adding proxy protocol payload to DoH query from %s failed: %s", ids->origDest.toStringWithPort(), e.what()); - failed = true; - } - } - - try { - if (!failed) { - int fd = du->downstream->pickSocketForSending(); - ids->backendFD = fd; - /* you can't touch du after this line, unless the call returned a non-negative value, - because it might already have been freed */ - ssize_t ret = udpClientSendRequestToBackend(du->downstream, fd, du->query); - - if (ret < 0) { - failed = true; - } - } - - if (failed) { - /* we are about to handle the error, make sure that - this pointer is not accessed when the state is cleaned, - but first check that it still belongs to us */ - if (ids->tryMarkUnused(generation)) { - ids->du = nullptr; - du->release(); - duRefCountIncremented = false; - --du->downstream->outstanding; - } - ++du->downstream->sendErrors; - ++g_stats.downstreamSendErrors; - du->status_code = 502; - sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); - return; - } - } - catch (const std::exception& e) { - if (duRefCountIncremented) { - du->release(); - } - throw; + ComboAddress dest = dq.ids.origDest; + if (!assignOutgoingUDPQueryToBackend(downstream, htons(queryId), dq, std::move(du->query), dest)) { + sendDoHUnitToTheMainThread(std::move(du), "DoH internal error"); + return; } - - vinfolog("Got query for %s|%s from %s (https), relayed to %s", ids->qname.toString(), QType(ids->qtype).toString(), remote.toStringWithPort(), du->downstream->getName()); } catch (const std::exception& e) { vinfolog("Got an error in DOH question thread while parsing a query from %s, id %d: %s", remote.toStringWithPort(), queryId, e.what()); @@ -839,6 +770,7 @@ static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_re du->req = req; du->ids.origDest = local; du->ids.origRemote = remote; + du->ids.protocol = dnsdist::Protocol::DoH; du->rsock = dsc->dohresponsepair[0]; if (req->scheme != nullptr) { du->scheme = std::string(req->scheme->name.base, req->scheme->name.len); @@ -1113,11 +1045,11 @@ HTTPHeaderRule::HTTPHeaderRule(const std::string& header, const std::string& reg bool HTTPHeaderRule::matches(const DNSQuestion* dq) const { - if (!dq->du || !dq->du->headers) { + if (!dq->ids.du || !dq->ids.du->headers) { return false; } - for (const auto& header : *dq->du->headers) { + for (const auto& header : *dq->ids.du->headers) { if (header.first == d_header) { return d_regex.match(header.second); } @@ -1138,15 +1070,15 @@ HTTPPathRule::HTTPPathRule(const std::string& path) bool HTTPPathRule::matches(const DNSQuestion* dq) const { - if (!dq->du) { + if (!dq->ids.du) { return false; } - if (dq->du->query_at == SIZE_MAX) { - return dq->du->path == d_path; + if (dq->ids.du->query_at == SIZE_MAX) { + return dq->ids.du->path == d_path; } else { - return d_path.compare(0, d_path.size(), dq->du->path, 0, dq->du->query_at) == 0; + return d_path.compare(0, d_path.size(), dq->ids.du->path, 0, dq->ids.du->query_at) == 0; } } @@ -1161,11 +1093,11 @@ HTTPPathRegexRule::HTTPPathRegexRule(const std::string& regex): d_regex(regex), bool HTTPPathRegexRule::matches(const DNSQuestion* dq) const { - if (!dq->du) { + if (!dq->ids.du) { return false; } - return d_regex.match(dq->du->getHTTPPath()); + return d_regex.match(dq->ids.du->getHTTPPath()); } string HTTPPathRegexRule::toString() const @@ -1330,9 +1262,11 @@ static void on_dnsdist(h2o_socket_t *listener, const char *err) /* restoring the original ID */ dnsheader* queryDH = reinterpret_cast(du->query.data() + du->proxyProtocolPayloadSize); queryDH->id = du->ids.origID; - + du->ids.forwardedOverUDP = false; du->tcp = true; du->truncated = false; + du->response.clear(); + auto cpq = std::make_unique(std::move(du)); if (g_tcpclientthreads && g_tcpclientthreads->passCrossProtocolQueryToThread(std::move(cpq))) { @@ -1657,7 +1591,7 @@ void dohThread(ClientState* cs) } } -void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, IDState&& state) +void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, InternalQueryState&& state) { du->response = std::move(udpResponse); du->ids = std::move(state); @@ -1666,18 +1600,25 @@ void handleUDPResponseForDoH(DOHUnitUniquePtr&& du, PacketBuffer&& udpResponse, if (!dh->tc) { static thread_local LocalStateHolder> localRespRuleActions = g_respruleactions.getLocal(); static thread_local LocalStateHolder> localcacheInsertedRespRuleActions = g_cacheInsertedRespRuleActions.getLocal(); - DNSResponse dr = makeDNSResponseFromIDState(du->ids, du->response); + + DNSResponse dr(du->ids, du->response, du->ids.sentTime.d_start, du->downstream); dnsheader cleartextDH; memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); - if (!processResponse(du->response, *localRespRuleActions, *localcacheInsertedRespRuleActions, dr, false, true)) { + dr.ids.du = std::move(du); + if (!processResponse(dr.ids.du->response, *localRespRuleActions, *localcacheInsertedRespRuleActions, dr, false)) { + if (dr.ids.du) { + dr.ids.du->status_code = 503; + sendDoHUnitToTheMainThread(std::move(dr.ids.du), "Response dropped by rules"); + } return; } + du = std::move(dr.ids.du); double udiff = du->ids.sentTime.udiff(); vinfolog("Got answer from %s, relayed to %s (https), took %f usec", du->downstream->d_config.remote.toStringWithPort(), du->ids.origRemote.toStringWithPort(), udiff); - handleResponseSent(du->ids, udiff, *dr.remote, du->downstream->d_config.remote, du->response.size(), cleartextDH, du->downstream->getProtocol()); + handleResponseSent(du->ids, udiff, dr.ids.origRemote, du->downstream->d_config.remote, du->response.size(), cleartextDH, du->downstream->getProtocol()); ++g_stats.responses; if (du->ids.cs) { diff --git a/pdns/dnsdistdist/test-dnsdistkvs_cc.cc b/pdns/dnsdistdist/test-dnsdistkvs_cc.cc index 5ed25c2a10..2d2f62bfe8 100644 --- a/pdns/dnsdistdist/test-dnsdistkvs_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistkvs_cc.cc @@ -19,7 +19,7 @@ static void doKVSChecks(std::unique_ptr& kvs, const ComboAddress& /* local address is not in the db, remote is */ BOOST_CHECK_EQUAL(kvs->getValue(std::string(reinterpret_cast(&lc.sin4.sin_addr.s_addr), sizeof(lc.sin4.sin_addr.s_addr)), value), false); BOOST_CHECK_EQUAL(kvs->keyExists(std::string(reinterpret_cast(&lc.sin4.sin_addr.s_addr), sizeof(lc.sin4.sin_addr.s_addr))), false); - BOOST_CHECK(kvs->keyExists(std::string(reinterpret_cast(&dq.remote->sin4.sin_addr.s_addr), sizeof(dq.remote->sin4.sin_addr.s_addr)))); + BOOST_CHECK(kvs->keyExists(std::string(reinterpret_cast(&dq.ids.origRemote.sin4.sin_addr.s_addr), sizeof(dq.ids.origRemote.sin4.sin_addr.s_addr)))); auto keys = lookupKey->getKeys(dq); BOOST_CHECK_EQUAL(keys.size(), 1U); @@ -66,7 +66,7 @@ static void doKVSChecks(std::unique_ptr& kvs, const ComboAddress& } } - const DNSName subdomain = DNSName("sub") + *dq.qname; + const DNSName subdomain = DNSName("sub") + dq.ids.qname; const DNSName notPDNS("not-powerdns.com."); /* qname match, in wire format */ @@ -147,9 +147,9 @@ static void doKVSChecks(std::unique_ptr& kvs, const ComboAddress& { auto lookupKey = make_unique(0, true); auto keys = lookupKey->getKeys(dq); - BOOST_CHECK_EQUAL(keys.size(), dq.qname->countLabels()); + BOOST_CHECK_EQUAL(keys.size(), dq.ids.qname.countLabels()); BOOST_REQUIRE(!keys.empty()); - BOOST_CHECK_EQUAL(keys.at(0), dq.qname->toDNSStringLC()); + BOOST_CHECK_EQUAL(keys.at(0), dq.ids.qname.toDNSStringLC()); std::string value; BOOST_CHECK_EQUAL(kvs->getValue(keys.at(0), value), true); BOOST_CHECK_EQUAL(value, "this is the value for the qname"); @@ -184,9 +184,9 @@ static void doKVSChecks(std::unique_ptr& kvs, const ComboAddress& { auto lookupKey = make_unique(0, false); auto keys = lookupKey->getKeys(dq); - BOOST_CHECK_EQUAL(keys.size(), dq.qname->countLabels()); + BOOST_CHECK_EQUAL(keys.size(), dq.ids.qname.countLabels()); BOOST_REQUIRE(!keys.empty()); - BOOST_CHECK_EQUAL(keys.at(0), dq.qname->toStringRootDot()); + BOOST_CHECK_EQUAL(keys.at(0), dq.ids.qname.toStringRootDot()); std::string value; BOOST_CHECK_EQUAL(kvs->getValue(keys.at(0), value), false); value.clear(); @@ -221,7 +221,7 @@ static void doKVSChecks(std::unique_ptr& kvs, const ComboAddress& auto keys = lookupKey->getKeys(dq); BOOST_CHECK_EQUAL(keys.size(), 1U); BOOST_REQUIRE(!keys.empty()); - BOOST_CHECK_EQUAL(keys.at(0), dq.qname->toDNSStringLC()); + BOOST_CHECK_EQUAL(keys.at(0), dq.ids.qname.toDNSStringLC()); std::string value; BOOST_CHECK_EQUAL(kvs->getValue(keys.at(0), value), true); BOOST_CHECK_EQUAL(value, "this is the value for the qname"); @@ -300,21 +300,22 @@ BOOST_AUTO_TEST_SUITE(dnsdistkvs_cc) #ifdef HAVE_LMDB BOOST_AUTO_TEST_CASE(test_LMDB) { - DNSName qname("powerdns.com."); + InternalQueryState ids; + ids.qname = DNSName("powerdns.com."); DNSName plaintextDomain("powerdns.org."); - uint16_t qtype = QType::A; - uint16_t qclass = QClass::IN; - ComboAddress lc("192.0.2.1:53"); - ComboAddress rem("192.0.2.128:42"); + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.origDest = ComboAddress("192.0.2.1:53"); + ids.origRemote = ComboAddress("192.0.2.128:42"); PacketBuffer packet(sizeof(dnsheader)); - auto proto = dnsdist::Protocol::DoUDP; + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryRealTime; gettime(&queryRealTime, true); struct timespec expiredTime; /* the internal QPS limiter does not use the real time */ gettime(&expiredTime); - DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime); + DNSQuestion dq(ids, packet, queryRealTime); ComboAddress v4Masked(v4ToMask); ComboAddress v6Masked(v6ToMask); v4Masked.truncate(25); @@ -330,11 +331,11 @@ BOOST_AUTO_TEST_CASE(test_LMDB) { MDBEnv env(dbPath.c_str(), MDB_NOSUBDIR, 0600, 50); auto transaction = env.getRWTransaction(); auto dbi = transaction->openDB("db-name", MDB_CREATE); - transaction->put(dbi, MDBInVal(std::string(reinterpret_cast(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr))), MDBInVal("this is the value for the remote addr")); - transaction->put(dbi, MDBInVal(std::string(reinterpret_cast(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast(&rem.sin4.sin_port), sizeof(rem.sin4.sin_port))), MDBInVal("this is the value for the remote addr + port")); + transaction->put(dbi, MDBInVal(std::string(reinterpret_cast(&ids.origRemote.sin4.sin_addr.s_addr), sizeof(ids.origRemote.sin4.sin_addr.s_addr))), MDBInVal("this is the value for the remote addr")); + transaction->put(dbi, MDBInVal(std::string(reinterpret_cast(&ids.origRemote.sin4.sin_addr.s_addr), sizeof(ids.origRemote.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast(&ids.origRemote.sin4.sin_port), sizeof(ids.origRemote.sin4.sin_port))), MDBInVal("this is the value for the remote addr + port")); transaction->put(dbi, MDBInVal(std::string(reinterpret_cast(&v4Masked.sin4.sin_addr.s_addr), sizeof(v4Masked.sin4.sin_addr.s_addr))), MDBInVal("this is the value for the masked v4 addr")); transaction->put(dbi, MDBInVal(std::string(reinterpret_cast(&v6Masked.sin6.sin6_addr.s6_addr), sizeof(v6Masked.sin6.sin6_addr.s6_addr))), MDBInVal("this is the value for the masked v6 addr")); - transaction->put(dbi, MDBInVal(qname.toDNSStringLC()), MDBInVal("this is the value for the qname")); + transaction->put(dbi, MDBInVal(dq.ids.qname.toDNSStringLC()), MDBInVal("this is the value for the qname")); transaction->put(dbi, MDBInVal(plaintextDomain.toStringRootDot()), MDBInVal("this is the value for the plaintext domain")); transaction->commit(); @@ -355,7 +356,7 @@ BOOST_AUTO_TEST_CASE(test_LMDB) { } std::unique_ptr lmdb = std::make_unique(dbPath, "db-name"); - doKVSChecks(lmdb, lc, rem, dq, plaintextDomain); + doKVSChecks(lmdb, ids.origDest, ids.origRemote, dq, plaintextDomain); lmdb.reset(); lmdb = std::make_unique(dbPath, "range-db-name"); @@ -385,21 +386,22 @@ BOOST_AUTO_TEST_CASE(test_LMDB) { #ifdef HAVE_CDB BOOST_AUTO_TEST_CASE(test_CDB) { - DNSName qname("powerdns.com."); + InternalQueryState ids; + ids.qname = DNSName("powerdns.com."); DNSName plaintextDomain("powerdns.org."); - uint16_t qtype = QType::A; - uint16_t qclass = QClass::IN; - ComboAddress lc("192.0.2.1:53"); - ComboAddress rem("192.0.2.128:42"); + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.origDest = ComboAddress("192.0.2.1:53"); + ids.origRemote = ComboAddress("192.0.2.128:42"); PacketBuffer packet(sizeof(dnsheader)); - auto proto = dnsdist::Protocol::DoUDP; + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryRealTime; gettime(&queryRealTime, true); struct timespec expiredTime; /* the internal QPS limiter does not use the real time */ gettime(&expiredTime); - DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime); + DNSQuestion dq(ids, packet, queryRealTime); ComboAddress v4Masked(v4ToMask); ComboAddress v6Masked(v6ToMask); v4Masked.truncate(25); @@ -410,17 +412,17 @@ BOOST_AUTO_TEST_CASE(test_CDB) { int fd = mkstemp(db); BOOST_REQUIRE(fd >= 0); CDBWriter writer(fd); - BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr)), "this is the value for the remote addr")); - BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast(&rem.sin4.sin_addr.s_addr), sizeof(rem.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast(&rem.sin4.sin_port), sizeof(rem.sin4.sin_port)), "this is the value for the remote addr + port")); + BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast(&ids.origRemote.sin4.sin_addr.s_addr), sizeof(ids.origRemote.sin4.sin_addr.s_addr)), "this is the value for the remote addr")); + BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast(&ids.origRemote.sin4.sin_addr.s_addr), sizeof(ids.origRemote.sin4.sin_addr.s_addr)) + std::string(reinterpret_cast(&ids.origRemote.sin4.sin_port), sizeof(ids.origRemote.sin4.sin_port)), "this is the value for the remote addr + port")); BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast(&v4Masked.sin4.sin_addr.s_addr), sizeof(v4Masked.sin4.sin_addr.s_addr)), "this is the value for the masked v4 addr")); BOOST_REQUIRE(writer.addEntry(std::string(reinterpret_cast(&v6Masked.sin6.sin6_addr.s6_addr), sizeof(v6Masked.sin6.sin6_addr.s6_addr)), "this is the value for the masked v6 addr")); - BOOST_REQUIRE(writer.addEntry(qname.toDNSStringLC(), "this is the value for the qname")); + BOOST_REQUIRE(writer.addEntry(dq.ids.qname.toDNSStringLC(), "this is the value for the qname")); BOOST_REQUIRE(writer.addEntry(plaintextDomain.toStringRootDot(), "this is the value for the plaintext domain")); writer.close(); } std::unique_ptr cdb = std::make_unique(db, 0); - doKVSChecks(cdb, lc, rem, dq, plaintextDomain); + doKVSChecks(cdb, ids.origDest, ids.origRemote, dq, plaintextDomain); unlink(db); diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc index 9eee1cca51..fd3bd3615f 100644 --- a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc @@ -108,17 +108,18 @@ std::atomic g_configurationDone{false}; static DNSQuestion getDQ(const DNSName* providedName = nullptr) { static const DNSName qname("powerdns.com."); - static const ComboAddress lc("127.0.0.1:53"); - static const ComboAddress rem("192.0.2.1:42"); static struct timespec queryRealTime; static PacketBuffer packet(sizeof(dnsheader)); - - uint16_t qtype = QType::A; - uint16_t qclass = QClass::IN; - auto proto = dnsdist::Protocol::DoUDP; + static InternalQueryState ids; + ids.origDest = ComboAddress("127.0.0.1:53"); + ids.origRemote = ComboAddress("192.0.2.1:42"); + ids.qname = providedName ? *providedName : qname; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; gettime(&queryRealTime, true); - DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime); + DNSQuestion dq(ids, packet, queryRealTime); return dq; } diff --git a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc index 7bf4e2c8ef..f9abd334f8 100644 --- a/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistnghttp2_cc.cc @@ -635,7 +635,7 @@ public: { } - void notifyIOError(IDState&& query, const struct timeval& now) override + void notifyIOError(InternalQueryState&& query, const struct timeval& now) override { d_error = true; } @@ -730,7 +730,7 @@ BOOST_FIXTURE_TEST_CASE(test_SingleQuery, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); s_steps = { {ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done}, @@ -808,7 +808,7 @@ BOOST_FIXTURE_TEST_CASE(test_ConcurrentQueries, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -897,7 +897,7 @@ BOOST_FIXTURE_TEST_CASE(test_ConnectionReuse, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1011,7 +1011,7 @@ BOOST_FIXTURE_TEST_CASE(test_InvalidDNSAnswer, TestFixture) while TCP and DoT will first pass it back to the TCP worker thread */ throw std::runtime_error("Invalid response"); }; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); s_steps = { {ExpectedStep::ExpectedRequest::connectToBackend, IOState::Done}, @@ -1086,7 +1086,7 @@ BOOST_FIXTURE_TEST_CASE(test_TimeoutWhileWriting, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1173,7 +1173,7 @@ BOOST_FIXTURE_TEST_CASE(test_TimeoutWhileReading, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1260,7 +1260,7 @@ BOOST_FIXTURE_TEST_CASE(test_ShortWrite, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1347,7 +1347,7 @@ BOOST_FIXTURE_TEST_CASE(test_ShortRead, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1441,7 +1441,7 @@ BOOST_FIXTURE_TEST_CASE(test_ConnectionClosedWhileReading, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1527,7 +1527,7 @@ BOOST_FIXTURE_TEST_CASE(test_ConnectionClosedWhileWriting, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1623,7 +1623,7 @@ BOOST_FIXTURE_TEST_CASE(test_GoAwayFromServer, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1732,7 +1732,7 @@ BOOST_FIXTURE_TEST_CASE(test_HTTP500FromServer, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1825,7 +1825,7 @@ BOOST_FIXTURE_TEST_CASE(test_WrongStreamID, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); queries.push_back({std::move(sender), std::move(internalQuery)}); } @@ -1928,7 +1928,7 @@ BOOST_FIXTURE_TEST_CASE(test_ProxyProtocol, TestFixture) auto sender = std::make_shared(); sender->d_id = counter; std::string payload = makeProxyHeader(counter % 2, local, local, {}); - InternalQuery internalQuery(std::move(query), IDState()); + InternalQuery internalQuery(std::move(query), InternalQueryState()); internalQuery.d_proxyProtocolPayload = std::move(payload); queries.push_back({std::move(sender), std::move(internalQuery)}); } diff --git a/pdns/dnsdistdist/test-dnsdistrules_cc.cc b/pdns/dnsdistdist/test-dnsdistrules_cc.cc index e27e6f6afe..6ace4199ac 100644 --- a/pdns/dnsdistdist/test-dnsdistrules_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistrules_cc.cc @@ -18,17 +18,18 @@ void checkParameterBound(const std::string& parameter, uint64_t value, size_t ma static DNSQuestion getDQ(const DNSName* providedName = nullptr) { static const DNSName qname("powerdns.com."); - static const ComboAddress lc("127.0.0.1:53"); - static const ComboAddress rem("192.0.2.1:42"); static struct timespec queryRealTime; static PacketBuffer packet(sizeof(dnsheader)); - - uint16_t qtype = QType::A; - uint16_t qclass = QClass::IN; - auto proto = dnsdist::Protocol::DoUDP; + static InternalQueryState ids; + ids.origDest = ComboAddress("127.0.0.1:53"); + ids.origRemote = ComboAddress("192.0.2.1:42"); + ids.qname = providedName ? *providedName : qname; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; gettime(&queryRealTime, true); - DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime); + DNSQuestion dq(ids, packet, queryRealTime); return dq; } @@ -42,24 +43,25 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { unsigned int scanFraction = 10; MaxQPSIPRule rule(maxQPS, maxBurst, 32, 64, expiration, cleanupDelay, scanFraction); - DNSName qname("powerdns.com."); - uint16_t qtype = QType::A; - uint16_t qclass = QClass::IN; - ComboAddress lc("127.0.0.1:53"); - ComboAddress rem("192.0.2.1:42"); + InternalQueryState ids; + ids.qname = DNSName("powerdns.com."); + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.origDest = ComboAddress("127.0.0.1:53"); + ids.origRemote = ComboAddress("192.0.2.1:42"); + ids.protocol = dnsdist::Protocol::DoUDP; PacketBuffer packet(sizeof(dnsheader)); - auto proto = dnsdist::Protocol::DoUDP; struct timespec queryRealTime; gettime(&queryRealTime, true); struct timespec expiredTime; /* the internal QPS limiter does not use the real time */ gettime(&expiredTime); - DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime); + DNSQuestion dq(ids, packet, queryRealTime); for (size_t idx = 0; idx < maxQPS; idx++) { /* let's use different source ports, it shouldn't matter */ - rem = ComboAddress("192.0.2.1:" + std::to_string(idx)); + ids.origRemote = ComboAddress("192.0.2.1:" + std::to_string(idx)); BOOST_CHECK_EQUAL(rule.matches(&dq), false); BOOST_CHECK_EQUAL(rule.getEntriesCount(), 1U); } @@ -87,7 +89,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { /* Let's insert a lot of different sources now */ for (size_t idxByte3 = 0; idxByte3 < 256; idxByte3++) { for (size_t idxByte4 = 0; idxByte4 < 256; idxByte4++) { - rem = ComboAddress("10.0." + std::to_string(idxByte3) + "." + std::to_string(idxByte4)); + ids.origRemote = ComboAddress("10.0." + std::to_string(idxByte3) + "." + std::to_string(idxByte4)); BOOST_CHECK_EQUAL(rule.matches(&dq), false); } } diff --git a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc index 056c08ed12..e427bf7717 100644 --- a/pdns/dnsdistdist/test-dnsdisttcp_cc.cc +++ b/pdns/dnsdistdist/test-dnsdisttcp_cc.cc @@ -59,7 +59,7 @@ uint64_t uptimeOfProcess(const std::string& str) return 0; } -void handleResponseSent(const IDState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol) +void handleResponseSent(const InternalQueryState& ids, double udiff, const ComboAddress& client, const ComboAddress& backend, unsigned int size, const dnsheader& cleartextDH, dnsdist::Protocol protocol) { } @@ -81,7 +81,7 @@ bool responseContentMatches(const PacketBuffer& response, const DNSName& qname, static std::function s_processResponse; -bool processResponse(PacketBuffer& response, const std::vector& localRespRuleActions, const std::vector& localCacheInsertedRespRuleActions, DNSResponse& dr, bool muted, bool receivedOverUDP) +bool processResponse(PacketBuffer& response, const std::vector& localRespRuleActions, const std::vector& localCacheInsertedRespRuleActions, DNSResponse& dr, bool muted) { if (s_processResponse) { return s_processResponse(response, dr, muted); diff --git a/pdns/doh.hh b/pdns/doh.hh index d717f05e35..62e7f83d29 100644 --- a/pdns/doh.hh +++ b/pdns/doh.hh @@ -177,9 +177,19 @@ struct DOHFrontend #ifndef HAVE_DNS_OVER_HTTPS struct DOHUnit { - static void release(DOHUnit* ptr) + static void release(DOHUnit*) + { + } + + void get() { } + + void release() + { + } + size_t proxyProtocolPayloadSize{0}; + uint16_t status_code{200}; }; #else /* HAVE_DNS_OVER_HTTPS */ @@ -223,7 +233,7 @@ struct DOHUnit } } - IDState ids; + InternalQueryState ids; std::string sni; std::string path; std::string scheme; @@ -261,7 +271,7 @@ struct DOHUnit void setHTTPResponse(uint16_t statusCode, PacketBuffer&& body, const std::string& contentType=""); }; -void handleUDPResponseForDoH(std::unique_ptr&&, PacketBuffer&& response, IDState&& state); +void handleUDPResponseForDoH(std::unique_ptr&&, PacketBuffer&& response, InternalQueryState&& state); #endif /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index 8fe793060d..d0cc62c638 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -42,6 +42,11 @@ bool DNSDistSNMPAgent::sendBackendStatusChangeTrap(DownstreamState const&) return false; } +bool assignOutgoingUDPQueryToBackend(std::shared_ptr& ds, uint16_t queryID, DNSQuestion& dq, PacketBuffer&& query, ComboAddress& dest) +{ + return false; +} + BOOST_AUTO_TEST_SUITE(test_dnsdist_cc) static const uint16_t ECSSourcePrefixV4 = 24; @@ -62,12 +67,12 @@ static void validateQuery(const PacketBuffer& packet, bool hasEdns=true, bool ha static void validateECS(const PacketBuffer& packet, const ComboAddress& expected) { - ComboAddress rem("::1"); - unsigned int consumed = 0; - uint16_t qtype; - uint16_t qclass; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &rem, const_cast(packet), dnsdist::Protocol::DoUDP, nullptr); + struct timespec queryTime; + InternalQueryState ids; + ids.protocol = dnsdist::Protocol::DoUDP; + ids.origRemote = ComboAddress("::1"); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + DNSQuestion dq(ids, const_cast(packet), queryTime); BOOST_CHECK(parseEDNSOptions(dq)); BOOST_REQUIRE(dq.ednsOptions != nullptr); BOOST_CHECK_EQUAL(dq.ednsOptions->size(), 1U); @@ -100,8 +105,11 @@ BOOST_AUTO_TEST_CASE(test_addXPF) struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests - ComboAddress remote; DNSName name("www.powerdns.com."); + InternalQueryState ids; + ids.protocol = dnsdist::Protocol::DoUDP; + ids.origRemote = ComboAddress("::1"); + ids.origDest = ComboAddress("::1"); PacketBuffer query; GenericDNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); @@ -112,13 +120,10 @@ BOOST_AUTO_TEST_CASE(test_addXPF) PacketBuffer packet = query; /* large enough packet */ - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + DNSQuestion dq(ids, const_cast(packet), queryTime); + BOOST_CHECK_EQUAL(ids.qname, name); + BOOST_CHECK(ids.qtype == QType::A); BOOST_CHECK(addXPF(dq, xpfOptionCode)); BOOST_CHECK(packet.size() > query.size()); @@ -131,13 +136,10 @@ BOOST_AUTO_TEST_CASE(test_addXPF) /* packet is already too large for the 4096 limit over UDP */ packet.resize(4096); - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + DNSQuestion dq(ids, const_cast(packet), queryTime); + BOOST_CHECK_EQUAL(ids.qname, name); + BOOST_CHECK(ids.qtype == QType::A); BOOST_REQUIRE(!addXPF(dq, xpfOptionCode)); BOOST_CHECK_EQUAL(packet.size(), 4096U); @@ -149,13 +151,10 @@ BOOST_AUTO_TEST_CASE(test_addXPF) PacketBuffer packet = query; /* packet with trailing data (overriding it) */ - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + DNSQuestion dq(ids, const_cast(packet), queryTime); + BOOST_CHECK_EQUAL(ids.qname, name); + BOOST_CHECK(ids.qtype == QType::A); /* add trailing data */ const size_t trailingDataSize = 10; @@ -323,9 +322,12 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSButWithAnswer) BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) { + InternalQueryState ids; + ids.origRemote = ComboAddress("192.0.2.1"); + ids.protocol = dnsdist::Protocol::DoUDP; + struct timespec queryTime; bool ednsAdded = false; bool ecsAdded = false; - ComboAddress remote("192.0.2.1"); DNSName name("www.powerdns.com."); PacketBuffer query; @@ -334,15 +336,12 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) auto packet = query; - unsigned int consumed = 0; - uint16_t qtype; - uint16_t qclass; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(qclass == QClass::IN); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + BOOST_CHECK_EQUAL(ids.qname, name); + BOOST_CHECK(ids.qtype == QType::A); + BOOST_CHECK(ids.qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + DNSQuestion dq(ids, packet, queryTime); /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ BOOST_CHECK(!parseEDNSOptions(dq)); @@ -352,7 +351,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) BOOST_CHECK_EQUAL(ednsAdded, true); BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet); - validateECS(packet, remote); + validateECS(packet, ids.origRemote); /* trailing data */ packet = query; @@ -360,12 +359,12 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) ednsAdded = false; ecsAdded = false; - consumed = 0; - qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + BOOST_CHECK_EQUAL(ids.qname, name); + BOOST_CHECK(ids.qtype == QType::A); + BOOST_CHECK(ids.qclass == QClass::IN); + DNSQuestion dq2(ids, packet, queryTime); BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded)); BOOST_CHECK_GT(packet.size(), query.size()); @@ -373,7 +372,7 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) BOOST_CHECK_EQUAL(ednsAdded, true); BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet); - validateECS(packet, remote); + validateECS(packet, ids.origRemote); } BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { @@ -423,9 +422,12 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { } BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { + InternalQueryState ids; + ids.origRemote = ComboAddress("2001:DB8::1"); + ids.protocol = dnsdist::Protocol::DoUDP; + struct timespec queryTime; bool ednsAdded = false; bool ecsAdded = false; - ComboAddress remote("2001:DB8::1"); DNSName name("www.powerdns.com."); PacketBuffer query; @@ -436,15 +438,12 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { auto packet = query; - unsigned int consumed = 0; - uint16_t qtype; - uint16_t qclass; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(qclass == QClass::IN); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + BOOST_CHECK_EQUAL(ids.qname, name); + BOOST_CHECK(ids.qtype == QType::A); + BOOST_CHECK(ids.qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + DNSQuestion dq(ids, packet, queryTime); /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ BOOST_CHECK(parseEDNSOptions(dq)); @@ -454,19 +453,18 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet); - validateECS(packet, remote); + validateECS(packet, ids.origRemote); /* trailing data */ packet = query; packet.resize(2048); - consumed = 0; ednsAdded = false; ecsAdded = false; - qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, &ids.qclass); + BOOST_CHECK_EQUAL(ids.qname, name); + BOOST_CHECK(ids.qtype == QType::A); + BOOST_CHECK(ids.qclass == QClass::IN); + DNSQuestion dq2(ids, packet, queryTime); BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded)); BOOST_CHECK_GT(packet.size(), query.size()); @@ -474,7 +472,7 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); validateQuery(packet); - validateECS(packet, remote); + validateECS(packet, ids.origRemote); } BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { @@ -518,11 +516,15 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) { bool ednsAdded = false; bool ecsAdded = false; ComboAddress remote("192.168.1.25"); - DNSName name("www.powerdns.com."); ComboAddress origRemote("127.0.0.1"); + InternalQueryState ids; + ids.origRemote = remote; + ids.protocol = dnsdist::Protocol::DoUDP; + ids.qname = DNSName("www.powerdns.com."); + struct timespec queryTime; PacketBuffer query; - GenericDNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pw(query, ids.qname, QType::A, QClass::IN, 0); pw.getHeader()->rd = 1; EDNSSubnetOpts ecsOpts; ecsOpts.source = Netmask(origRemote, ECSSourcePrefixV4); @@ -538,11 +540,11 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) { uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - BOOST_CHECK_EQUAL(qname, name); + BOOST_CHECK_EQUAL(qname, ids.qname); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + DNSQuestion dq(ids, packet, queryTime); dq.ecsOverride = true; /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ @@ -1421,18 +1423,13 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) { validateResponse(newResponse, true, 1); } -static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& realTime, PacketBuffer& query) -{ - return DNSQuestion(&qname, qtype, qclass, &lc, &rem, query, dnsdist::Protocol::DoUDP, &realTime); -} - -static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, PacketBuffer& query, bool resizeBuffer=true) +static DNSQuestion turnIntoResponse(InternalQueryState& ids, PacketBuffer& query, struct timespec& queryRealTime, bool resizeBuffer=true) { if (resizeBuffer) { query.resize(4096); } - auto dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query); + auto dq = DNSQuestion(ids, query, queryRealTime); BOOST_CHECK(addEDNSToQueryTurnedResponse(dq)); @@ -1441,11 +1438,17 @@ static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, static int getZ(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, PacketBuffer& query) { - ComboAddress lc("127.0.0.1"); - ComboAddress rem("127.0.0.1"); + InternalQueryState ids; + ids.protocol = dnsdist::Protocol::DoUDP; + ids.qname = qname; + ids.qtype = qtype; + ids.qclass = qclass; + ids.origDest = ComboAddress("127.0.0.1"); + ids.origRemote = ComboAddress("127.0.0.1"); struct timespec queryRealTime; gettime(&queryRealTime, true); - DNSQuestion dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query); + + auto dq = DNSQuestion(ids, query, queryRealTime); return getEDNSZ(dq); } @@ -1547,12 +1550,14 @@ BOOST_AUTO_TEST_CASE(test_getEDNSZ) { } BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { - + InternalQueryState ids; + ids.qname = DNSName("www.powerdns.com."); + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.origDest = ComboAddress("127.0.0.1"); + ids.origRemote = ComboAddress("127.0.0.1"); uint16_t z; uint16_t udpPayloadSize; - DNSName qname("www.powerdns.com."); - uint16_t qtype = QType::A; - uint16_t qclass = QClass::IN; EDNSSubnetOpts ecsOpts; ecsOpts.source = Netmask(ComboAddress("127.0.0.1"), ECSSourcePrefixV4); string origECSOptionStr = makeEDNSSubnetOptsString(ecsOpts); @@ -1561,20 +1566,18 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { GenericDNSPacketWriter::optvect_t opts; opts.emplace_back(EDNSOptionCode::COOKIE, cookiesOptionStr); opts.emplace_back(EDNSOptionCode::ECS, origECSOptionStr); - ComboAddress lc("127.0.0.1"); - ComboAddress rem("127.0.0.1"); struct timespec queryRealTime; gettime(&queryRealTime, true); { /* no EDNS */ PacketBuffer query; - GenericDNSPacketWriter pw(query, qname, qtype, qclass, 0); + GenericDNSPacketWriter pw(query, ids.qname, ids.qtype, ids.qclass, 0); pw.getHeader()->qr = 1; pw.getHeader()->rcode = RCode::NXDomain; pw.commit(); - auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); + auto dq = turnIntoResponse(ids, query, queryRealTime); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), false); BOOST_CHECK_EQUAL(z, 0); @@ -1584,12 +1587,12 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { { /* truncated EDNS */ PacketBuffer query; - GenericDNSPacketWriter pw(query, qname, qtype, qclass, 0); + GenericDNSPacketWriter pw(query, ids.qname, ids.qtype, ids.qclass, 0); pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO); pw.commit(); query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* last byte of TTL / Z */ 1)); - auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query, false); + auto dq = turnIntoResponse(ids, query, queryRealTime, false); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), false); BOOST_CHECK_EQUAL(z, 0); @@ -1599,11 +1602,11 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { { /* valid EDNS, no options, DO not set */ PacketBuffer query; - GenericDNSPacketWriter pw(query, qname, qtype, qclass, 0); + GenericDNSPacketWriter pw(query, ids.qname, ids.qtype, ids.qclass, 0); pw.addOpt(512, 0, 0); pw.commit(); - auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); + auto dq = turnIntoResponse(ids, query, queryRealTime); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, 0); @@ -1613,11 +1616,11 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { { /* valid EDNS, no options, DO set */ PacketBuffer query; - GenericDNSPacketWriter pw(query, qname, qtype, qclass, 0); + GenericDNSPacketWriter pw(query, ids.qname, ids.qtype, ids.qclass, 0); pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO); pw.commit(); - auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); + auto dq = turnIntoResponse(ids, query, queryRealTime); BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO); BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO); @@ -1627,11 +1630,11 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { { /* valid EDNS, options, DO not set */ PacketBuffer query; - GenericDNSPacketWriter pw(query, qname, qtype, qclass, 0); + GenericDNSPacketWriter pw(query, ids.qname, ids.qtype, ids.qclass, 0); pw.addOpt(512, 0, 0, opts); pw.commit(); - auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); + auto dq = turnIntoResponse(ids, query, queryRealTime); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, 0); @@ -1641,11 +1644,11 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { { /* valid EDNS, options, DO set */ PacketBuffer query; - GenericDNSPacketWriter pw(query, qname, qtype, qclass, 0); + GenericDNSPacketWriter pw(query, ids.qname, ids.qtype, ids.qclass, 0); pw.addOpt(512, 0, EDNS_HEADER_FLAG_DO, opts); pw.commit(); - auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); + auto dq = turnIntoResponse(ids, query, queryRealTime); BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO); BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO); @@ -1898,6 +1901,10 @@ BOOST_AUTO_TEST_CASE(test_isEDNSOptionInOpt) { } BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { + InternalQueryState ids; + ids.origRemote = ComboAddress("192.0.2.1"); + ids.protocol = dnsdist::Protocol::DoUDP; + struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests ComboAddress remote; @@ -1917,10 +1924,8 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* no incoming EDNS */ auto packet = query; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, false)); BOOST_CHECK(packet.size() > query.size()); @@ -1941,10 +1946,8 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* now with incoming EDNS */ auto packet = queryWithEDNS; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, false)); BOOST_CHECK(packet.size() > queryWithEDNS.size()); @@ -1969,10 +1972,8 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* no incoming EDNS */ auto packet = query; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, false)); BOOST_CHECK(packet.size() > query.size()); @@ -1993,10 +1994,8 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* now with incoming EDNS */ auto packet = queryWithEDNS; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, false)); BOOST_CHECK(packet.size() > queryWithEDNS.size()); @@ -2023,12 +2022,11 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* no incoming EDNS */ auto packet = query; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); - BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, true)); + BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , + 5, true)); BOOST_CHECK(packet.size() > query.size()); MOADNSParser mdp(true, reinterpret_cast(packet.data()), packet.size()); @@ -2047,10 +2045,8 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* now with incoming EDNS */ auto packet = queryWithEDNS; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, true)); BOOST_CHECK(packet.size() > queryWithEDNS.size()); @@ -2075,10 +2071,8 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* no incoming EDNS */ auto packet = query; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, true)); BOOST_CHECK(packet.size() > query.size()); @@ -2099,10 +2093,8 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* now with incoming EDNS */ auto packet = queryWithEDNS; - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); - DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, dnsdist::Protocol::DoUDP, &queryTime); + ids.qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &ids.qtype, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5, true)); BOOST_CHECK(packet.size() > queryWithEDNS.size()); @@ -2124,9 +2116,12 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { } BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { - const ComboAddress remote("192.168.1.25"); + InternalQueryState ids; + ids.origRemote = ComboAddress("192.168.1.25"); + ids.protocol = dnsdist::Protocol::DoUDP; + struct timespec queryTime; + const DNSName name("www.powerdns.com."); - const ComboAddress origRemote("127.0.0.1"); const ComboAddress v4("192.0.2.1"); { @@ -2143,7 +2138,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(!parseEDNSOptions(dq)); } @@ -2164,7 +2159,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(!parseEDNSOptions(dq)); } @@ -2185,7 +2180,7 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { uint16_t qtype; uint16_t qclass; DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, dnsdist::Protocol::DoUDP, nullptr); + DNSQuestion dq(ids, packet, queryTime); BOOST_CHECK(!parseEDNSOptions(dq)); } @@ -2193,12 +2188,14 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { BOOST_AUTO_TEST_CASE(test_setEDNSOption) { - DNSName qname("powerdns.com."); - uint16_t qtype = QType::A; - uint16_t qclass = QClass::IN; - ComboAddress lc("127.0.0.1:53"); - ComboAddress rem("192.0.2.1:42"); - auto proto = dnsdist::Protocol::DoUDP; + InternalQueryState ids; + ids.origRemote = ComboAddress("192.0.2.1:42"); + ids.origDest = ComboAddress("127.0.0.1:53"); + ids.protocol = dnsdist::Protocol::DoUDP; + ids.qname = DNSName("powerdns.com."); + ids.qtype = QType::A; + ids.qclass = QClass::IN; + struct timespec queryRealTime; gettime(&queryRealTime, true); struct timespec expiredTime; @@ -2206,11 +2203,11 @@ BOOST_AUTO_TEST_CASE(test_setEDNSOption) gettime(&expiredTime); PacketBuffer packet; - GenericDNSPacketWriter pw(packet, qname, qtype, qclass, 0); + GenericDNSPacketWriter pw(packet, ids.qname, ids.qtype, ids.qclass, 0); pw.addOpt(4096, 0, EDNS_HEADER_FLAG_DO); pw.commit(); - DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, packet, proto, &queryRealTime); + DNSQuestion dq(ids, packet, queryRealTime); std::string result; EDNSCookiesOpt cookiesOpt("deadbeefdeadbeef"); @@ -2221,7 +2218,7 @@ BOOST_AUTO_TEST_CASE(test_setEDNSOption) const auto& data = dq.getData(); MOADNSParser mdp(true, reinterpret_cast(data.data()), data.size()); - BOOST_CHECK_EQUAL(mdp.d_qname.toString(), qname.toString()); + BOOST_CHECK_EQUAL(mdp.d_qname.toString(), ids.qname.toString()); BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1U); BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0U); BOOST_CHECK_EQUAL(mdp.d_header.nscount, 0U); diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index a72ce76424..7eea14ee19 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -24,15 +24,19 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests - size_t counter=0; - size_t skipped=0; - ComboAddress remote; + size_t counter = 0; + size_t skipped = 0; bool dnssecOK = false; const time_t now = time(nullptr); + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; + try { for (counter = 0; counter < 100000; ++counter) { - DNSName a=DNSName(std::to_string(counter))+DNSName(" hello"); - BOOST_CHECK_EQUAL(DNSName(a.toString()), a); + auto a = DNSName(std::to_string(counter))+DNSName(" hello"); + ids.qname = a; PacketBuffer query; GenericDNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); @@ -50,7 +54,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -75,16 +79,16 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { size_t deleted=0; size_t delcounter=0; for (delcounter=0; delcounter < counter/1000; ++delcounter) { - DNSName a=DNSName(std::to_string(delcounter))+DNSName(" hello"); + ids.qname = DNSName(std::to_string(delcounter))+DNSName(" hello"); PacketBuffer query; - GenericDNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); if (found == true) { - auto removed = PC.expungeByName(a); + auto removed = PC.expungeByName(ids.qname); BOOST_CHECK_EQUAL(removed, 1U); deleted += removed; } @@ -94,13 +98,13 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { size_t matches=0; size_t expected=counter-skipped-deleted; for (; delcounter < counter; ++delcounter) { - DNSName a(DNSName(std::to_string(delcounter))+DNSName(" hello")); + ids.qname = DNSName(std::to_string(delcounter))+DNSName(" hello"); PacketBuffer query; - GenericDNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); if (PC.get(dq, pwQ.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP)) { matches++; } @@ -137,34 +141,38 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) { ComboAddress remote; bool dnssecOK = false; const time_t now = time(nullptr); + InternalQueryState ids; + ids.qtype = QType::AAAA; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; try { for (counter = 0; counter < 100000; ++counter) { - DNSName a(std::to_string(counter) + ".powerdns.com."); + ids.qname = DNSName(std::to_string(counter) + ".powerdns.com."); PacketBuffer query; - GenericDNSPacketWriter pwQ(query, a, QType::AAAA, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, QType::AAAA, QClass::IN, 0); pwQ.getHeader()->rd = 1; PacketBuffer response; - GenericDNSPacketWriter pwR(response, a, QType::AAAA, QClass::IN, 0); + GenericDNSPacketWriter pwR(response, ids.qname, QType::AAAA, QClass::IN, 0); pwR.getHeader()->rd = 1; pwR.getHeader()->ra = 1; pwR.getHeader()->qr = 1; pwR.getHeader()->id = pwQ.getHeader()->id; - pwR.startRecord(a, QType::AAAA, 7200, QClass::IN, DNSResourceRecord::ANSWER); + pwR.startRecord(ids.qname, QType::AAAA, 7200, QClass::IN, DNSResourceRecord::ANSWER); ComboAddress v6("2001:db8::1"); pwR.xfrIP6(std::string(reinterpret_cast(v6.sin6.sin6_addr.s6_addr), 16)); pwR.commit(); uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::AAAA, QClass::IN, response, receivedOverUDP, 0, boost::none); found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true); if (found == true) { @@ -183,14 +191,14 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) { size_t matches = 0; for (counter = 0; counter < 100000; ++counter) { - DNSName a(std::to_string(counter) + ".powerdns.com."); + ids.qname = DNSName(std::to_string(counter) + ".powerdns.com."); PacketBuffer query; - GenericDNSPacketWriter pwQ(query, a, QType::AAAA, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, QType::AAAA, QClass::IN, 0); pwQ.getHeader()->rd = 1; uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); if (PC.get(dq, pwQ.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP)) { matches++; } @@ -226,14 +234,18 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSharded) { BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { const size_t maxEntries = 150000; DNSDistPacketCache PC(maxEntries, 86400, 1); + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests ComboAddress remote; bool dnssecOK = false; try { - DNSName a = DNSName("tcp"); - BOOST_CHECK_EQUAL(DNSName(a.toString()), a); + DNSName a("tcp"); + ids.qname = a; PacketBuffer query; GenericDNSPacketWriter pwQ(query, a, QType::AAAA, QClass::IN, 0); @@ -254,7 +266,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { /* UDP */ uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -269,7 +281,8 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { /* same but over TCP */ uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoTCP, &queryTime); + ids.protocol = dnsdist::Protocol::DoTCP; + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, !receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -289,6 +302,10 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTCP) { BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { const size_t maxEntries = 150000; DNSDistPacketCache PC(maxEntries, 86400, 1); + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests @@ -296,7 +313,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { bool dnssecOK = false; try { DNSName a = DNSName("servfail"); - BOOST_CHECK_EQUAL(DNSName(a.toString()), a); + ids.qname = a; PacketBuffer query; GenericDNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); @@ -313,7 +330,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -340,13 +357,18 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { const size_t maxEntries = 150000; DNSDistPacketCache PC(maxEntries, /* maxTTL */ 86400, /* minTTL */ 1, /* tempFailureTTL */ 60, /* maxNegativeTTL */ 1); + ComboAddress remote; + bool dnssecOK = false; + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests - ComboAddress remote; - bool dnssecOK = false; try { DNSName name("nodata"); + ids.qname = name; PacketBuffer query; GenericDNSPacketWriter pwQ(query, name, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; @@ -366,7 +388,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -392,6 +414,10 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { const size_t maxEntries = 150000; DNSDistPacketCache PC(maxEntries, /* maxTTL */ 86400, /* minTTL */ 1, /* tempFailureTTL */ 60, /* maxNegativeTTL */ 1); + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests @@ -399,6 +425,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { bool dnssecOK = false; try { DNSName name("nxdomain"); + ids.qname = name; PacketBuffer query; GenericDNSPacketWriter pwQ(query, name, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; @@ -418,7 +445,7 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); @@ -447,17 +474,21 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTruncated) { struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests - ComboAddress remote; + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; + ids.sentTime.start(); // does not have to be accurate ("realTime") in tests bool dnssecOK = false; try { - DNSName name("truncated"); + ids.qname = DNSName("truncated"); PacketBuffer query; - GenericDNSPacketWriter pwQ(query, name, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; PacketBuffer response; - GenericDNSPacketWriter pwR(response, name, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pwR(response, ids.qname, QType::A, QClass::IN, 0); pwR.getHeader()->rd = 1; pwR.getHeader()->ra = 0; pwR.getHeader()->qr = 1; @@ -465,18 +496,18 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheTruncated) { pwR.getHeader()->rcode = RCode::NoError; pwR.getHeader()->id = pwQ.getHeader()->id; pwR.commit(); - pwR.startRecord(name, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER); + pwR.startRecord(ids.qname, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER); pwR.xfr32BitInt(0x01020304); pwR.commit(); uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, RCode::NXDomain, boost::none); bool allowTruncated = true; found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, receivedOverUDP, 0, true, allowTruncated); @@ -497,33 +528,38 @@ static DNSDistPacketCache g_PC(500000); static void threadMangler(unsigned int offset) { + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests + try { ComboAddress remote; bool dnssecOK = false; for(unsigned int counter=0; counter < 100000; ++counter) { - DNSName a=DNSName("hello ")+DNSName(std::to_string(counter+offset)); + ids.qname = DNSName("hello ")+DNSName(std::to_string(counter+offset)); PacketBuffer query; - GenericDNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; PacketBuffer response; - GenericDNSPacketWriter pwR(response, a, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pwR(response, ids.qname, QType::A, QClass::IN, 0); pwR.getHeader()->rd = 1; pwR.getHeader()->ra = 1; pwR.getHeader()->qr = 1; pwR.getHeader()->id = pwQ.getHeader()->id; - pwR.startRecord(a, QType::A, 3600, QClass::IN, DNSResourceRecord::ANSWER); + pwR.startRecord(ids.qname, QType::A, 3600, QClass::IN, DNSResourceRecord::ANSWER); pwR.xfr32BitInt(0x01020304); pwR.commit(); uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); - g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); + g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, ids.qname, QType::A, QClass::IN, response, receivedOverUDP, 0, boost::none); } } catch(PDNSException& e) { @@ -536,21 +572,26 @@ AtomicCounter g_missing; static void threadReader(unsigned int offset) { - bool dnssecOK = false; + InternalQueryState ids; + ids.qtype = QType::A; + ids.qclass = QClass::IN; + ids.qname = DNSName("www.powerdns.com."); + ids.protocol = dnsdist::Protocol::DoUDP; struct timespec queryTime; gettime(&queryTime); // does not have to be accurate ("realTime") in tests + bool dnssecOK = false; try { ComboAddress remote; for(unsigned int counter=0; counter < 100000; ++counter) { - DNSName a=DNSName("hello ")+DNSName(std::to_string(counter+offset)); + ids.qname = DNSName("hello ")+DNSName(std::to_string(counter+offset)); PacketBuffer query; - GenericDNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + DNSQuestion dq(ids, query, queryTime); bool found = g_PC.get(dq, 0, &key, subnet, dnssecOK, receivedOverUDP); if (!found) { g_missing++; @@ -601,19 +642,24 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { DNSDistPacketCache PC(maxEntries, 86400, 1, 60, 3600, 60, false, 1, true, true); BOOST_CHECK_EQUAL(PC.getSize(), 0U); - DNSName qname("www.powerdns.com."); - uint16_t qtype = QType::AAAA; + InternalQueryState ids; + ids.qtype = QType::AAAA; + ids.qclass = QClass::IN; + ids.qname = DNSName("www.powerdns.com."); + ids.protocol = dnsdist::Protocol::DoUDP; uint16_t qid = 0x42; uint32_t key; uint32_t secondKey; boost::optional subnetOut; bool dnssecOK = false; + struct timespec queryTime; + gettime(&queryTime); // does not have to be accurate ("realTime") in tests /* lookup for a query with a first ECS value, insert a corresponding response */ { PacketBuffer query; - GenericDNSPacketWriter pwQ(query, qname, qtype, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, ids.qtype, QClass::IN, 0); pwQ.getHeader()->rd = 1; pwQ.getHeader()->id = qid; GenericDNSPacketWriter::optvect_t ednsOptions; @@ -624,26 +670,25 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { pwQ.commit(); ComboAddress remote("192.0.2.1"); - struct timespec queryTime; - gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + ids.sentTime.start(); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnetOut, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_REQUIRE(subnetOut); BOOST_CHECK_EQUAL(subnetOut->toString(), opt.source.toString()); PacketBuffer response; - GenericDNSPacketWriter pwR(response, qname, qtype, QClass::IN, 0); + GenericDNSPacketWriter pwR(response, ids.qname, ids.qtype, QClass::IN, 0); pwR.getHeader()->rd = 1; pwR.getHeader()->id = qid; - pwR.startRecord(qname, qtype, 100, QClass::IN, DNSResourceRecord::ANSWER); + pwR.startRecord(ids.qname, ids.qtype, 100, QClass::IN, DNSResourceRecord::ANSWER); ComboAddress v6("::1"); pwR.xfrCAWithoutPort(6, v6); pwR.commit(); pwR.addOpt(512, 0, 0, ednsOptions); pwR.commit(); - PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), dnssecOK, qname, qtype, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); + PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), dnssecOK, ids.qname, ids.qtype, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); BOOST_CHECK_EQUAL(PC.getSize(), 1U); found = PC.get(dq, 0, &key, subnetOut, dnssecOK, receivedOverUDP); @@ -656,7 +701,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { we should get the same key (collision) but no match */ { PacketBuffer query; - GenericDNSPacketWriter pwQ(query, qname, qtype, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, ids.qtype, QClass::IN, 0); pwQ.getHeader()->rd = 1; pwQ.getHeader()->id = qid; GenericDNSPacketWriter::optvect_t ednsOptions; @@ -667,9 +712,8 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { pwQ.commit(); ComboAddress remote("192.0.2.1"); - struct timespec queryTime; - gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + ids.sentTime.start(); + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &secondKey, subnetOut, dnssecOK, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK_EQUAL(secondKey, key); @@ -693,7 +737,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { for (size_t idxB = 0; idxB < 256; idxB++) { for (size_t idxC = 0; idxC < 256; idxC++) { PacketBuffer secondQuery; - GenericDNSPacketWriter pwFQ(secondQuery, qname, QType::AAAA, QClass::IN, 0); + GenericDNSPacketWriter pwFQ(secondQuery, ids.qname, QType::AAAA, QClass::IN, 0); pwFQ.getHeader()->rd = 1; pwFQ.getHeader()->qr = false; pwFQ.getHeader()->id = 0x42; @@ -702,7 +746,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { ednsOptions.emplace_back(EDNSOptionCode::ECS, makeEDNSSubnetOptsString(opt)); pwFQ.addOpt(512, 0, 0, ednsOptions); pwFQ.commit(); - secondKey = pc.getKey(qname.toDNSString(), qname.wirelength(), secondQuery, false); + secondKey = pc.getKey(ids.qname.toDNSString(), ids.qname.wirelength(), secondQuery, false); auto pair = colMap.emplace(secondKey, opt.source); total++; if (!pair.second) { @@ -725,42 +769,47 @@ BOOST_AUTO_TEST_CASE(test_PCDNSSECCollision) { DNSDistPacketCache PC(maxEntries, 86400, 1, 60, 3600, 60, false, 1, true, true); BOOST_CHECK_EQUAL(PC.getSize(), 0U); - DNSName qname("www.powerdns.com."); - uint16_t qtype = QType::AAAA; + InternalQueryState ids; + ids.qtype = QType::AAAA; + ids.qclass = QClass::IN; + ids.qname = DNSName("www.powerdns.com."); + ids.protocol = dnsdist::Protocol::DoUDP; uint16_t qid = 0x42; uint32_t key; boost::optional subnetOut; + struct timespec queryTime; + gettime(&queryTime); // does not have to be accurate ("realTime") in tests /* lookup for a query with DNSSEC OK, insert a corresponding response with DO set, check that it doesn't match without DO, but does with it */ { PacketBuffer query; - GenericDNSPacketWriter pwQ(query, qname, qtype, QClass::IN, 0); + GenericDNSPacketWriter pwQ(query, ids.qname, ids.qtype, QClass::IN, 0); pwQ.getHeader()->rd = 1; pwQ.getHeader()->id = qid; pwQ.addOpt(512, 0, EDNS_HEADER_FLAG_DO); pwQ.commit(); ComboAddress remote("192.0.2.1"); - struct timespec queryTime; - gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, dnsdist::Protocol::DoUDP, &queryTime); + ids.sentTime.start(); + ids.origRemote = remote; + DNSQuestion dq(ids, query, queryTime); bool found = PC.get(dq, 0, &key, subnetOut, true, receivedOverUDP); BOOST_CHECK_EQUAL(found, false); PacketBuffer response; - GenericDNSPacketWriter pwR(response, qname, qtype, QClass::IN, 0); + GenericDNSPacketWriter pwR(response, ids.qname, ids.qtype, QClass::IN, 0); pwR.getHeader()->rd = 1; pwR.getHeader()->id = qid; - pwR.startRecord(qname, qtype, 100, QClass::IN, DNSResourceRecord::ANSWER); + pwR.startRecord(ids.qname, ids.qtype, 100, QClass::IN, DNSResourceRecord::ANSWER); ComboAddress v6("::1"); pwR.xfrCAWithoutPort(6, v6); pwR.commit(); pwR.addOpt(512, 0, EDNS_HEADER_FLAG_DO); pwR.commit(); - PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), /* DNSSEC OK is set */ true, qname, qtype, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); + PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), /* DNSSEC OK is set */ true, ids.qname, ids.qtype, QClass::IN, response, receivedOverUDP, RCode::NoError, boost::none); BOOST_CHECK_EQUAL(PC.getSize(), 1U); found = PC.get(dq, 0, &key, subnetOut, false, receivedOverUDP); @@ -776,8 +825,6 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheInspection) { const size_t maxEntries = 100; DNSDistPacketCache PC(maxEntries, 86400, 1); BOOST_CHECK_EQUAL(PC.getSize(), 0U); - struct timespec queryTime; - gettime(&queryTime); // does not have to be accurate ("realTime") in tests ComboAddress remote; bool dnssecOK = false; -- 2.47.2