From: Remi Gacogne Date: Tue, 6 Oct 2020 15:54:43 +0000 (+0200) Subject: dnsdist: Use vectors instead of C arrays as buffers X-Git-Tag: rec-4.5.0-alpha1~19^2~12 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=341d2553b74c579df9d9843959f3ca6f5c3dc954;p=thirdparty%2Fpdns.git dnsdist: Use vectors instead of C arrays as buffers --- diff --git a/pdns/dnscrypt.cc b/pdns/dnscrypt.cc index 17eb99666f..dca863977f 100644 --- a/pdns/dnscrypt.cc +++ b/pdns/dnscrypt.cc @@ -384,29 +384,32 @@ void DNSCryptContext::removeInactiveCertificate(uint32_t serial) throw std::runtime_error("No inactive certificate found with this serial"); } -bool DNSCryptQuery::parsePlaintextQuery(const char * packet, uint16_t packetSize) +bool DNSCryptQuery::parsePlaintextQuery(const std::vector& packet) { assert(d_ctx != nullptr); - if (packetSize < sizeof(dnsheader)) { + if (packet.size() < sizeof(dnsheader)) { return false; } - const struct dnsheader * dh = reinterpret_cast(packet); + const struct dnsheader * dh = reinterpret_cast(packet.data()); if (dh->qr || ntohs(dh->qdcount) != 1 || dh->ancount != 0 || dh->nscount != 0 || dh->opcode != Opcode::Query) return false; - unsigned int consumed; + unsigned int qnameWireLength; uint16_t qtype, qclass; - DNSName qname(packet, packetSize, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - if ((packetSize - sizeof(dnsheader)) < (consumed + sizeof(qtype) + sizeof(qclass))) + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); + if ((packet.size() - sizeof(dnsheader)) < (qnameWireLength + sizeof(qtype) + sizeof(qclass))) { return false; + } - if (qtype != QType::TXT || qclass != QClass::IN) + if (qtype != QType::TXT || qclass != QClass::IN) { return false; + } - if (qname != d_ctx->getProviderName()) + if (qname != d_ctx->getProviderName()) { return false; + } d_qname = qname; d_id = dh->id; @@ -455,21 +458,21 @@ bool DNSCryptContext::magicMatchesAPublicKey(DNSCryptQuery& query, time_t now) return false; } -bool DNSCryptQuery::isEncryptedQuery(const char * packet, uint16_t packetSize, bool tcp, time_t now) +bool DNSCryptQuery::isEncryptedQuery(const std::vector& packet, bool tcp, time_t now) { assert(d_ctx != nullptr); d_encrypted = false; - if (packetSize < sizeof(DNSCryptQueryHeader)) { + if (packet.size() < sizeof(DNSCryptQueryHeader)) { return false; } - if (!tcp && packetSize < DNSCryptQuery::s_minUDPLength) { + if (!tcp && packet.size() < DNSCryptQuery::s_minUDPLength) { return false; } - const struct DNSCryptQueryHeader* header = reinterpret_cast(packet); + const struct DNSCryptQueryHeader* header = reinterpret_cast(packet.data()); d_header = *header; @@ -482,16 +485,15 @@ bool DNSCryptQuery::isEncryptedQuery(const char * packet, uint16_t packetSize, b return true; } -void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen) +void DNSCryptQuery::getDecrypted(bool tcp, std::vector& packet) { - assert(decryptedQueryLen != nullptr); assert(d_encrypted); assert(d_pair != nullptr); assert(d_valid == false); #ifdef DNSCRYPT_STRICT_PADDING_LENGTH - if (tcp && ((packetSize - sizeof(DNSCryptQueryHeader)) % DNSCRYPT_PADDED_BLOCK_SIZE) != 0) { - vinfolog("Dropping encrypted query with invalid size of %d (should be a multiple of %d)", (packetSize - sizeof(DNSCryptQueryHeader)), DNSCRYPT_PADDED_BLOCK_SIZE); + if (tcp && ((packet.size() - sizeof(DNSCryptQueryHeader)) % DNSCRYPT_PADDED_BLOCK_SIZE) != 0) { + vinfolog("Dropping encrypted query with invalid size of %d (should be a multiple of %d)", (packet.size() - sizeof(DNSCryptQueryHeader)), DNSCRYPT_PADDED_BLOCK_SIZE); return; } #endif @@ -514,17 +516,17 @@ void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, ui const DNSCryptExchangeVersion version = getVersion(); if (version == DNSCryptExchangeVersion::VERSION1) { - res = crypto_box_open_easy_afternm(reinterpret_cast(packet), - reinterpret_cast(packet + sizeof(DNSCryptQueryHeader)), - packetSize - sizeof(DNSCryptQueryHeader), + res = crypto_box_open_easy_afternm(reinterpret_cast(packet.data()), + reinterpret_cast(&packet.at(sizeof(DNSCryptQueryHeader))), + packet.size() - sizeof(DNSCryptQueryHeader), nonce, d_sharedKey); } else if (version == DNSCryptExchangeVersion::VERSION2) { #ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY - res = crypto_box_curve25519xchacha20poly1305_open_easy_afternm(reinterpret_cast(packet), - reinterpret_cast(packet + sizeof(DNSCryptQueryHeader)), - packetSize - sizeof(DNSCryptQueryHeader), + res = crypto_box_curve25519xchacha20poly1305_open_easy_afternm(reinterpret_cast(packet.data()), + reinterpret_cast(&packet.at(sizeof(DNSCryptQueryHeader))), + packet.size() - sizeof(DNSCryptQueryHeader), nonce, d_sharedKey); #else /* HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY */ @@ -535,9 +537,9 @@ void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, ui } #else /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ - int res = crypto_box_open_easy(reinterpret_cast(packet), - reinterpret_cast(packet + sizeof(DNSCryptQueryHeader)), - packetSize - sizeof(DNSCryptQueryHeader), + int res = crypto_box_open_easy(reinterpret_cast(packet.data()), + reinterpret_cast(&packet.at(sizeof(DNSCryptQueryHeader))), + packet.size() - sizeof(DNSCryptQueryHeader), nonce, d_header.clientPK, d_pair->privateKey.key); @@ -548,22 +550,22 @@ void DNSCryptQuery::getDecrypted(bool tcp, char* packet, uint16_t packetSize, ui return; } - *decryptedQueryLen = packetSize - sizeof(DNSCryptQueryHeader) - DNSCRYPT_MAC_SIZE; - uint16_t pos = *decryptedQueryLen; - assert(pos < packetSize); - d_paddedLen = *decryptedQueryLen; + uint16_t decryptedQueryLen = packet.size() - sizeof(DNSCryptQueryHeader) - DNSCRYPT_MAC_SIZE; + uint16_t pos = decryptedQueryLen; + assert(pos < packet.size()); + d_paddedLen = decryptedQueryLen; - while(pos > 0 && packet[pos - 1] == 0) pos--; + while (pos > 0 && packet.at(pos - 1) == 0) pos--; - if (pos == 0 || static_cast(packet[pos - 1]) != 0x80) { + if (pos == 0 || packet.at(pos - 1) != 0x80) { vinfolog("Dropping encrypted query with invalid padding value"); return; } pos--; - size_t paddingLen = *decryptedQueryLen - pos; - *decryptedQueryLen = pos; + size_t paddingLen = decryptedQueryLen - pos; + packet.resize(pos); if (tcp && paddingLen > DNSCRYPT_MAX_TCP_PADDING_SIZE) { vinfolog("Dropping encrypted query with too long padding size"); @@ -580,19 +582,16 @@ void DNSCryptQuery::getCertificateResponse(time_t now, std::vector& res d_ctx->getCertificateResponse(now, d_qname, d_id, response); } -void DNSCryptQuery::parsePacket(char* packet, uint16_t packetSize, bool tcp, uint16_t* decryptedQueryLen, time_t now) +void DNSCryptQuery::parsePacket(std::vector& packet, bool tcp, time_t now) { - assert(packet != nullptr); - assert(decryptedQueryLen != nullptr); - d_valid = false; /* might be a plaintext certificate request or an authenticated request */ - if (isEncryptedQuery(packet, packetSize, tcp, now)) { - getDecrypted(tcp, packet, packetSize, decryptedQueryLen); + if (isEncryptedQuery(packet, tcp, now)) { + getDecrypted(tcp, packet); } else { - parsePlaintextQuery(packet, packetSize); + parsePlaintextQuery(packet); } } @@ -636,65 +635,69 @@ uint16_t DNSCryptQuery::computePaddingSize(uint16_t unpaddedLen, size_t maxLen) return result; } -int DNSCryptQuery::encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, bool tcp, uint16_t* encryptedResponseLen) +int DNSCryptQuery::encryptResponse(std::vector& response, size_t maxResponseSize, bool tcp) { struct DNSCryptResponseHeader responseHeader; - assert(response != nullptr); - assert(responseLen > 0); - assert(responseSize >= responseLen); - assert(encryptedResponseLen != nullptr); + assert(response.size() > 0); + assert(maxResponseSize >= response.size()); assert(d_encrypted == true); assert(d_pair != nullptr); - if (!tcp && d_paddedLen < responseLen) { - struct dnsheader* dh = reinterpret_cast(response); + /* a DNSCrypt UDP response can't be larger than the (padded) DNSCrypt query */ + if (!tcp && d_paddedLen < response.size()) { + /* so we need to truncate it */ size_t questionSize = 0; - if (responseLen > sizeof(dnsheader)) { - unsigned int consumed = 0; - DNSName tempQName(response, responseLen, sizeof(dnsheader), false, 0, 0, &consumed); - if (consumed > 0) { - questionSize = consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE; + if (response.size() > sizeof(dnsheader)) { + unsigned int qnameWireLength = 0; + DNSName tempQName(reinterpret_cast(response.data()), response.size(), sizeof(dnsheader), false, 0, 0, &qnameWireLength); + if (qnameWireLength > 0) { + questionSize = qnameWireLength + DNS_TYPE_SIZE + DNS_CLASS_SIZE; } } - responseLen = sizeof(dnsheader) + questionSize; + response.resize(sizeof(dnsheader) + questionSize); - if (responseLen > d_paddedLen) { - responseLen = d_paddedLen; + if (response.size() > d_paddedLen) { + /* that does not seem right but let's truncate even more */ + response.resize(d_paddedLen); } + struct dnsheader* dh = reinterpret_cast(response.data()); dh->ancount = dh->arcount = dh->nscount = 0; dh->tc = 1; } - size_t requiredSize = sizeof(responseHeader) + DNSCRYPT_MAC_SIZE + responseLen; - size_t maxSize = (responseSize > (requiredSize + DNSCRYPT_MAX_RESPONSE_PADDING_SIZE)) ? (requiredSize + DNSCRYPT_MAX_RESPONSE_PADDING_SIZE) : responseSize; + size_t requiredSize = sizeof(responseHeader) + DNSCRYPT_MAC_SIZE + response.size(); + size_t maxSize = std::min(maxResponseSize, requiredSize + DNSCRYPT_MAX_RESPONSE_PADDING_SIZE); uint16_t paddingSize = computePaddingSize(requiredSize, maxSize); requiredSize += paddingSize; - if (requiredSize > responseSize) + if (requiredSize > maxResponseSize) { return ENOBUFS; + } memcpy(&responseHeader.nonce, &d_header.clientNonce, sizeof d_header.clientNonce); fillServerNonce(&(responseHeader.nonce[sizeof(d_header.clientNonce)])); + size_t responseLen = response.size(); /* moving the existing response after the header + MAC */ - memmove(response + sizeof(responseHeader) + DNSCRYPT_MAC_SIZE, response, responseLen); + response.resize(requiredSize); + std::copy_backward(response.begin(), response.begin() + responseLen, response.begin() + responseLen + sizeof(responseHeader) + DNSCRYPT_MAC_SIZE); uint16_t pos = 0; /* copying header */ - memcpy(response + pos, &responseHeader, sizeof(responseHeader)); + memcpy(&response.at(pos), &responseHeader, sizeof(responseHeader)); pos += sizeof(responseHeader); /* setting MAC bytes to 0 */ - memset(response + pos, 0, DNSCRYPT_MAC_SIZE); + memset(&response.at(pos), 0, DNSCRYPT_MAC_SIZE); pos += DNSCRYPT_MAC_SIZE; uint16_t toEncryptPos = pos; /* skipping response */ pos += responseLen; /* padding */ - response[pos] = static_cast(0x80); + response.at(pos) = static_cast(0x80); pos++; - memset(response + pos, 0, paddingSize - 1); + memset(&response.at(pos), 0, paddingSize - 1); pos += (paddingSize - 1); /* encrypting */ @@ -707,16 +710,16 @@ int DNSCryptQuery::encryptResponse(char* response, uint16_t responseLen, uint16_ const DNSCryptExchangeVersion version = getVersion(); if (version == DNSCryptExchangeVersion::VERSION1) { - res = crypto_box_easy_afternm(reinterpret_cast(response + sizeof(responseHeader)), - reinterpret_cast(response + toEncryptPos), + res = crypto_box_easy_afternm(reinterpret_cast(&response.at(sizeof(responseHeader))), + reinterpret_cast(&response.at(toEncryptPos)), responseLen + paddingSize, responseHeader.nonce, d_sharedKey); } else if (version == DNSCryptExchangeVersion::VERSION2) { #ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY - res = crypto_box_curve25519xchacha20poly1305_easy_afternm(reinterpret_cast(response + sizeof(responseHeader)), - reinterpret_cast(response + toEncryptPos), + res = crypto_box_curve25519xchacha20poly1305_easy_afternm(reinterpret_cast(&response.at(sizeof(responseHeader))), + reinterpret_cast(&response.at(toEncryptPos)), responseLen + paddingSize, responseHeader.nonce, d_sharedKey); @@ -728,8 +731,8 @@ int DNSCryptQuery::encryptResponse(char* response, uint16_t responseLen, uint16_ res = -1; } #else - int res = crypto_box_easy(reinterpret_cast(response + sizeof(responseHeader)), - reinterpret_cast(response + toEncryptPos), + int res = crypto_box_easy(reinterpret_cast(&response.at(sizeof(responseHeader))), + reinterpret_cast(&response.at(toEncryptPos)), responseLen + paddingSize, responseHeader.nonce, d_header.clientPK, @@ -738,20 +741,17 @@ int DNSCryptQuery::encryptResponse(char* response, uint16_t responseLen, uint16_ if (res == 0) { assert(pos == requiredSize); - *encryptedResponseLen = requiredSize; } return res; } -int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr& cert) const +int DNSCryptContext::encryptQuery(std::vector& packet, size_t maximumSize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, const std::shared_ptr& cert) const { - assert(query != nullptr); - assert(queryLen > 0); - assert(querySize >= queryLen); - assert(encryptedResponseLen != nullptr); + assert(packet.size() > 0); assert(cert != nullptr); + size_t queryLen = packet.size(); unsigned char nonce[DNSCRYPT_NONCE_SIZE]; size_t requiredSize = sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE + queryLen; /* this is not optimal, we should compute a random padding size, multiple of DNSCRYPT_PADDED_BLOCK_SIZE, @@ -764,37 +764,39 @@ int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query requiredSize = DNSCryptQuery::s_minUDPLength; } - if (requiredSize > querySize) + if (requiredSize > maximumSize) { return ENOBUFS; + } /* moving the existing query after the header + MAC */ - memmove(query + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE, query, queryLen); + packet.resize(requiredSize); + std::copy_backward(packet.begin(), packet.begin() + queryLen, packet.begin() + queryLen + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE); size_t pos = 0; /* client magic */ - memcpy(query + pos, cert->signedData.clientMagic, sizeof(cert->signedData.clientMagic)); + memcpy(&packet.at(pos), cert->signedData.clientMagic, sizeof(cert->signedData.clientMagic)); pos += sizeof(cert->signedData.clientMagic); /* client PK */ - memcpy(query + pos, clientPublicKey, DNSCRYPT_PUBLIC_KEY_SIZE); + memcpy(&packet.at(pos), clientPublicKey, DNSCRYPT_PUBLIC_KEY_SIZE); pos += DNSCRYPT_PUBLIC_KEY_SIZE; /* client nonce */ - memcpy(query + pos, clientNonce, DNSCRYPT_NONCE_SIZE / 2); + memcpy(&packet.at(pos), clientNonce, DNSCRYPT_NONCE_SIZE / 2); pos += DNSCRYPT_NONCE_SIZE / 2; size_t encryptedPos = pos; /* clear the MAC bytes */ - memset(query + pos, 0, DNSCRYPT_MAC_SIZE); + memset(&packet.at(pos), 0, DNSCRYPT_MAC_SIZE); pos += DNSCRYPT_MAC_SIZE; /* skipping data */ pos += queryLen; /* padding */ - query[pos] = static_cast(0x80); + packet.at(pos) = static_cast(0x80); pos++; - memset(query + pos, 0, paddingSize - 1); + memset(&packet.at(pos), 0, paddingSize - 1); pos += paddingSize - 1; memcpy(nonce, clientNonce, DNSCRYPT_NONCE_SIZE / 2); @@ -804,8 +806,8 @@ int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query int res = -1; if (version == DNSCryptExchangeVersion::VERSION1) { - res = crypto_box_easy(reinterpret_cast(query + encryptedPos), - reinterpret_cast(query + encryptedPos + DNSCRYPT_MAC_SIZE), + res = crypto_box_easy(reinterpret_cast(&packet.at(encryptedPos)), + reinterpret_cast(&packet.at(encryptedPos + DNSCRYPT_MAC_SIZE)), queryLen + paddingSize, nonce, cert->signedData.resolverPK, @@ -813,8 +815,8 @@ int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query } else if (version == DNSCryptExchangeVersion::VERSION2) { #ifdef HAVE_CRYPTO_BOX_CURVE25519XCHACHA20POLY1305_EASY - res = crypto_box_curve25519xchacha20poly1305_easy(reinterpret_cast(query + encryptedPos), - reinterpret_cast(query + encryptedPos + DNSCRYPT_MAC_SIZE), + res = crypto_box_curve25519xchacha20poly1305_easy(reinterpret_cast(&packet.at(encryptedPos)), + reinterpret_cast(&packet.at(encryptedPos + DNSCRYPT_MAC_SIZE)), queryLen + paddingSize, nonce, cert->signedData.resolverPK, @@ -827,7 +829,6 @@ int DNSCryptContext::encryptQuery(char* query, uint16_t queryLen, uint16_t query if (res == 0) { assert(pos == requiredSize); - *encryptedResponseLen = requiredSize; } return res; diff --git a/pdns/dnscrypt.hh b/pdns/dnscrypt.hh index a010ebdc5a..018e90025a 100644 --- a/pdns/dnscrypt.hh +++ b/pdns/dnscrypt.hh @@ -207,10 +207,10 @@ public: d_pair = pair; } - void parsePacket(char* packet, uint16_t packetSize, bool tcp, uint16_t* decryptedQueryLen, time_t now); - void getDecrypted(bool tcp, char* packet, uint16_t packetSize, uint16_t* decryptedQueryLen); + void parsePacket(std::vector& packet, bool tcp, time_t now); + void getDecrypted(bool tcp, std::vector& packet); void getCertificateResponse(time_t now, std::vector& response) const; - int encryptResponse(char* response, uint16_t responseLen, uint16_t responseSize, bool tcp, uint16_t* encryptedResponseLen); + int encryptResponse(std::vector& response, size_t maxResponseSize, bool tcp); static const size_t s_minUDPLength = 256; @@ -221,8 +221,8 @@ private: #endif /* HAVE_CRYPTO_BOX_EASY_AFTERNM */ void fillServerNonce(unsigned char* dest) const; uint16_t computePaddingSize(uint16_t unpaddedLen, size_t maxLen) const; - bool parsePlaintextQuery(const char * packet, uint16_t packetSize); - bool isEncryptedQuery(const char * packet, uint16_t packetSize, bool tcp, time_t now); + bool parsePlaintextQuery(const std::vector& packet); + bool isEncryptedQuery(const std::vector& packet, bool tcp, time_t now); DNSCryptQueryHeader d_header; #ifdef HAVE_CRYPTO_BOX_EASY_AFTERNM @@ -275,7 +275,7 @@ public: std::vector> getCertificates() { return d_certs; }; const DNSName& getProviderName() const { return providerName; } - int encryptQuery(char* query, uint16_t queryLen, uint16_t querySize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, uint16_t* encryptedResponseLen, const std::shared_ptr& cert) const; + int encryptQuery(std::vector& query, size_t maximumSize, const unsigned char clientPublicKey[DNSCRYPT_PUBLIC_KEY_SIZE], const DNSCryptPrivateKey& clientPrivateKey, const unsigned char clientNonce[DNSCRYPT_NONCE_SIZE / 2], bool tcp, const std::shared_ptr& cert) const; bool magicMatchesAPublicKey(DNSCryptQuery& query, time_t now); void getCertificateResponse(time_t now, const DNSName& qname, uint16_t qid, std::vector& response); diff --git a/pdns/dnsdist-cache.cc b/pdns/dnsdist-cache.cc index ddf00bfa94..f893a0621a 100644 --- a/pdns/dnsdist-cache.cc +++ b/pdns/dnsdist-cache.cc @@ -53,23 +53,23 @@ DNSDistPacketCache::~DNSDistPacketCache() } } -bool DNSDistPacketCache::getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional& subnet) +bool DNSDistPacketCache::getClientSubnet(const std::vector& packet, size_t qnameWireLength, boost::optional& subnet) { uint16_t optRDPosition; size_t remaining = 0; - int res = getEDNSOptionsStart(const_cast(packet), consumed, len, &optRDPosition, &remaining); + int res = getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining); if (res == 0) { - char * ecsOptionStart = nullptr; + size_t ecsOptionStartPosition = 0; size_t ecsOptionSize = 0; - res = getEDNSOption(const_cast(packet) + optRDPosition, remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize); + res = getEDNSOption(reinterpret_cast(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize); if (res == 0 && ecsOptionSize > (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE)) { EDNSSubnetOpts eso; - if (getEDNSSubnetOptsFromString(ecsOptionStart + (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), ecsOptionSize - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), &eso) == true) { + if (getEDNSSubnetOptsFromString(reinterpret_cast(&packet.at(optRDPosition + ecsOptionStartPosition + (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE))), ecsOptionSize - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), &eso) == true) { subnet = eso.source; return true; } @@ -127,9 +127,9 @@ void DNSDistPacketCache::insertLocked(CacheShard& shard, uint32_t key, CacheValu value = newValue; } -void DNSDistPacketCache::insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL) +void DNSDistPacketCache::insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const std::vector& response, bool tcp, uint8_t rcode, boost::optional tempFailureTTL) { - if (responseLen < sizeof(dnsheader)) { + if (response.size() < sizeof(dnsheader)) { return; } @@ -143,7 +143,7 @@ void DNSDistPacketCache::insert(uint32_t key, const boost::optional& su } else { bool seenAuthSOA = false; - minTTL = getMinTTL(response, responseLen, &seenAuthSOA); + minTTL = getMinTTL(reinterpret_cast(response.data()), response.size(), &seenAuthSOA); /* no TTL found, we don't want to cache this */ if (minTTL == std::numeric_limits::max()) { @@ -176,12 +176,12 @@ void DNSDistPacketCache::insert(uint32_t key, const boost::optional& su newValue.qtype = qtype; newValue.qclass = qclass; newValue.queryFlags = queryFlags; - newValue.len = responseLen; + newValue.len = response.size(); newValue.validity = newValidity; newValue.added = now; newValue.tcp = tcp; newValue.dnssecOK = dnssecOK; - newValue.value = std::string(response, responseLen); + newValue.value = std::string(response.begin(), response.end()); newValue.subnet = subnet; auto& shard = d_shards.at(shardIndex); @@ -202,23 +202,24 @@ void DNSDistPacketCache::insert(uint32_t key, const boost::optional& su } } -bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional& subnet, bool dnssecOK, uint32_t allowExpired, bool skipAging) +bool DNSDistPacketCache::get(DNSQuestion& dq, uint16_t queryId, uint32_t* keyOut, boost::optional& subnet, bool dnssecOK, uint32_t allowExpired, bool skipAging) { const auto& dnsQName = dq.qname->getStorage(); - uint32_t key = getKey(dnsQName, consumed, reinterpret_cast(dq.dh), dq.len, dq.tcp); + uint32_t key = getKey(dnsQName, dq.qname->wirelength(), dq.getData(), dq.tcp); if (keyOut) { *keyOut = key; } if (d_parseECS) { - getClientSubnet(reinterpret_cast(dq.dh), consumed, dq.len, subnet); + getClientSubnet(dq.getData(), dq.qname->wirelength(), subnet); } uint32_t shardIndex = getShardIndex(key); time_t now = time(nullptr); time_t age; bool stale = false; + auto& response = dq.getMutableData(); auto& shard = d_shards.at(shardIndex); auto& map = shard.d_map; { @@ -245,22 +246,22 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t } } - if (*responseLen < value.len || value.len < sizeof(dnsheader)) { + if (value.len < sizeof(dnsheader)) { return false; } /* check for collision */ - if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.dh)), *dq.qname, dq.qtype, dq.qclass, dq.tcp, dnssecOK, subnet)) { + if (!cachedValueMatches(value, *(getFlagsFromDNSHeader(dq.getHeader())), *dq.qname, dq.qtype, dq.qclass, dq.tcp, dnssecOK, subnet)) { d_lookupCollisions++; return false; } - memcpy(response, &queryId, sizeof(queryId)); - memcpy(response + sizeof(queryId), value.value.c_str() + sizeof(queryId), sizeof(dnsheader) - sizeof(queryId)); + response.resize(value.len); + memcpy(&response.at(0), &queryId, sizeof(queryId)); + memcpy(&response.at(sizeof(queryId)), &value.value.at(sizeof(queryId)), sizeof(dnsheader) - sizeof(queryId)); if (value.len == sizeof(dnsheader)) { /* DNS header only, our work here is done */ - *responseLen = value.len; d_hits++; return true; } @@ -270,11 +271,11 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t return false; } - memcpy(response + sizeof(dnsheader), dnsQName.c_str(), dnsQNameLen); + memcpy(&response.at(sizeof(dnsheader)), dnsQName.c_str(), dnsQNameLen); if (value.len > (sizeof(dnsheader) + dnsQNameLen)) { - memcpy(response + sizeof(dnsheader) + dnsQNameLen, value.value.c_str() + sizeof(dnsheader) + dnsQNameLen, value.len - (sizeof(dnsheader) + dnsQNameLen)); + memcpy(&response.at(sizeof(dnsheader) + dnsQNameLen), &value.value.at(sizeof(dnsheader) + dnsQNameLen), value.len - (sizeof(dnsheader) + dnsQNameLen)); } - *responseLen = value.len; + if (!stale) { age = now - value.added; } @@ -284,7 +285,7 @@ bool DNSDistPacketCache::get(const DNSQuestion& dq, uint16_t consumed, uint16_t } if (!d_dontAge && !skipAging) { - ageDNSPacket(response, *responseLen, age); + ageDNSPacket(reinterpret_cast(&response[0]), response.size(), age); } d_hits++; @@ -413,26 +414,26 @@ uint32_t DNSDistPacketCache::getMinTTL(const char* packet, uint16_t length, bool return getDNSPacketMinTTL(packet, length, seenNoDataSOA); } -uint32_t DNSDistPacketCache::getKey(const DNSName::string_t& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp) +uint32_t DNSDistPacketCache::getKey(const DNSName::string_t& qname, size_t qnameWireLength, const std::vector& packet, bool tcp) { uint32_t result = 0; /* skip the query ID */ - if (packetLen < sizeof(dnsheader)) { - throw std::range_error("Computing packet cache key for an invalid packet size (" + std::to_string(packetLen) +")"); + if (packet.size() < sizeof(dnsheader)) { + throw std::range_error("Computing packet cache key for an invalid packet size (" + std::to_string(packet.size()) +")"); } - result = burtle(packet + 2, sizeof(dnsheader) - 2, result); + result = burtle(&packet.at(2), sizeof(dnsheader) - 2, result); result = burtleCI((const unsigned char*) qname.c_str(), qname.length(), result); - if (packetLen < sizeof(dnsheader) + consumed) { - throw std::range_error("Computing packet cache key for an invalid packet (" + std::to_string(packetLen) + " < " + std::to_string(sizeof(dnsheader) + consumed) + ")"); + if (packet.size() < sizeof(dnsheader) + qnameWireLength) { + throw std::range_error("Computing packet cache key for an invalid packet (" + std::to_string(packet.size()) + " < " + std::to_string(sizeof(dnsheader) + qnameWireLength) + ")"); } - if (packetLen > ((sizeof(dnsheader) + consumed))) { + if (packet.size() > ((sizeof(dnsheader) + qnameWireLength))) { if (!d_cookieHashing) { /* skip EDNS Cookie options if any */ - result = PacketCache::hashAfterQname(pdns_string_view(reinterpret_cast(packet), packetLen), result, sizeof(dnsheader) + consumed, false); + result = PacketCache::hashAfterQname(pdns_string_view(reinterpret_cast(packet.data()), packet.size()), result, sizeof(dnsheader) + qnameWireLength, false); } else { - result = burtle(packet + sizeof(dnsheader) + consumed, packetLen - (sizeof(dnsheader) + consumed), result); + result = burtle(&packet.at(sizeof(dnsheader) + qnameWireLength), packet.size() - (sizeof(dnsheader) + qnameWireLength), result); } } result = burtle((const unsigned char*) &tcp, sizeof(tcp), result); diff --git a/pdns/dnsdist-cache.hh b/pdns/dnsdist-cache.hh index 89ccbfec91..c899a93df5 100644 --- a/pdns/dnsdist-cache.hh +++ b/pdns/dnsdist-cache.hh @@ -35,8 +35,8 @@ public: DNSDistPacketCache(size_t maxEntries, uint32_t maxTTL=86400, uint32_t minTTL=0, uint32_t tempFailureTTL=60, uint32_t maxNegativeTTL=3600, uint32_t staleTTL=60, bool dontAge=false, uint32_t shards=1, bool deferrableInsertLock=true, bool parseECS=false); ~DNSDistPacketCache(); - void insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const char* response, uint16_t responseLen, bool tcp, uint8_t rcode, boost::optional tempFailureTTL); - bool get(const DNSQuestion& dq, uint16_t consumed, uint16_t queryId, char* response, uint16_t* responseLen, uint32_t* keyOut, boost::optional& subnetOut, bool dnssecOK, uint32_t allowExpired=0, bool skipAging=false); + void insert(uint32_t key, const boost::optional& subnet, uint16_t queryFlags, bool dnssecOK, const DNSName& qname, uint16_t qtype, uint16_t qclass, const std::vector& response, bool tcp, uint8_t rcode, boost::optional tempFailureTTL); + bool get(DNSQuestion& dq, uint16_t queryId, uint32_t* keyOut, boost::optional& subnet, bool dnssecOK, uint32_t allowExpired = 0, bool skipAging = false); size_t purgeExpired(size_t upTo=0); size_t expunge(size_t upTo=0); size_t expungeByName(const DNSName& name, uint16_t qtype=QType::ANY, bool suffixMatch=false); @@ -76,10 +76,10 @@ public: d_parseECS = enabled; } - uint32_t getKey(const DNSName::string_t& qname, uint16_t consumed, const unsigned char* packet, uint16_t packetLen, bool tcp); + uint32_t getKey(const DNSName::string_t& qname, size_t qnameWireLength, const std::vector& packet, bool tcp); static uint32_t getMinTTL(const char* packet, uint16_t length, bool* seenNoDataSOA); - static bool getClientSubnet(const char* packet, unsigned int consumed, uint16_t len, boost::optional& subnet); + static bool getClientSubnet(const std::vector& packet, size_t qnameWireLength, boost::optional& subnet); private: diff --git a/pdns/dnsdist-dnscrypt.cc b/pdns/dnsdist-dnscrypt.cc index 448efcd6ea..ac254fe9e5 100644 --- a/pdns/dnsdist-dnscrypt.cc +++ b/pdns/dnsdist-dnscrypt.cc @@ -24,9 +24,9 @@ #include "dnscrypt.hh" #ifdef HAVE_DNSCRYPT -int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector& response) +int handleDNSCryptQuery(std::vector& packet, std::shared_ptr& query, bool tcp, time_t now, std::vector& response) { - query->parsePacket(packet, len, tcp, decryptedQueryLen, now); + query->parsePacket(packet, tcp, now); if (query->isValid() == false) { vinfolog("Dropping DNSCrypt invalid query"); @@ -39,7 +39,7 @@ int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr(sizeof(struct dnsheader))) { + if (packet.size() < static_cast(sizeof(struct dnsheader))) { ++g_stats.nonCompliantQueries; return false; } diff --git a/pdns/dnsdist-ecs.cc b/pdns/dnsdist-ecs.cc index 000ad409c4..7a3a28c85c 100644 --- a/pdns/dnsdist-ecs.cc +++ b/pdns/dnsdist-ecs.cc @@ -261,12 +261,15 @@ static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, v return true; } -static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::shared_ptr >& options) +static bool slowParseEDNSOptions(const std::vector& packet, std::shared_ptr >& options) { - const struct dnsheader* dh = reinterpret_cast(packet); + if (packet.size() < sizeof(dnsheader)) { + return false; + } + + const struct dnsheader* dh = reinterpret_cast(packet.data()); - if (len < sizeof(dnsheader) || ntohs(dh->qdcount) == 0) - { + if (ntohs(dh->qdcount) == 0) { return false; } @@ -276,7 +279,7 @@ static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::sh try { uint64_t numrecords = ntohs(dh->ancount) + ntohs(dh->nscount) + ntohs(dh->arcount); - DNSPacketMangler dpm(const_cast(packet), len); + DNSPacketMangler dpm(const_cast(reinterpret_cast(&packet.at(0))), packet.size()); uint64_t n; for(n=0; n < ntohs(dh->qdcount) ; ++n) { dpm.skipDomainName(); @@ -294,12 +297,12 @@ static bool slowParseEDNSOptions(const char* packet, uint16_t const len, std::sh if(section == 3 && dnstype == QType::OPT) { uint32_t offset = dpm.getOffset(); - if (offset >= len) { + if (offset >= packet.size()) { return false; } /* if we survive this call, we can parse it safely */ dpm.skipRData(); - return getEDNSOptions(packet + offset, len - offset, *options) == 0; + return getEDNSOptions(reinterpret_cast(&packet.at(offset)), packet.size() - offset, *options) == 0; } else { dpm.skipRData(); @@ -380,14 +383,13 @@ int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * opt } /* extract the start of the OPT RR in a QUERY packet if any */ -int getEDNSOptionsStart(const char* packet, const size_t offset, const size_t len, uint16_t* optRDPosition, size_t * remaining) +int getEDNSOptionsStart(const std::vector& packet, const size_t offset, uint16_t* optRDPosition, size_t* remaining) { - assert(packet != nullptr); assert(optRDPosition != nullptr); assert(remaining != nullptr); - const struct dnsheader* dh = reinterpret_cast(packet); + const struct dnsheader* dh = reinterpret_cast(packet.data()); - if (offset >= len) { + if (offset >= packet.size()) { return ENOENT; } @@ -397,10 +399,10 @@ int getEDNSOptionsStart(const char* packet, const size_t offset, const size_t le size_t pos = sizeof(dnsheader) + offset; pos += DNS_TYPE_SIZE + DNS_CLASS_SIZE; - if (pos >= len) + if (pos >= packet.size()) return ENOENT; - if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= len) { + if ((pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE) >= packet.size()) { return ENOENT; } @@ -410,16 +412,17 @@ int getEDNSOptionsStart(const char* packet, const size_t offset, const size_t le } pos += 1; - uint16_t qtype = (reinterpret_cast(packet)[pos])*256 + reinterpret_cast(packet)[pos+1]; + uint16_t qtype = packet.at(pos)*256 + packet.at(pos+1); pos += DNS_TYPE_SIZE; pos += DNS_CLASS_SIZE; - if(qtype != QType::OPT || (len - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE)) + if (qtype != QType::OPT || (packet.size() - pos) < (DNS_TTL_SIZE + DNS_RDLENGTH_SIZE)) { return ENOENT; + } pos += DNS_TTL_SIZE; *optRDPosition = pos; - *remaining = len - pos; + *remaining = packet.size() - pos; return 0; } @@ -453,39 +456,41 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload res.append(optRData.c_str(), optRData.length()); } -static bool replaceEDNSClientSubnetOption(char * const packet, const size_t packetSize, uint16_t * const len, char * const oldEcsOptionStart, size_t const oldEcsOptionSize, unsigned char * const optRDLen, const string& newECSOption) +static bool replaceEDNSClientSubnetOption(std::vector& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption) { - assert(packet != NULL); - assert(len != NULL); - assert(oldEcsOptionStart != NULL); - assert(optRDLen != NULL); + assert(oldEcsOptionStartPosition < packet.size()); + assert(optRDLenPosition < packet.size()); if (newECSOption.size() == oldEcsOptionSize) { /* same size as the existing option */ - memcpy(oldEcsOptionStart, newECSOption.c_str(), oldEcsOptionSize); + memcpy(&packet.at(oldEcsOptionStartPosition), newECSOption.c_str(), oldEcsOptionSize); } else { /* different size than the existing option */ - const unsigned int newPacketLen = *len + (newECSOption.length() - oldEcsOptionSize); - const size_t beforeOptionLen = oldEcsOptionStart - packet; - const size_t dataBehindSize = *len - beforeOptionLen - oldEcsOptionSize; + const unsigned int newPacketLen = packet.size() + (newECSOption.length() - oldEcsOptionSize); + const size_t beforeOptionLen = oldEcsOptionStartPosition; + const size_t dataBehindSize = packet.size() - beforeOptionLen - oldEcsOptionSize; /* check that it fits in the existing buffer */ - if (newPacketLen > packetSize) { - return false; + if (newPacketLen > packet.size()) { + if (newPacketLen > maximumSize) { + return false; + } + + packet.resize(newPacketLen); } /* fix the size of ECS Option RDLen */ - uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1]; + uint16_t newRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1); newRDLen += (newECSOption.size() - oldEcsOptionSize); - optRDLen[0] = newRDLen / 256; - optRDLen[1] = newRDLen % 256; + packet.at(optRDLenPosition) = newRDLen / 256; + packet.at(optRDLenPosition + 1) = newRDLen % 256; if (dataBehindSize > 0) { - memmove(oldEcsOptionStart, oldEcsOptionStart + oldEcsOptionSize, dataBehindSize); + memmove(&packet.at(oldEcsOptionStartPosition), &packet.at(oldEcsOptionStartPosition + oldEcsOptionSize), dataBehindSize); } - memcpy(oldEcsOptionStart + dataBehindSize, newECSOption.c_str(), newECSOption.size()); - *len = newPacketLen; + memcpy(&packet.at(oldEcsOptionStartPosition + dataBehindSize), newECSOption.c_str(), newECSOption.size()); + packet.resize(newPacketLen); } return true; @@ -495,139 +500,135 @@ static bool replaceEDNSClientSubnetOption(char * const packet, const size_t pack and false otherwise. */ bool parseEDNSOptions(const DNSQuestion& dq) { - assert(dq.dh != nullptr); - assert(dq.consumed <= dq.len); - assert(dq.len <= dq.size); - + const auto dh = dq.getHeader(); if (dq.ednsOptions != nullptr) { return true; } dq.ednsOptions = std::make_shared >(); - if (ntohs(dq.dh->arcount) == 0) { + if (ntohs(dh->arcount) == 0) { /* nothing in additional so no EDNS */ return false; } - if (ntohs(dq.dh->ancount) != 0 || ntohs(dq.dh->nscount) != 0 || ntohs(dq.dh->arcount) > 1) { - return slowParseEDNSOptions(reinterpret_cast(dq.dh), dq.len, dq.ednsOptions); + if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || ntohs(dh->arcount) > 1) { + return slowParseEDNSOptions(dq.getData(), dq.ednsOptions); } - const char* packet = reinterpret_cast(dq.dh); - size_t remaining = 0; uint16_t optRDPosition; - int res = getEDNSOptionsStart(packet, dq.consumed, dq.len, &optRDPosition, &remaining); + int res = getEDNSOptionsStart(dq.getData(), dq.qname->wirelength(), &optRDPosition, &remaining); if (res == 0) { - res = getEDNSOptions(packet + optRDPosition, remaining, *dq.ednsOptions); + res = getEDNSOptions(reinterpret_cast(&dq.getData().at(optRDPosition)), remaining, *dq.ednsOptions); return (res == 0); } return false; } -static bool addECSToExistingOPT(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, unsigned char* optRDLen, bool& ecsAdded) +static bool addECSToExistingOPT(std::vector& packet, size_t maximumSize, const string& newECSOption, size_t optRDLenPosition, bool& ecsAdded) { /* we need to add one EDNS0 ECS option, fixing the size of EDNS0 RDLENGTH */ /* getEDNSOptionsStart has already checked that there is exactly one AR, no NS and no AN */ + uint16_t oldRDLen = (packet.at(optRDLenPosition) * 256) + packet.at(optRDLenPosition + 1); + if (packet.size() != (optRDLenPosition + sizeof(uint16_t) + oldRDLen)) { + /* we are supposed to be the last record, do we have some trailing data to remove? */ + uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast(packet.data()), packet.size()); + packet.resize(realPacketLen); + } - /* check if the existing buffer is large enough */ - const size_t newECSOptionSize = newECSOption.size(); - if (packetSize - *len <= newECSOptionSize) { + if ((maximumSize - packet.size()) < newECSOption.size()) { return false; } - uint16_t newRDLen = (optRDLen[0] * 256) + optRDLen[1]; - newRDLen += newECSOptionSize; - optRDLen[0] = newRDLen / 256; - optRDLen[1] = newRDLen % 256; + uint16_t newRDLen = oldRDLen + newECSOption.size(); + packet.at(optRDLenPosition) = newRDLen / 256; + packet.at(optRDLenPosition + 1) = newRDLen % 256; - memcpy(packet + *len, newECSOption.c_str(), newECSOptionSize); - *len += newECSOptionSize; + packet.insert(packet.end(), newECSOption.begin(), newECSOption.end()); ecsAdded = true; return true; } -static bool addEDNSWithECS(char* const packet, size_t const packetSize, uint16_t* const len, const string& newECSOption, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData) +static bool addEDNSWithECS(std::vector& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded) { /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */ string EDNSRR; - struct dnsheader* dh = reinterpret_cast(packet); generateOptRR(newECSOption, EDNSRR, g_EdnsUDPPayloadSize, 0, false); - /* does it fit in the existing buffer? */ - if (packetSize - *len <= EDNSRR.size()) { + if ((maximumSize - packet.size()) < EDNSRR.size()) { return false; } - uint32_t realPacketLen = getDNSPacketLength(packet, *len); - if (realPacketLen < *len && preserveTrailingData) { - size_t toMove = *len - realPacketLen; - memmove(packet + realPacketLen + EDNSRR.size(), packet + realPacketLen, toMove); - *len += EDNSRR.size(); - } - else { - *len = realPacketLen + EDNSRR.size(); - } +#warning FIXME: we can avoid a copy here by generating in place + packet.insert(packet.end(), EDNSRR.begin(), EDNSRR.end()); + struct dnsheader* dh = reinterpret_cast(packet.data()); uint16_t arcount = ntohs(dh->arcount); arcount++; dh->arcount = htons(arcount); ednsAdded = true; ecsAdded = true; - memcpy(packet + realPacketLen, EDNSRR.c_str(), EDNSRR.size()); - return true; } -bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const unsigned int consumed, uint16_t* const len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData) +bool handleEDNSClientSubnet(std::vector& packet, const size_t maximumSize, const size_t qnameWireLength, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption) { - assert(packet != nullptr); - assert(len != nullptr); - assert(consumed <= (size_t) *len); + assert(qnameWireLength <= packet.size()); - const struct dnsheader* dh = reinterpret_cast(packet); + const struct dnsheader* dh = reinterpret_cast(packet.data()); if (ntohs(dh->ancount) != 0 || ntohs(dh->nscount) != 0 || (ntohs(dh->arcount) != 0 && ntohs(dh->arcount) != 1)) { vector newContent; - newContent.reserve(packetSize); + newContent.reserve(packet.size()); - if (!slowRewriteQueryWithExistingEDNS(std::string(packet, *len), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) { + if (!slowRewriteQueryWithExistingEDNS(std::string(reinterpret_cast(packet.data()), packet.size()), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) { ednsAdded = false; ecsAdded = false; return false; } - if (newContent.size() > packetSize) { + if (newContent.size() > maximumSize) { ednsAdded = false; ecsAdded = false; return false; } - memcpy(packet, &newContent.at(0), newContent.size()); - *len = newContent.size(); + packet = std::move(newContent); return true; } uint16_t optRDPosition = 0; size_t remaining = 0; - int res = getEDNSOptionsStart(packet, consumed, *len, &optRDPosition, &remaining); + int res = getEDNSOptionsStart(packet, qnameWireLength, &optRDPosition, &remaining); if (res != 0) { - return addEDNSWithECS(packet, packetSize, len, newECSOption, ednsAdded, ecsAdded, preserveTrailingData); + /* no EDNS but there might be another record in additional (TSIG?) */ + size_t minimumPacketSize = sizeof(dnsheader) + qnameWireLength + sizeof(uint16_t) + sizeof(uint16_t); + if (packet.size() > minimumPacketSize) { + if (ntohs(dh->arcount) == 0) { + /* well now.. */ + packet.resize(minimumPacketSize); + } + else { + uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast(packet.data()), packet.size()); + packet.resize(realPacketLen); + } + } + + return addEDNSWithECS(packet, maximumSize, newECSOption, ednsAdded, ecsAdded); } - unsigned char* optRDLen = reinterpret_cast(packet) + optRDPosition; - char * ecsOptionStart = nullptr; + size_t ecsOptionStartPosition = 0; size_t ecsOptionSize = 0; - res = getEDNSOption(reinterpret_cast(optRDLen), remaining, EDNSOptionCode::ECS, &ecsOptionStart, &ecsOptionSize); + res = getEDNSOption(reinterpret_cast(&packet.at(optRDPosition)), remaining, EDNSOptionCode::ECS, &ecsOptionStartPosition, &ecsOptionSize); if (res == 0) { /* there is already an ECS value */ @@ -635,23 +636,22 @@ bool handleEDNSClientSubnet(char* const packet, const size_t packetSize, const u return true; } - return replaceEDNSClientSubnetOption(packet, packetSize, len, ecsOptionStart, ecsOptionSize, optRDLen, newECSOption); + return replaceEDNSClientSubnetOption(packet, maximumSize, optRDPosition + ecsOptionStartPosition, ecsOptionSize, optRDPosition, newECSOption); } else { /* we have an EDNS OPT RR but no existing ECS option */ - return addECSToExistingOPT(packet, packetSize, len, newECSOption, optRDLen, ecsAdded); + return addECSToExistingOPT(packet, maximumSize, newECSOption, optRDPosition, ecsAdded); } return true; } -bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData) +bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded) { assert(dq.remote != nullptr); string newECSOption; generateECSOption(dq.ecsSet ? dq.ecs.getNetwork() : *dq.remote, newECSOption, dq.ecsSet ? dq.ecs.getBits() : dq.ecsPrefixLength); - char* packet = reinterpret_cast(dq.dh); - return handleEDNSClientSubnet(packet, dq.size, dq.consumed, &dq.len, ednsAdded, ecsAdded, dq.ecsOverride, newECSOption, preserveTrailingData); + return handleEDNSClientSubnet(dq.getMutableData(), dq.getMaximumSize(), dq.qname->wirelength(), ednsAdded, ecsAdded, dq.ecsOverride, newECSOption); } static int removeEDNSOptionFromOptions(unsigned char* optionsStart, const uint16_t optionsLen, const uint16_t optionCodeToRemove, uint16_t* newOptionsLen) @@ -844,18 +844,13 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin return 0; } -bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode) +bool addEDNS(std::vector& packet, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode) { std::string optRecord; generateOptRR(std::string(), optRecord, payloadSize, ednsrcode, dnssecOK); - if (optRecord.size() >= size || (size - optRecord.size()) < len) { - return false; - } - - char * optPtr = reinterpret_cast(dh) + len; - memcpy(optPtr, optRecord.data(), optRecord.size()); - len += optRecord.size(); + packet.insert(packet.end(), optRecord.begin(), optRecord.end()); + auto dh = reinterpret_cast(packet.data()); dh->arcount = htons(ntohs(dh->arcount) + 1); return true; @@ -867,52 +862,47 @@ bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uin */ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum) { - if (ntohs(dq.dh->qdcount) != 1) { + auto& packet = dq.getMutableData(); + auto dh = dq.getHeader(); + if (ntohs(dh->qdcount) != 1) { return false; } - assert(dq.consumed == dq.qname->wirelength()); - size_t queryPartSize = sizeof(dnsheader) + dq.consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE; - if (dq.len < queryPartSize) { + size_t queryPartSize = sizeof(dnsheader) + dq.qname->wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE; + if (packet.size() < queryPartSize) { /* something is already wrong, don't build on flawed foundations */ return false; } - size_t available = dq.size - queryPartSize; uint16_t qtype = htons(QType::SOA); uint16_t qclass = htons(QClass::IN); uint16_t rdLength = mname.wirelength() + rname.wirelength() + sizeof(serial) + sizeof(refresh) + sizeof(retry) + sizeof(expire) + sizeof(minimum); size_t soaSize = zone.wirelength() + sizeof(qtype) + sizeof(qclass) + sizeof(ttl) + sizeof(rdLength) + rdLength; - - if (soaSize > available) { - /* not enough space left to add the SOA, sorry! */ - return false; - } - bool hadEDNS = false; bool dnssecOK = false; if (g_addEDNSToSelfGeneratedResponses) { uint16_t payloadSize = 0; uint16_t z = 0; - hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.dh), dq.len, &payloadSize, &z); + hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast(packet.data()), packet.size(), &payloadSize, &z); if (hadEDNS) { dnssecOK = z & EDNS_HEADER_FLAG_DO; } } /* chop off everything after the question */ - dq.len = queryPartSize; + packet.resize(queryPartSize); + dh = dq.getHeader(); if (nxd) { - dq.dh->rcode = RCode::NXDomain; + dh->rcode = RCode::NXDomain; } else { - dq.dh->rcode = RCode::NoError; + dh->rcode = RCode::NoError; } - dq.dh->qr = true; - dq.dh->ancount = 0; - dq.dh->nscount = 0; - dq.dh->arcount = 0; + dh->qr = true; + dh->ancount = 0; + dh->nscount = 0; + dh->arcount = 0; rdLength = htons(rdLength); ttl = htonl(ttl); @@ -941,15 +931,13 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, throw std::runtime_error("Unexpected SOA response size: " + std::to_string(soa.size()) + " vs " + std::to_string(soaSize)); } - memcpy(reinterpret_cast(dq.dh) + queryPartSize, soa.c_str(), soa.size()); - - dq.len += soa.size(); - - dq.dh->arcount = htons(1); + packet.insert(packet.end(), soa.begin(), soa.end()); + dh = dq.getHeader(); + dh->arcount = htons(1); if (hadEDNS) { /* now we need to add a new OPT record */ - return addEDNS(dq.dh, dq.len, dq.size, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode); + return addEDNS(packet, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode); } return true; @@ -961,7 +949,8 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq) /* remaining is at least the size of the rdlen + the options if any + the following records if any */ size_t remaining = 0; - int res = getEDNSOptionsStart(reinterpret_cast(dq.dh), dq.consumed, dq.len, &optRDPosition, &remaining); + auto& packet = dq.getMutableData(); + int res = getEDNSOptionsStart(packet, dq.qname->wirelength(), &optRDPosition, &remaining); if (res != 0) { /* if the initial query did not have EDNS0, we are done */ @@ -969,25 +958,25 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq) } const size_t existingOptLen = /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2 + remaining; - if (existingOptLen >= dq.len) { + if (existingOptLen >= packet.size()) { /* something is wrong, bail out */ return false; } - char* optRDLen = reinterpret_cast(dq.dh) + optRDPosition; - char * optPtr = (optRDLen - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2)); + uint8_t* optRDLen = &packet.at(optRDPosition); + uint8_t* optPtr = (optRDLen - (/* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + /* Z */ 2)); - const uint8_t* zPtr = reinterpret_cast(optPtr) + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE; + const uint8_t* zPtr = optPtr + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE; uint16_t z = 0x100 * (*zPtr) + *(zPtr + 1); bool dnssecOK = z & EDNS_HEADER_FLAG_DO; /* remove the existing OPT record, and everything else that follows (any SIG or TSIG would be useless anyway) */ - dq.len -= existingOptLen; - dq.dh->arcount = 0; + packet.resize(packet.size() - existingOptLen); + dq.getHeader()->arcount = 0; if (g_addEDNSToSelfGeneratedResponses) { /* now we need to add a new OPT record */ - return addEDNS(dq.dh, dq.len, dq.size, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode); + return addEDNS(packet, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode); } /* otherwise we are just fine */ @@ -996,45 +985,48 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq) // goal in life - if you send us a reasonably normal packet, we'll get Z for you, otherwise 0 int getEDNSZ(const DNSQuestion& dq) -try { - if (ntohs(dq.dh->qdcount) != 1 || dq.dh->ancount != 0 || ntohs(dq.dh->arcount) != 1 || dq.dh->nscount != 0) { - return 0; - } + try + { + const auto& dh = dq.getHeader(); + if (ntohs(dh->qdcount) != 1 || dh->ancount != 0 || ntohs(dh->arcount) != 1 || dh->nscount != 0) { + return 0; + } - if (dq.len <= sizeof(dnsheader)) { - return 0; - } + if (dq.getData().size() <= sizeof(dnsheader)) { + return 0; + } - size_t pos = sizeof(dnsheader) + dq.consumed + DNS_TYPE_SIZE + DNS_CLASS_SIZE; + size_t pos = sizeof(dnsheader) + dq.qname->wirelength() + DNS_TYPE_SIZE + DNS_CLASS_SIZE; - if (dq.len <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) { - return 0; - } + if (dq.getData().size() <= (pos + /* root */ 1 + DNS_TYPE_SIZE + DNS_CLASS_SIZE)) { + return 0; + } - const char* packet = reinterpret_cast(dq.dh); + auto& packet = dq.getData(); - if (packet[pos] != 0) { - /* not root, so not a valid OPT record */ - return 0; - } + if (packet.at(pos) != 0) { + /* not root, so not a valid OPT record */ + return 0; + } - pos++; + pos++; - uint16_t qtype = (reinterpret_cast(packet)[pos])*256 + reinterpret_cast(packet)[pos+1]; - pos += DNS_TYPE_SIZE; - pos += DNS_CLASS_SIZE; + uint16_t qtype = packet.at(pos)*256 + packet.at(pos+1); + pos += DNS_TYPE_SIZE; + pos += DNS_CLASS_SIZE; + + if (qtype != QType::OPT || (pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= packet.size()) { + return 0; + } - if (qtype != QType::OPT || (pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE + 1) >= dq.len) { + const uint8_t* z = &packet.at(pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE); + return 0x100 * (*z) + *(z+1); + } + catch(...) + { return 0; } - - const uint8_t* z = reinterpret_cast(packet) + pos + EDNS_EXTENDED_RCODE_SIZE + EDNS_VERSION_SIZE; - return 0x100 * (*z) + *(z+1); -} -catch(...) -{ - return 0; } bool queryHasEDNS(const DNSQuestion& dq) @@ -1042,7 +1034,7 @@ bool queryHasEDNS(const DNSQuestion& dq) uint16_t optRDPosition; size_t ecsRemaining = 0; - int res = getEDNSOptionsStart(reinterpret_cast(dq.dh), dq.consumed, dq.len, &optRDPosition, &ecsRemaining); + int res = getEDNSOptionsStart(dq.getData(), dq.qname->wirelength(), &optRDPosition, &ecsRemaining); if (res == 0) { return true; } @@ -1055,8 +1047,9 @@ bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0) uint16_t optStart; size_t optLen = 0; bool last = false; - const char * packet = reinterpret_cast(dq.dh); - std::string packetStr(packet, dq.len); + const auto& packet = dq.getData(); +#warning FIXME: save an alloc+copy + std::string packetStr(reinterpret_cast(packet.data()), packet.size()); int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last); if (res != 0) { // no EDNS OPT RR @@ -1067,13 +1060,13 @@ bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0) return false; } - if (optStart < dq.len && packetStr.at(optStart) != 0) { + if (optStart < packet.size() && packetStr.at(optStart) != 0) { // OPT RR Name != '.' return false; } static_assert(sizeof(EDNS0Record) == sizeof(uint32_t), "sizeof(EDNS0Record) must match sizeof(uint32_t) AKA RR TTL size"); // copy out 4-byte "ttl" (really the EDNS0 record), after root label (1) + type (2) + class (2). - memcpy(&edns0, packet + optStart + 5, sizeof edns0); + memcpy(&edns0, &packet.at(optStart + 5), sizeof edns0); return true; } diff --git a/pdns/dnsdist-ecs.hh b/pdns/dnsdist-ecs.hh index 7a34084b41..385454a320 100644 --- a/pdns/dnsdist-ecs.hh +++ b/pdns/dnsdist-ecs.hh @@ -33,14 +33,14 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength); int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove); int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector& newContent); -int getEDNSOptionsStart(const char* packet, const size_t offset, const size_t len, uint16_t* optRDPosition, size_t * remaining); +int getEDNSOptionsStart(const std::vector& packet, const size_t offset, uint16_t* optRDPosition, size_t * remaining); bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart = nullptr, uint16_t* optContentLen = nullptr); -bool addEDNS(dnsheader* dh, uint16_t& len, const size_t size, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode); +bool addEDNS(std::vector& packet, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode); bool addEDNSToQueryTurnedResponse(DNSQuestion& dq); bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum); -bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded, bool preserveTrailingData); -bool handleEDNSClientSubnet(char* packet, size_t packetSize, unsigned int consumed, uint16_t* len, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption, bool preserveTrailingData); +bool handleEDNSClientSubnet(DNSQuestion& dq, bool& ednsAdded, bool& ecsAdded); +bool handleEDNSClientSubnet(std::vector& packet, size_t maximumSize, size_t qnameWireLength, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption); bool parseEDNSOptions(const DNSQuestion& dq); diff --git a/pdns/dnsdist-lua-actions.cc b/pdns/dnsdist-lua-actions.cc index 49bc53f17d..3b962af864 100644 --- a/pdns/dnsdist-lua-actions.cc +++ b/pdns/dnsdist-lua-actions.cc @@ -174,28 +174,26 @@ DNSAction::Action TeeAction::operator()(DNSQuestion* dq, std::string* ruleresult d_queries++; if(d_addECS) { - std::string query; - uint16_t len = dq->len; + std::vector query(dq->getData()); bool ednsAdded = false; bool ecsAdded = false; - query.reserve(dq->size); - query.assign((char*) dq->dh, len); std::string newECSOption; generateECSOption(dq->ecsSet ? dq->ecs.getNetwork() : *dq->remote, newECSOption, dq->ecsSet ? dq->ecs.getBits() : dq->ecsPrefixLength); - if (!handleEDNSClientSubnet(const_cast(query.c_str()), query.capacity(), dq->qname->wirelength(), &len, ednsAdded, ecsAdded, dq->ecsOverride, newECSOption, g_preserveTrailingData)) { + if (!handleEDNSClientSubnet(query, dq->getMaximumSize(), dq->qname->wirelength(), ednsAdded, ecsAdded, dq->ecsOverride, newECSOption)) { return DNSAction::Action::None; } - res = send(d_fd, query.c_str(), len, 0); + res = send(d_fd, query.data(), query.size(), 0); } else { - res = send(d_fd, (char*)dq->dh, dq->len, 0); + res = send(d_fd, dq->getData().data(), dq->getData().size(), 0); } - if (res <= 0) + if (res <= 0) { d_senderrors++; + } } return DNSAction::Action::None; } @@ -305,9 +303,9 @@ public: RCodeAction(uint8_t rcode) : d_rcode(rcode) {} DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->dh->rcode = d_rcode; - dq->dh->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->dh, d_responseConfig); + dq->getHeader()->rcode = d_rcode; + dq->getHeader()->qr = true; // for good measure + setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); return Action::HeaderModify; } std::string toString() const override @@ -326,10 +324,10 @@ public: ERCodeAction(uint8_t rcode) : d_rcode(rcode) {} DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->dh->rcode = (d_rcode & 0xF); + dq->getHeader()->rcode = (d_rcode & 0xF); dq->ednsRCode = ((d_rcode & 0xFFF0) >> 4); - dq->dh->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->dh, d_responseConfig); + dq->getHeader()->qr = true; // for good measure + setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); return Action::HeaderModify; } std::string toString() const override @@ -559,10 +557,10 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu shuffle(addrs.begin(), addrs.end(), t_randomEngine); } - unsigned int consumed=0; - DNSName ignore((char*)dq->dh, dq->len, sizeof(dnsheader), false, 0, 0, &consumed); + unsigned int qnameWireLength=0; + DNSName ignore((char*)dq->getData().data(), dq->getData().size(), sizeof(dnsheader), false, 0, 0, &qnameWireLength); - if (dq->size < (sizeof(dnsheader) + consumed + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen)) { + if (dq->getMaximumSize() < (sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen)) { return Action::None; } @@ -573,13 +571,14 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu dnssecOK = getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO; } - dq->len = sizeof(dnsheader) + consumed + 4; // there goes your EDNS - char* dest = ((char*)dq->dh) + dq->len; + auto& data = dq->getMutableData(); + data.resize(sizeof(dnsheader) + qnameWireLength + 4 + numberOfRecords*12 /* recordstart */ + totrdatalen); // there goes your EDNS + uint8_t* dest = &(data.at(sizeof(dnsheader) + qnameWireLength + 4)); - dq->dh->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->dh, d_responseConfig); - dq->dh->ancount = 0; - dq->dh->arcount = 0; // for now, forget about your EDNS, we're marching over it + dq->getHeader()->qr = true; // for good measure + setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); + dq->getHeader()->ancount = 0; + dq->getHeader()->arcount = 0; // for now, forget about your EDNS, we're marching over it uint32_t ttl = htonl(d_responseConfig.ttl); unsigned char recordstart[] = {0xc0, 0x0c, // compressed name @@ -601,8 +600,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu memcpy(dest, recordstart, sizeof(recordstart)); dest += sizeof(recordstart); memcpy(dest, wireData.c_str(), wireData.length()); - dq->len += wireData.length() + sizeof(recordstart); - dq->dh->ancount++; + dq->getHeader()->ancount++; } else if (!d_rawResponse.empty()) { uint16_t rdataLen = htons(d_rawResponse.size()); @@ -613,8 +611,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu memcpy(dest, recordstart, sizeof(recordstart)); dest += sizeof(recordstart); memcpy(dest, d_rawResponse.c_str(), d_rawResponse.size()); - dq->len += d_rawResponse.size() + sizeof(recordstart); - dq->dh->ancount++; + dq->getHeader()->ancount++; raw = true; } else { @@ -631,15 +628,14 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu addr.sin4.sin_family == AF_INET ? (void*)&addr.sin4.sin_addr.s_addr : (void*)&addr.sin6.sin6_addr.s6_addr, addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)); dest += (addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)); - dq->len += (addr.sin4.sin_family == AF_INET ? sizeof(addr.sin4.sin_addr.s_addr) : sizeof(addr.sin6.sin6_addr.s6_addr)) + sizeof(recordstart); - dq->dh->ancount++; + dq->getHeader()->ancount++; } } - dq->dh->ancount = htons(dq->dh->ancount); + dq->getHeader()->ancount = htons(dq->getHeader()->ancount); if (hadEDNS && raw == false) { - addEDNS(dq->dh, dq->len, dq->size, dnssecOK, g_PayloadSizeSelfGenAnswers, 0); + addEDNS(dq->getMutableData(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0); } return Action::HeaderModify; @@ -652,7 +648,7 @@ public: {} DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - if(dq->dh->arcount) + if(dq->getHeader()->arcount) return Action::None; std::string mac = getMACAddress(*dq->remote); @@ -665,13 +661,13 @@ public: std::string res; generateOptRR(optRData, res, g_EdnsUDPPayloadSize, 0, false); - if ((dq->size - dq->len) < res.length()) + if (!dq->hasRoomFor(res.length())) { return Action::None; + } - dq->dh->arcount = htons(1); - char* dest = ((char*)dq->dh) + dq->len; - memcpy(dest, res.c_str(), res.length()); - dq->len += res.length(); + dq->getHeader()->arcount = htons(1); + auto& data = dq->getMutableData(); + data.insert(data.end(), res.begin(), res.end()); return Action::None; } @@ -688,7 +684,7 @@ class NoRecurseAction : public DNSAction public: DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->dh->rd = false; + dq->getHeader()->rd = false; return Action::None; } std::string toString() const override @@ -723,10 +719,10 @@ public: if (!d_fp) { if (!d_verboseOnly || g_verbose) { if (d_includeTimestamp) { - infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast(dq->queryTime->tv_sec), static_cast(dq->queryTime->tv_nsec), dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->dh->id); + infolog("[%u.%u] Packet from %s for %s %s with id %d", static_cast(dq->queryTime->tv_sec), static_cast(dq->queryTime->tv_nsec), dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->getHeader()->id); } else { - infolog("Packet from %s for %s %s with id %d", dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->dh->id); + infolog("Packet from %s for %s %s with id %d", dq->remote->toStringWithPort(), dq->qname->toString(), QType(dq->qtype).getName(), dq->getHeader()->id); } } } @@ -739,7 +735,7 @@ public: fwrite(&tv_sec, sizeof(tv_sec), 1, d_fp.get()); fwrite(&tv_nsec, sizeof(tv_nsec), 1, d_fp.get()); } - uint16_t id = dq->dh->id; + uint16_t id = dq->getHeader()->id; fwrite(&id, sizeof(id), 1, d_fp.get()); fwrite(out.c_str(), 1, out.size(), d_fp.get()); fwrite(&dq->qtype, sizeof(dq->qtype), 1, d_fp.get()); @@ -754,10 +750,10 @@ public: } else { if (d_includeTimestamp) { - fprintf(d_fp.get(), "[%llu.%lu] Packet from %s for %s %s with id %d\n", static_cast(dq->queryTime->tv_sec), static_cast(dq->queryTime->tv_nsec), dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).getName().c_str(), dq->dh->id); + fprintf(d_fp.get(), "[%llu.%lu] Packet from %s for %s %s with id %d\n", static_cast(dq->queryTime->tv_sec), static_cast(dq->queryTime->tv_nsec), dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).getName().c_str(), dq->getHeader()->id); } else { - fprintf(d_fp.get(), "Packet from %s for %s %s with id %d\n", dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).getName().c_str(), dq->dh->id); + fprintf(d_fp.get(), "Packet from %s for %s %s with id %d\n", dq->remote->toStringWithPort().c_str(), dq->qname->toString().c_str(), QType(dq->qtype).getName().c_str(), dq->getHeader()->id); } } } @@ -805,19 +801,19 @@ public: if (!d_fp) { if (!d_verboseOnly || g_verbose) { if (d_includeTimestamp) { - infolog("[%u.%u] Answer to %s for %s %s (%s) with id %d", static_cast(dr->queryTime->tv_sec), static_cast(dr->queryTime->tv_nsec), dr->remote->toStringWithPort(), dr->qname->toString(), QType(dr->qtype).getName(), RCode::to_s(dr->dh->rcode), dr->dh->id); + infolog("[%u.%u] Answer to %s for %s %s (%s) with id %d", static_cast(dr->queryTime->tv_sec), static_cast(dr->queryTime->tv_nsec), dr->remote->toStringWithPort(), dr->qname->toString(), QType(dr->qtype).getName(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); } else { - infolog("Answer to %s for %s %s (%s) with id %d", dr->remote->toStringWithPort(), dr->qname->toString(), QType(dr->qtype).getName(), RCode::to_s(dr->dh->rcode), dr->dh->id); + infolog("Answer to %s for %s %s (%s) with id %d", dr->remote->toStringWithPort(), dr->qname->toString(), QType(dr->qtype).getName(), RCode::to_s(dr->getHeader()->rcode), dr->getHeader()->id); } } } else { if (d_includeTimestamp) { - fprintf(d_fp.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %d\n", static_cast(dr->queryTime->tv_sec), static_cast(dr->queryTime->tv_nsec), dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).getName().c_str(), RCode::to_s(dr->dh->rcode).c_str(), dr->dh->id); + fprintf(d_fp.get(), "[%llu.%lu] Answer to %s for %s %s (%s) with id %d\n", static_cast(dr->queryTime->tv_sec), static_cast(dr->queryTime->tv_nsec), dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).getName().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); } else { - fprintf(d_fp.get(), "Answer to %s for %s %s (%s) with id %d\n", dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).getName().c_str(), RCode::to_s(dr->dh->rcode).c_str(), dr->dh->id); + fprintf(d_fp.get(), "Answer to %s for %s %s (%s) with id %d\n", dr->remote->toStringWithPort().c_str(), dr->qname->toString().c_str(), QType(dr->qtype).getName().c_str(), RCode::to_s(dr->getHeader()->rcode).c_str(), dr->getHeader()->id); } } return Action::None; @@ -843,7 +839,7 @@ class DisableValidationAction : public DNSAction public: DNSAction::Action operator()(DNSQuestion* dq, std::string* ruleresult) const override { - dq->dh->cd = true; + dq->getHeader()->cd = true; return Action::None; } std::string toString() const override @@ -990,8 +986,7 @@ public: static thread_local std::string data; data.clear(); - const struct dnsheader* dh = reinterpret_cast(dq->dh); - DnstapMessage message(data, !dh->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, dq->remote, dq->local, dq->tcp, reinterpret_cast(dq->dh), dq->len, dq->queryTime, nullptr); + DnstapMessage message(data, !dq->getHeader()->qr ? DnstapMessage::MessageType::client_query : DnstapMessage::MessageType::client_response, d_identity, dq->remote, dq->local, dq->tcp, reinterpret_cast(dq->getData().data()), dq->getData().size(), dq->queryTime, nullptr); { if (d_alterFunc) { std::lock_guard lock(g_luamutex); @@ -1120,7 +1115,7 @@ public: gettime(&now, true); data.clear(); - DnstapMessage message(data, DnstapMessage::MessageType::client_response, d_identity, dr->remote, dr->local, dr->tcp, reinterpret_cast(dr->dh), dr->len, dr->queryTime, &now); + DnstapMessage message(data, DnstapMessage::MessageType::client_response, d_identity, dr->remote, dr->local, dr->tcp, reinterpret_cast(dr->getData().data()), dr->getData().size(), dr->queryTime, &now); { if (d_alterFunc) { std::lock_guard lock(g_luamutex); @@ -1320,7 +1315,7 @@ private: class HTTPStatusAction: public DNSAction { public: - HTTPStatusAction(int code, const std::string& body, const std::string& contentType): d_body(body), d_contentType(contentType), d_code(code) + HTTPStatusAction(int code, const std::vector& body, const std::string& contentType): d_body(body), d_contentType(contentType), d_code(code) { } @@ -1330,9 +1325,9 @@ public: return Action::None; } - dq->du->setHTTPResponse(d_code, d_body, d_contentType); - dq->dh->qr = true; // for good measure - setResponseHeadersFromConfig(*dq->dh, d_responseConfig); + dq->du->setHTTPResponse(d_code, std::vector(d_body), d_contentType); + dq->getHeader()->qr = true; // for good measure + setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); return Action::HeaderModify; } @@ -1343,7 +1338,7 @@ public: ResponseConfig d_responseConfig; private: - std::string d_body; + std::vector d_body; std::string d_contentType; int d_code; }; @@ -1399,7 +1394,7 @@ public: return Action::None; } - setResponseHeadersFromConfig(*dq->dh, d_responseConfig); + setResponseHeadersFromConfig(*dq->getHeader(), d_responseConfig); return Action::Allow; } @@ -1822,7 +1817,7 @@ void setupLuaActions(LuaContext& luaCtx) #ifdef HAVE_DNS_OVER_HTTPS luaCtx.writeFunction("HTTPStatusAction", [](uint16_t status, std::string body, boost::optional contentType, boost::optional vars) { - auto ret = std::shared_ptr(new HTTPStatusAction(status, body, contentType ? *contentType : "")); + auto ret = std::shared_ptr(new HTTPStatusAction(status, std::vector(body.begin(), body.end()), contentType ? *contentType : "")); auto hsa = std::dynamic_pointer_cast(ret); parseResponseConfig(vars, hsa->d_responseConfig); return ret; diff --git a/pdns/dnsdist-lua-bindings-dnsquestion.cc b/pdns/dnsdist-lua-bindings-dnsquestion.cc index b0b4a2f39c..01988b7a23 100644 --- a/pdns/dnsdist-lua-bindings-dnsquestion.cc +++ b/pdns/dnsdist-lua-bindings-dnsquestion.cc @@ -32,13 +32,14 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerMember("qname", [](const DNSQuestion& dq) -> const DNSName { return *dq.qname; }, [](DNSQuestion& dq, const DNSName newName) { (void) newName; }); luaCtx.registerMember("qtype", [](const DNSQuestion& dq) -> uint16_t { return dq.qtype; }, [](DNSQuestion& dq, uint16_t newType) { (void) newType; }); luaCtx.registerMember("qclass", [](const DNSQuestion& dq) -> uint16_t { return dq.qclass; }, [](DNSQuestion& dq, uint16_t newClass) { (void) newClass; }); - luaCtx.registerMember("rcode", [](const DNSQuestion& dq) -> int { return dq.dh->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.dh->rcode = newRCode; }); + luaCtx.registerMember("rcode", [](const DNSQuestion& dq) -> int { return dq.getHeader()->rcode; }, [](DNSQuestion& dq, int newRCode) { dq.getHeader()->rcode = newRCode; }); luaCtx.registerMember("remoteaddr", [](const DNSQuestion& dq) -> const ComboAddress { return *dq.remote; }, [](DNSQuestion& dq, const ComboAddress newRemote) { (void) newRemote; }); /* DNSDist DNSQuestion */ - luaCtx.registerMember("dh", &DNSQuestion::dh); - luaCtx.registerMember("len", [](const DNSQuestion& dq) -> uint16_t { return dq.len; }, [](DNSQuestion& dq, uint16_t newlen) { dq.len = newlen; }); - luaCtx.registerMember("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.dh->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; }); - luaCtx.registerMember("size", [](const DNSQuestion& dq) -> size_t { return dq.size; }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; }); + luaCtx.registerMember("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast(dq).getHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { *(dq.getHeader()) = *dh; }); + luaCtx.registerMember("len", [](const DNSQuestion& dq) -> uint16_t { return dq.getData().size(); }, [](DNSQuestion& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); }); + luaCtx.registerMember("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; }); + #warning FIXME we need to provide Lua with a way to update the size + //luaCtx.registerMember("size", [](const DNSQuestion& dq) -> size_t { return dq.getData().size(); }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; }); luaCtx.registerMember("tcp", [](const DNSQuestion& dq) -> bool { return dq.tcp; }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; }); luaCtx.registerMember("skipCache", [](const DNSQuestion& dq) -> bool { return dq.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.skipCache = newSkipCache; }); luaCtx.registerMember("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; }); @@ -135,16 +136,17 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) luaCtx.registerMember("qname", [](const DNSResponse& dq) -> const DNSName { return *dq.qname; }, [](DNSResponse& dq, const DNSName newName) { (void) newName; }); luaCtx.registerMember("qtype", [](const DNSResponse& dq) -> uint16_t { return dq.qtype; }, [](DNSResponse& dq, uint16_t newType) { (void) newType; }); luaCtx.registerMember("qclass", [](const DNSResponse& dq) -> uint16_t { return dq.qclass; }, [](DNSResponse& dq, uint16_t newClass) { (void) newClass; }); - luaCtx.registerMember("rcode", [](const DNSResponse& dq) -> int { return dq.dh->rcode; }, [](DNSResponse& dq, int newRCode) { dq.dh->rcode = newRCode; }); + luaCtx.registerMember("rcode", [](const DNSResponse& dq) -> int { return dq.getHeader()->rcode; }, [](DNSResponse& dq, int newRCode) { dq.getHeader()->rcode = newRCode; }); luaCtx.registerMember("remoteaddr", [](const DNSResponse& dq) -> const ComboAddress { return *dq.remote; }, [](DNSResponse& dq, const ComboAddress newRemote) { (void) newRemote; }); - luaCtx.registerMember("dh", [](const DNSResponse& dr) -> dnsheader* { return dr.dh; }, [](DNSResponse& dr, dnsheader * newdh) { dr.dh = newdh; }); - luaCtx.registerMember("len", [](const DNSResponse& dq) -> uint16_t { return dq.len; }, [](DNSResponse& dq, uint16_t newlen) { dq.len = newlen; }); - luaCtx.registerMember("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.dh->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; }); - luaCtx.registerMember("size", [](const DNSResponse& dq) -> size_t { return dq.size; }, [](DNSResponse& dq, size_t newSize) { (void) newSize; }); + luaCtx.registerMember("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast(dr).getHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { *(dr.getHeader()) = *dh; }); + luaCtx.registerMember("len", [](const DNSResponse& dq) -> uint16_t { return dq.getData().size(); }, [](DNSResponse& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); }); + luaCtx.registerMember("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; }); + #warning FIXME we need to provide Lua with a way to update the size + //luaCtx.registerMember("size", [](const DNSResponse& dq) -> size_t { return dq.size; }, [](DNSResponse& dq, size_t newSize) { (void) newSize; }); luaCtx.registerMember("tcp", [](const DNSResponse& dq) -> bool { return dq.tcp; }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; }); luaCtx.registerMember("skipCache", [](const DNSResponse& dq) -> bool { return dq.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.skipCache = newSkipCache; }); - luaCtx.registerFunction editFunc)>("editTTLs", [](const DNSResponse& dr, std::function editFunc) { - editDNSPacketTTL((char*) dr.dh, dr.len, editFunc); + luaCtx.registerFunction editFunc)>("editTTLs", [](DNSResponse& dr, std::function editFunc) { + editDNSPacketTTL(reinterpret_cast(dr.getMutableData().data()), dr.getData().size(), editFunc); }); luaCtx.registerFunction("getTrailingData", [](const DNSResponse& dq) { return dq.getTrailingData(); @@ -238,7 +240,8 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx) if (dq.du == nullptr) { return; } - dq.du->setHTTPResponse(statusCode, body, contentType ? *contentType : ""); + std::vector vect(body.begin(), body.end()); + dq.du->setHTTPResponse(statusCode, std::move(vect), contentType ? *contentType : ""); }); #endif /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/dnsdist-lua-bindings.cc b/pdns/dnsdist-lua-bindings.cc index 2cd34f9215..d7a8b6abfe 100644 --- a/pdns/dnsdist-lua-bindings.cc +++ b/pdns/dnsdist-lua-bindings.cc @@ -490,6 +490,6 @@ void setupLuaBindings(LuaContext& luaCtx, bool client) headers->push_back({ boost::to_lower_copy(header.first), header.second }); } } - return std::make_shared(regex, status, content, headers); + return std::make_shared(regex, status, std::vector(content.begin(), content.end()), headers); }); } diff --git a/pdns/dnsdist-lua-rules.cc b/pdns/dnsdist-lua-rules.cc index e5af721c6d..0ad71bd8e5 100644 --- a/pdns/dnsdist-lua-rules.cc +++ b/pdns/dnsdist-lua-rules.cc @@ -443,10 +443,11 @@ void setupLuaRules(LuaContext& luaCtx) StopWatch sw; sw.start(); for(int n=0; n < times; ++n) { - const item& i = items[n % items.size()]; - DNSQuestion dq(&i.qname, i.qtype, i.qclass, 0, &i.rem, &i.rem, (struct dnsheader*)&i.packet[0], i.packet.size(), i.packet.size(), false, &sw.d_start); - if(rule->matches(&dq)) + item& i = items[n % items.size()]; + DNSQuestion dq(&i.qname, i.qtype, i.qclass, &i.rem, &i.rem, i.packet, false, &sw.d_start); + if (rule->matches(&dq)) { matches++; + } } double udiff=sw.udiff(); g_outputBuffer=(boost::format("Had %d matches out of %d, %.1f qps, in %.1f usec\n") % matches % times % (1000000*(1.0*times/udiff)) % udiff).str(); diff --git a/pdns/dnsdist-lua.cc b/pdns/dnsdist-lua.cc index 7eeac832f0..8a13a0f005 100644 --- a/pdns/dnsdist-lua.cc +++ b/pdns/dnsdist-lua.cc @@ -1159,8 +1159,6 @@ static void setupLuaConfig(LuaContext& luaCtx, bool client, bool configCheck) luaCtx.writeFunction("setECSOverride", [](bool override) { g_ECSOverride=override; }); - luaCtx.writeFunction("setPreserveTrailingData", [](bool preserve) { g_preserveTrailingData = preserve; }); - luaCtx.writeFunction("showDynBlocks", []() { setLuaNoSideEffect(); auto slow = g_dynblockNMG.getCopy(); diff --git a/pdns/dnsdist-protobuf.cc b/pdns/dnsdist-protobuf.cc index 02f77f3509..1ceb86af72 100644 --- a/pdns/dnsdist-protobuf.cc +++ b/pdns/dnsdist-protobuf.cc @@ -124,7 +124,7 @@ void DNSDistProtoBufMessage::serialize(std::string& data) const m.setTime(ts.tv_sec, ts.tv_nsec / 1000); } - m.setRequest(d_dq.uniqueId ? *d_dq.uniqueId : getUniqueID(), d_requestor ? *d_requestor : *d_dq.remote, d_responder ? *d_responder : *d_dq.local, d_question ? d_question->d_name : *d_dq.qname, d_question ? d_question->d_type : d_dq.qtype, d_question ? d_question->d_class : d_dq.qclass, d_dq.dh->id, d_dq.tcp, d_bytes ? *d_bytes : d_dq.len); + m.setRequest(d_dq.uniqueId ? *d_dq.uniqueId : getUniqueID(), d_requestor ? *d_requestor : *d_dq.remote, d_responder ? *d_responder : *d_dq.local, d_question ? d_question->d_name : *d_dq.qname, d_question ? d_question->d_type : d_dq.qtype, d_question ? d_question->d_class : d_dq.qclass, d_dq.getHeader()->id, d_dq.tcp, d_bytes ? *d_bytes : d_dq.getData().size()); if (d_serverIdentity) { m.setServerIdentity(*d_serverIdentity); @@ -146,8 +146,8 @@ void DNSDistProtoBufMessage::serialize(std::string& data) const } if (d_dr != nullptr) { - m.setResponseCode(d_rcode ? *d_rcode : d_dr->dh->rcode); - m.addRRsFromPacket(reinterpret_cast(d_dr->dh), d_dr->len, d_includeCNAME); + m.setResponseCode(d_rcode ? *d_rcode : d_dr->getHeader()->rcode); + m.addRRsFromPacket(reinterpret_cast(d_dr->getData().data()), d_dr->getData().size(), d_includeCNAME); } else { if (d_rcode) { diff --git a/pdns/dnsdist-snmp.cc b/pdns/dnsdist-snmp.cc index e62f9f0c1c..3d79859fe1 100644 --- a/pdns/dnsdist-snmp.cc +++ b/pdns/dnsdist-snmp.cc @@ -448,9 +448,9 @@ bool DNSDistSNMPAgent::sendDNSTrap(const DNSQuestion& dq, const std::string& rea std::string qname = dq.qname->toStringNoDot(); const uint32_t socketFamily = dq.remote->isIPv4() ? 1 : 2; const uint32_t socketProtocol = dq.tcp ? 2 : 1; - const uint32_t queryType = dq.dh->qr ? 2 : 1; - const uint32_t querySize = (uint32_t) dq.len; - const uint32_t queryID = (uint32_t) ntohs(dq.dh->id); + const uint32_t queryType = dq.getHeader()->qr ? 2 : 1; + const uint32_t querySize = (uint32_t) dq.getData().size(); + const uint32_t queryID = (uint32_t) ntohs(dq.getHeader()->id); const uint32_t qType = (uint32_t) dq.qtype; const uint32_t qClass = (uint32_t) dq.qclass; diff --git a/pdns/dnsdist-tcp.cc b/pdns/dnsdist-tcp.cc index a3f32f0292..93b8790cd4 100644 --- a/pdns/dnsdist-tcp.cc +++ b/pdns/dnsdist-tcp.cc @@ -446,40 +446,20 @@ void IncomingTCPConnectionState::handleResponse(std::shared_ptr(512)); - size_t responseCapacity = response.d_buffer.size(); - auto responseAsCharArray = reinterpret_cast(&response.d_buffer.at(0)); - auto& ids = response.d_idstate; - unsigned int consumed; - if (!responseContentMatches(responseAsCharArray, responseSize, ids.qname, ids.qtype, ids.qclass, response.d_connection->getRemote(), consumed)) { + unsigned int qnameWireLength; + if (!responseContentMatches(response.d_buffer, ids.qname, ids.qtype, ids.qclass, response.d_connection->getRemote(), qnameWireLength)) { return; } - auto dh = reinterpret_cast(responseAsCharArray); - uint16_t addRoom = 0; - DNSResponse dr = makeDNSResponseFromIDState(ids, dh, responseCapacity, responseSize, true); - if (dr.dnsCryptQuery) { - addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; - } + DNSResponse dr = makeDNSResponseFromIDState(ids, response.d_buffer, true); - memcpy(&response.d_cleartextDH, dr.dh, sizeof(response.d_cleartextDH)); + memcpy(&response.d_cleartextDH, dr.getHeader(), sizeof(response.d_cleartextDH)); - std::vector rewrittenResponse; - if (!processResponse(&responseAsCharArray, &responseSize, &responseCapacity, state->d_threadData.localRespRulactions, dr, addRoom, rewrittenResponse, false)) { + if (!processResponse(response.d_buffer, state->d_threadData.localRespRulactions, dr, false)) { return; } - if (!rewrittenResponse.empty()) { - /* responseSize has been updated as well but we don't really care since it will match - the capacity of rewrittenResponse anyway */ - response.d_buffer = std::move(rewrittenResponse); - } else { - /* the size might have been updated (shrinked) if we removed the whole OPT RR, for example) */ - response.d_buffer.resize(responseSize); - } - if (state->d_isXFR && !state->d_xfrStarted) { /* don't bother parsing the content of the response for now */ state->d_xfrStarted = true; @@ -539,26 +519,24 @@ static IOState handleQuery(std::shared_ptr& state, c struct timespec queryRealTime; gettime(&queryRealTime, true); - auto query = reinterpret_cast(&state->d_buffer.at(0)); std::shared_ptr dnsCryptQuery{nullptr}; - auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, query, state->d_querySize, dnsCryptQuery, queryRealTime.tv_sec, true); + auto dnsCryptResponse = checkDNSCryptQuery(*state->d_ci.cs, state->d_buffer, dnsCryptQuery, queryRealTime.tv_sec, true); if (dnsCryptResponse) { TCPResponse response; - response.d_buffer = std::move(*dnsCryptResponse); state->d_state = IncomingTCPConnectionState::State::idle; ++state->d_currentQueriesCount; return state->sendResponse(state, now, std::move(response)); } - const auto& dh = reinterpret_cast(query); + const auto* dh = reinterpret_cast(state->d_buffer.data()); if (!checkQueryHeaders(dh)) { return IOState::NeedRead; } uint16_t qtype, qclass; - unsigned int consumed = 0; - DNSName qname(query, state->d_querySize, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, &state->d_origDest, &state->d_ci.remote, reinterpret_cast(query), state->d_buffer.size(), state->d_querySize, true, &queryRealTime); + unsigned int qnameWireLength = 0; + DNSName qname(reinterpret_cast(state->d_buffer.data()), state->d_buffer.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); + DNSQuestion dq(&qname, qtype, qclass, &state->d_origDest, &state->d_ci.remote, state->d_buffer, true, &queryRealTime); dq.dnsCryptQuery = std::move(dnsCryptQuery); dq.sni = state->d_handler.getServerNameIndication(); @@ -574,8 +552,9 @@ static IOState handleQuery(std::shared_ptr& state, c return IOState::Done; } + // the buffer might have been invalidated by now + dh = dq.getHeader(); if (result == ProcessQueryResult::SendAnswer) { - state->d_buffer.resize(dq.len); TCPResponse response; response.d_selfGenerated = true; response.d_buffer = std::move(state->d_buffer); @@ -592,18 +571,12 @@ static IOState handleQuery(std::shared_ptr& state, c setIDStateFromDNSQuestion(ids, dq, std::move(qname)); ids.origID = ntohs(dh->id); - const uint8_t sizeBytes[] = { static_cast(dq.len / 256), static_cast(dq.len % 256) }; + uint16_t queryLen = state->d_buffer.size(); + const uint8_t sizeBytes[] = { static_cast(queryLen / 256), static_cast(queryLen % 256) }; /* prepend the size. Yes, this is not the most efficient way but it prevents mistakes that could occur if we had to deal with the size during the processing, especially alignment issues */ - /* first we need to resize to the size that is actually used, since we allocated more to be able to insert - EDNS or Proxy Protocol values */ - dq.size = state->d_buffer.size(); - state->d_buffer.resize(dq.len); state->d_buffer.insert(state->d_buffer.begin(), sizeBytes, sizeBytes + 2); - dq.len = dq.len + 2; - dq.dh = reinterpret_cast(&state->d_buffer.at(0)); - state->d_buffer.resize(dq.len); bool proxyProtocolPayloadAdded = false; std::string proxyProtocolPayload; @@ -729,6 +702,7 @@ void IncomingTCPConnectionState::handleIO(std::shared_ptrd_handler.tryRead(state->d_buffer, state->d_currentPos, state->d_querySize); if (iostate == IOState::Done) { DEBUGLOG("query received"); + state->d_buffer.resize(state->d_querySize); iostate = handleQuery(state, now); // if the query has been passed to a backend, or dropped, we can start diff --git a/pdns/dnsdist-xpf.cc b/pdns/dnsdist-xpf.cc index c828aadfb0..ebffa072c8 100644 --- a/pdns/dnsdist-xpf.cc +++ b/pdns/dnsdist-xpf.cc @@ -25,7 +25,7 @@ #include "dnsparser.hh" #include "xpf.hh" -bool addXPF(DNSQuestion& dq, uint16_t optionCode, bool preserveTrailingData) +bool addXPF(DNSQuestion& dq, uint16_t optionCode) { std::string payload = generateXPFPayload(dq.tcp, *dq.remote, *dq.local); uint8_t root = '\0'; @@ -36,32 +36,24 @@ bool addXPF(DNSQuestion& dq, uint16_t optionCode, bool preserveTrailingData) drh.d_clen = htons(payload.size()); size_t recordHeaderLen = sizeof(root) + sizeof(drh); - size_t available = dq.size - dq.len; - - if ((payload.size() + recordHeaderLen) > available) { + if (!dq.hasRoomFor(payload.size() + recordHeaderLen)) { return false; } size_t xpfSize = sizeof(root) + sizeof(drh) + payload.size(); - uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast(dq.dh), dq.len); - if (realPacketLen < dq.len && preserveTrailingData) { - size_t toMove = dq.len - realPacketLen; - memmove(reinterpret_cast(dq.dh) + realPacketLen + xpfSize, reinterpret_cast(dq.dh) + realPacketLen, toMove); - dq.len += xpfSize; - } - else { - dq.len = realPacketLen + xpfSize; - } + auto& data = dq.getMutableData(); + uint32_t realPacketLen = getDNSPacketLength(reinterpret_cast(data.data()), data.size()); + data.resize(realPacketLen + xpfSize); size_t pos = realPacketLen; - memcpy(reinterpret_cast(dq.dh) + pos, &root, sizeof(root)); + memcpy(reinterpret_cast(&data.at(pos)), &root, sizeof(root)); pos += sizeof(root); - memcpy(reinterpret_cast(dq.dh) + pos, &drh, sizeof(drh)); + memcpy(reinterpret_cast(&data.at(pos)), &drh, sizeof(drh)); pos += sizeof(drh); - memcpy(reinterpret_cast(dq.dh) + pos, payload.data(), payload.size()); + memcpy(reinterpret_cast(&data.at(pos)), payload.data(), payload.size()); pos += payload.size(); - dq.dh->arcount = htons(ntohs(dq.dh->arcount) + 1); + dq.getHeader()->arcount = htons(ntohs(dq.getHeader()->arcount) + 1); return true; } diff --git a/pdns/dnsdist-xpf.hh b/pdns/dnsdist-xpf.hh index 5a1b411146..2e66f65588 100644 --- a/pdns/dnsdist-xpf.hh +++ b/pdns/dnsdist-xpf.hh @@ -23,5 +23,5 @@ #include "dnsdist.hh" -bool addXPF(DNSQuestion& dq, uint16_t optionCode, bool preserveTrailingData); +bool addXPF(DNSQuestion& dq, uint16_t optionCode); diff --git a/pdns/dnsdist.cc b/pdns/dnsdist.cc index ac8b13b0e1..4f039355da 100644 --- a/pdns/dnsdist.cc +++ b/pdns/dnsdist.cc @@ -141,48 +141,52 @@ int g_udpTimeout{2}; bool g_servFailOnNoPolicy{false}; bool g_truncateTC{false}; bool g_fixupCase{false}; -bool g_preserveTrailingData{false}; std::set g_capabilitiesToRetain; -static void truncateTC(char* packet, uint16_t* len, size_t responseSize, unsigned int consumed) -try +static size_t const s_initialUDPPacketBufferSize = s_maxPacketCacheEntrySize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; +static_assert(s_initialUDPPacketBufferSize <= UINT16_MAX, "Packet size should fit in a uint16_t"); + +static void truncateTC(std::vector& packet, unsigned int qnameWireLength) { - bool hadEDNS = false; - uint16_t payloadSize = 0; - uint16_t z = 0; + try + { + bool hadEDNS = false; + uint16_t payloadSize = 0; + uint16_t z = 0; - if (g_addEDNSToSelfGeneratedResponses) { - hadEDNS = getEDNSUDPPayloadSizeAndZ(packet, *len, &payloadSize, &z); - } + if (g_addEDNSToSelfGeneratedResponses) { + hadEDNS = getEDNSUDPPayloadSizeAndZ(reinterpret_cast(packet.data()), packet.size(), &payloadSize, &z); + } - *len=static_cast(sizeof(dnsheader)+consumed+DNS_TYPE_SIZE+DNS_CLASS_SIZE); - struct dnsheader* dh = reinterpret_cast(packet); - dh->ancount = dh->arcount = dh->nscount = 0; + packet.resize(static_cast(sizeof(dnsheader)+qnameWireLength+DNS_TYPE_SIZE+DNS_CLASS_SIZE)); + struct dnsheader* dh = reinterpret_cast(packet.data()); + dh->ancount = dh->arcount = dh->nscount = 0; - if (hadEDNS) { - addEDNS(dh, *len, responseSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0); + if (hadEDNS) { + addEDNS(packet, z & EDNS_HEADER_FLAG_DO, payloadSize, 0); + } + } + catch(...) + { + ++g_stats.truncFail; } -} -catch(...) -{ - g_stats.truncFail++; } struct DelayedPacket { int fd; - string packet; + std::vector packet; ComboAddress destination; ComboAddress origDest; void operator()() { ssize_t res; if(origDest.sin4.sin_family == 0) { - res = sendto(fd, packet.c_str(), packet.size(), 0, (struct sockaddr*)&destination, destination.getSocklen()); + res = sendto(fd, packet.data(), packet.size(), 0, (struct sockaddr*)&destination, destination.getSocklen()); } else { - res = sendfromto(fd, packet.c_str(), packet.size(), 0, origDest, destination); + res = sendfromto(fd, packet.data(), packet.size(), 0, origDest, destination); } if (res == -1) { int err = errno; @@ -195,25 +199,18 @@ DelayPipe* g_delay = nullptr; std::string DNSQuestion::getTrailingData() const { - const char* message = reinterpret_cast(this->dh); - const uint16_t messageLen = getDNSPacketLength(message, this->len); - return std::string(message + messageLen, this->len - messageLen); + const char* message = reinterpret_cast(this->getHeader()); + const uint16_t messageLen = getDNSPacketLength(message, this->data.size()); + return std::string(message + messageLen, this->getData().size() - messageLen); } bool DNSQuestion::setTrailingData(const std::string& tail) { - char* message = reinterpret_cast(this->dh); - const uint16_t messageLen = getDNSPacketLength(message, this->len); - const uint16_t tailLen = tail.size(); - if (tailLen > (this->size - messageLen)) { - vinfolog("Trailing data update failed, the new trailing data size was %d, the existing message length was %d, packet size was %d and buffer size %d", tail.size(), messageLen, this->len, this->size); - return false; - } - - /* Update length and copy data from the Lua string. */ - this->len = messageLen + tailLen; - if (tailLen > 0) { - tail.copy(message + messageLen, tailLen); + const char* message = reinterpret_cast(this->data.data()); + const uint16_t messageLen = getDNSPacketLength(message, this->data.size()); + this->data.resize(messageLen); + if (tail.size() > 0) { + this->data.insert(this->data.end(), tail.begin(), tail.end()); } return true; } @@ -238,13 +235,13 @@ void doLatencyStats(double udiff) doAvg(g_stats.latencyAvg1000000, udiff, 1000000); } -bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed) +bool responseContentMatches(const std::vector& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& qnameWireLength) { - if (responseLen < sizeof(dnsheader)) { + if (response.size() < sizeof(dnsheader)) { return false; } - const struct dnsheader* dh = reinterpret_cast(response); + const struct dnsheader* dh = reinterpret_cast(response.data()); if (dh->qr == 0) { ++g_stats.nonCompliantResponses; return false; @@ -263,10 +260,10 @@ bool responseContentMatches(const char* response, const uint16_t responseLen, co uint16_t rqtype, rqclass; DNSName rqname; try { - rqname=DNSName(response, responseLen, sizeof(dnsheader), false, &rqtype, &rqclass, &consumed); + rqname = DNSName(reinterpret_cast(response.data()), response.size(), sizeof(dnsheader), false, &rqtype, &rqclass, &qnameWireLength); } - catch(const std::exception& e) { - if(responseLen > 0 && static_cast(responseLen) > sizeof(dnsheader)) { + catch (const std::exception& e) { + if(response.size() > 0 && static_cast(response.size()) > sizeof(dnsheader)) { infolog("Backend %s sent us a response with id %d that did not parse: %s", remote.toStringWithPort(), ntohs(dh->id), e.what()); } ++g_stats.nonCompliantResponses; @@ -285,7 +282,7 @@ static void restoreFlags(struct dnsheader* dh, uint16_t origFlags) static const uint16_t rdMask = 1 << FLAGS_RD_OFFSET; static const uint16_t cdMask = 1 << FLAGS_CD_OFFSET; static const uint16_t restoreFlagsMask = UINT16_MAX & ~(rdMask | cdMask); - uint16_t * flags = getFlagsFromDNSHeader(dh); + uint16_t* flags = getFlagsFromDNSHeader(dh); /* clear the flags we are about to restore */ *flags &= restoreFlagsMask; /* only keep the flags we want to restore */ @@ -296,28 +293,28 @@ static void restoreFlags(struct dnsheader* dh, uint16_t origFlags) static bool fixUpQueryTurnedResponse(DNSQuestion& dq, const uint16_t origFlags) { - restoreFlags(dq.dh, origFlags); + restoreFlags(dq.getHeader(), origFlags); return addEDNSToQueryTurnedResponse(dq); } -static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* responseSize, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, std::vector& rewrittenResponse, uint16_t addRoom, bool* zeroScope) +static bool fixUpResponse(std::vector& response, const DNSName& qname, uint16_t origFlags, bool ednsAdded, bool ecsAdded, bool* zeroScope) { - if (*responseLen < sizeof(dnsheader)) { + if (response.size() < sizeof(dnsheader)) { return false; } - struct dnsheader* dh = reinterpret_cast(*response); + struct dnsheader* dh = reinterpret_cast(response.data()); restoreFlags(dh, origFlags); - if (*responseLen == sizeof(dnsheader)) { + if (response.size() == sizeof(dnsheader)) { return true; } if (g_fixupCase) { const auto& realname = qname.getStorage(); - if (*responseLen >= (sizeof(dnsheader) + realname.length())) { - memcpy(*response + sizeof(dnsheader), realname.c_str(), realname.length()); + if (response.size() >= (sizeof(dnsheader) + realname.length())) { + memcpy(&response.at(sizeof(dnsheader)), realname.c_str(), realname.length()); } } @@ -326,7 +323,8 @@ static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* respon size_t optLen = 0; bool last = false; - const std::string responseStr(*response, *responseLen); +#warning FIXME: save an alloc+copy + const std::string responseStr(reinterpret_cast(response.data()), response.size()); int res = locateEDNSOptRR(responseStr, &optStart, &optLen, &last); if (res == 0) { @@ -346,20 +344,17 @@ static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* respon therefore we need to remove it entirely */ if (last) { /* simply remove the last AR */ - *responseLen -= optLen; + response.resize(response.size() - optLen); + dh = reinterpret_cast(response.data()); uint16_t arcount = ntohs(dh->arcount); arcount--; dh->arcount = htons(arcount); } else { /* Removing an intermediary RR could lead to compression error */ + std::vector rewrittenResponse; if (rewriteResponseWithoutEDNS(responseStr, rewrittenResponse) == 0) { - *responseLen = rewrittenResponse.size(); - if (addRoom && (UINT16_MAX - *responseLen) > addRoom) { - rewrittenResponse.reserve(*responseLen + addRoom); - } - *responseSize = rewrittenResponse.capacity(); - *response = reinterpret_cast(rewrittenResponse.data()); + response = std::move(rewrittenResponse); } else { warnlog("Error rewriting content"); @@ -373,18 +368,14 @@ static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* respon /* nothing after the OPT RR, we can simply remove the ECS option */ size_t existingOptLen = optLen; - removeEDNSOptionFromOPT(*response + optStart, &optLen, EDNSOptionCode::ECS); - *responseLen -= (existingOptLen - optLen); + removeEDNSOptionFromOPT(reinterpret_cast(&response.at(optStart)), &optLen, EDNSOptionCode::ECS); + response.resize(response.size() - (existingOptLen - optLen)); } else { + std::vector rewrittenResponse; /* Removing an intermediary RR could lead to compression error */ if (rewriteResponseWithoutEDNSOption(responseStr, EDNSOptionCode::ECS, rewrittenResponse) == 0) { - *responseLen = rewrittenResponse.size(); - if (addRoom && (UINT16_MAX - *responseLen) > addRoom) { - rewrittenResponse.reserve(*responseLen + addRoom); - } - *responseSize = rewrittenResponse.capacity(); - *response = reinterpret_cast(rewrittenResponse.data()); + response = std::move(rewrittenResponse); } else { warnlog("Error rewriting content"); @@ -398,21 +389,12 @@ static bool fixUpResponse(char** response, uint16_t* responseLen, size_t* respon } #ifdef HAVE_DNSCRYPT -static bool encryptResponse(char* response, uint16_t* responseLen, size_t responseSize, bool tcp, std::shared_ptr dnsCryptQuery, dnsheader** dh, dnsheader* dhCopy) +static bool encryptResponse(std::vector& response, bool tcp, std::shared_ptr dnsCryptQuery) { if (dnsCryptQuery) { - uint16_t encryptedResponseLen = 0; - - /* save the original header before encrypting it in place */ - if (dh != nullptr && *dh != nullptr && dhCopy != nullptr) { - memcpy(dhCopy, *dh, sizeof(dnsheader)); - *dh = dhCopy; - } - - int res = dnsCryptQuery->encryptResponse(response, *responseLen, responseSize, tcp, &encryptedResponseLen); - if (res == 0) { - *responseLen = encryptedResponseLen; - } else { + #warning FIXME should not be harcoded + int res = dnsCryptQuery->encryptResponse(response, tcp ? std::numeric_limits::max() : 4096, tcp); + if (res != 0) { /* dropping response */ vinfolog("Error encrypting the response, dropping."); return false; @@ -441,7 +423,7 @@ static bool applyRulesToResponse(LocalStateHolderrcode = RCode::ServFail; + dr.getHeader()->rcode = RCode::ServFail; return true; break; /* non-terminal actions follow */ @@ -457,18 +439,18 @@ static bool applyRulesToResponse(LocalStateHolder >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector& rewrittenResponse, bool muted) +bool processResponse(std::vector& response, LocalStateHolder >& localRespRulactions, DNSResponse& dr, bool muted) { if (!applyRulesToResponse(localRespRulactions, dr)) { return false; } bool zeroScope = false; - if (!fixUpResponse(response, responseLen, responseSize, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, rewrittenResponse, addRoom, dr.useZeroScope ? &zeroScope : nullptr)) { + if (!fixUpResponse(response, *dr.qname, dr.origFlags, dr.ednsAdded, dr.ecsAdded, dr.useZeroScope ? &zeroScope : nullptr)) { return false; } - if (dr.packetCache && !dr.skipCache && *responseLen <= s_maxPacketCacheEntrySize) { + if (dr.packetCache && !dr.skipCache && response.size() <= s_maxPacketCacheEntrySize) { if (!dr.useZeroScope) { /* if the query was not suitable for zero-scope, for example because it had an existing ECS entry so the hash is @@ -481,12 +463,12 @@ bool processResponse(char** response, uint16_t* responseLen, size_t* responseSiz zeroScope = false; } // if zeroScope, pass the pre-ECS hash-key and do not pass the subnet to the cache - dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.origFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, *response, *responseLen, dr.tcp, dr.dh->rcode, dr.tempFailureTTL); + dr.packetCache->insert(zeroScope ? dr.cacheKeyNoECS : dr.cacheKey, zeroScope ? boost::none : dr.subnet, dr.origFlags, dr.dnssecOK, *dr.qname, dr.qtype, dr.qclass, response, dr.tcp, dr.getHeader()->rcode, dr.tempFailureTTL); } #ifdef HAVE_DNSCRYPT if (!muted) { - if (!encryptResponse(*response, responseLen, *responseSize, dr.tcp, dr.dnsCryptQuery, nullptr, nullptr)) { + if (!encryptResponse(response, dr.tcp, dr.dnsCryptQuery)) { return false; } } @@ -495,19 +477,19 @@ bool processResponse(char** response, uint16_t* responseLen, size_t* responseSiz return true; } -static bool sendUDPResponse(int origFD, const char* response, const uint16_t responseLen, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote) +static bool sendUDPResponse(int origFD, const std::vector& response, const int delayMsec, const ComboAddress& origDest, const ComboAddress& origRemote) { if(delayMsec && g_delay) { - DelayedPacket dp{origFD, string(response,responseLen), origRemote, origDest}; + DelayedPacket dp{origFD, response, origRemote, origDest}; g_delay->submit(dp, delayMsec); } else { ssize_t res; if(origDest.sin4.sin_family == 0) { - res = sendto(origFD, response, responseLen, 0, reinterpret_cast(&origRemote), origRemote.getSocklen()); + res = sendto(origFD, response.data(), response.size(), 0, reinterpret_cast(&origRemote), origRemote.getSocklen()); } else { - res = sendfromto(origFD, response, responseLen, 0, origDest, origRemote); + res = sendfromto(origFD, response.data(), response.size(), 0, origDest, origRemote); } if (res == -1) { int err = errno; @@ -541,22 +523,20 @@ static void pickBackendSocketsReadyForReceiving(const std::shared_ptr dss) -try { +{ + try { setThreadName("dnsdist/respond"); auto localRespRulactions = g_resprulactions.getLocal(); - char packet[s_maxPacketCacheEntrySize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE]; - static_assert(sizeof(packet) <= UINT16_MAX, "Packet size should fit in a uint16_t"); + std::vector response(s_initialUDPPacketBufferSize); + /* when the answer is encrypted in place, we need to get a copy of the original header before encryption to fill the ring buffer */ dnsheader cleartextDH; - vector rewrittenResponse; - uint16_t queryId = 0; std::vector sockets; sockets.reserve(dss->sockets.size()); for(;;) { - dnsheader* dh = reinterpret_cast(packet); try { pickBackendSocketsReadyForReceiving(dss, sockets); if (dss->isStopped()) { @@ -564,9 +544,8 @@ try { } for (const auto& fd : sockets) { - ssize_t got = recv(fd, packet, sizeof(packet), 0); - char * response = packet; - size_t responseSize = sizeof(packet); + response.resize(s_initialUDPPacketBufferSize); + ssize_t got = recv(fd, response.data(), response.size(), 0); if (got == 0 && dss->isStopped()) { break; @@ -576,17 +555,18 @@ try { continue; } - uint16_t responseLen = static_cast(got); + response.resize(static_cast(got)); + dnsheader* dh = reinterpret_cast(response.data()); queryId = dh->id; - if(queryId >= dss->idStates.size()) { + if (queryId >= dss->idStates.size()) { continue; } IDState* ids = &dss->idStates[queryId]; int64_t usageIndicator = ids->usageIndicator; - if(!IDState::isInUse(usageIndicator)) { + if (!IDState::isInUse(usageIndicator)) { /* the corresponding state is marked as not in use, meaning that: - it was already cleaned up by another thread and the state is gone ; - we already got a response for this query and this one is a duplicate. @@ -604,8 +584,8 @@ try { ids->age = 0; int origFD = ids->origFD; - unsigned int consumed = 0; - if (!responseContentMatches(response, responseLen, ids->qname, ids->qtype, ids->qclass, dss->remote, consumed)) { + unsigned int qnameWireLength = 0; + if (!responseContentMatches(response, ids->qname, ids->qtype, ids->qclass, dss->remote, qnameWireLength)) { continue; } @@ -630,20 +610,16 @@ try { continue; } - if(dh->tc && g_truncateTC) { - truncateTC(response, &responseLen, responseSize, consumed); + if (dh->tc && g_truncateTC) { + truncateTC(response, qnameWireLength); } dh->id = ids->origID; - uint16_t addRoom = 0; - DNSResponse dr = makeDNSResponseFromIDState(*ids, dh, sizeof(packet), responseLen, false); - if (dr.dnsCryptQuery) { - addRoom = DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE; - } + DNSResponse dr = makeDNSResponseFromIDState(*ids, response, false); + memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH)); - memcpy(&cleartextDH, dr.dh, sizeof(cleartextDH)); - if (!processResponse(&response, &responseLen, &responseSize, localRespRulactions, dr, addRoom, rewrittenResponse, ids->cs && ids->cs->muted)) { + if (!processResponse(response, localRespRulactions, dr, ids->cs && ids->cs->muted)) { continue; } @@ -651,7 +627,7 @@ try { if (du) { #ifdef HAVE_DNS_OVER_HTTPS // DoH query - du->response = std::string(response, responseLen); + du->response = std::move(response); static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); ssize_t sent = write(du->rsock, &du, sizeof(du)); if (sent != sizeof(du)) { @@ -677,7 +653,7 @@ try { empty.sin4.sin_family = 0; /* if ids->destHarvested is false, origDest holds the listening address. We don't want to use that as a source since it could be 0.0.0.0 for example. */ - sendUDPResponse(origFD, response, responseLen, dr.delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote); + sendUDPResponse(origFD, response, dr.delayMsec, ids->destHarvested ? ids->origDest : empty, ids->origRemote); } } @@ -710,27 +686,26 @@ try { dss->latencyUsec = (127.0 * dss->latencyUsec / 128.0) + udiff/128.0; doLatencyStats(udiff); - - rewrittenResponse.clear(); } } - catch(const std::exception& e){ + catch (const std::exception& e){ vinfolog("Got an error in UDP responder thread while parsing a response from %s, id %d: %s", dss->remote.toStringWithPort(), queryId, e.what()); } } } -catch(const std::exception& e) +catch (const std::exception& e) { errlog("UDP responder thread died because of exception: %s", e.what()); } -catch(const PDNSException& e) +catch (const PDNSException& e) { errlog("UDP responder thread died because of PowerDNS exception: %s", e.reason); } -catch(...) +catch (...) { errlog("UDP responder thread died because of an exception: %s", "unknown"); } +} std::mutex g_luamutex; LuaContext g_lua; @@ -787,20 +762,20 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return true; break; case DNSAction::Action::Nxdomain: - dq.dh->rcode = RCode::NXDomain; - dq.dh->qr=true; + dq.getHeader()->rcode = RCode::NXDomain; + dq.getHeader()->qr=true; ++g_stats.ruleNXDomain; return true; break; case DNSAction::Action::Refused: - dq.dh->rcode = RCode::Refused; - dq.dh->qr=true; + dq.getHeader()->rcode = RCode::Refused; + dq.getHeader()->qr=true; ++g_stats.ruleRefused; return true; break; case DNSAction::Action::ServFail: - dq.dh->rcode = RCode::ServFail; - dq.dh->qr=true; + dq.getHeader()->rcode = RCode::ServFail; + dq.getHeader()->qr=true; ++g_stats.ruleServFail; return true; break; @@ -813,11 +788,11 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return true; break; case DNSAction::Action::Truncate: - dq.dh->tc = true; - dq.dh->qr = true; - dq.dh->ra = dq.dh->rd; - dq.dh->aa = false; - dq.dh->ad = false; + dq.getHeader()->tc = true; + dq.getHeader()->qr = true; + dq.getHeader()->ra = dq.getHeader()->rd; + dq.getHeader()->aa = false; + dq.getHeader()->ad = false; return true; break; case DNSAction::Action::HeaderModify: @@ -828,7 +803,7 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s return true; break; case DNSAction::Action::NoRecurse: - dq.dh->rd = false; + dq.getHeader()->rd = false; return true; break; /* non-terminal actions follow */ @@ -848,7 +823,7 @@ bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::s static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const struct timespec& now) { - g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.len, *dq.dh); + g_rings.insertQuery(now, *dq.remote, *dq.qname, dq.qtype, dq.getData().size(), *dq.getHeader()); if(g_qcount.enabled) { string qname = (*dq.qname).toLogString(); @@ -887,27 +862,27 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru vinfolog("Query from %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort()); updateBlockStats(); - dq.dh->rcode = RCode::NXDomain; - dq.dh->qr=true; + dq.getHeader()->rcode = RCode::NXDomain; + dq.getHeader()->qr=true; return true; case DNSAction::Action::Refused: vinfolog("Query from %s refused because of dynamic block", dq.remote->toStringWithPort()); updateBlockStats(); - dq.dh->rcode = RCode::Refused; - dq.dh->qr = true; + dq.getHeader()->rcode = RCode::Refused; + dq.getHeader()->qr = true; return true; case DNSAction::Action::Truncate: if(!dq.tcp) { updateBlockStats(); vinfolog("Query from %s truncated because of dynamic block", dq.remote->toStringWithPort()); - dq.dh->tc = true; - dq.dh->qr = true; - dq.dh->ra = dq.dh->rd; - dq.dh->aa = false; - dq.dh->ad = false; + dq.getHeader()->tc = true; + dq.getHeader()->qr = true; + dq.getHeader()->ra = dq.getHeader()->rd; + dq.getHeader()->aa = false; + dq.getHeader()->ad = false; return true; } else { @@ -917,7 +892,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::NoRecurse: updateBlockStats(); vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort()); - dq.dh->rd = false; + dq.getHeader()->rd = false; return true; default: updateBlockStats(); @@ -946,26 +921,26 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru vinfolog("Query from %s for %s turned into NXDomain because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); updateBlockStats(); - dq.dh->rcode = RCode::NXDomain; - dq.dh->qr=true; + dq.getHeader()->rcode = RCode::NXDomain; + dq.getHeader()->qr=true; return true; case DNSAction::Action::Refused: vinfolog("Query from %s for %s refused because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); updateBlockStats(); - dq.dh->rcode = RCode::Refused; - dq.dh->qr=true; + dq.getHeader()->rcode = RCode::Refused; + dq.getHeader()->qr=true; return true; case DNSAction::Action::Truncate: if(!dq.tcp) { updateBlockStats(); vinfolog("Query from %s for %s truncated because of dynamic block", dq.remote->toStringWithPort(), dq.qname->toLogString()); - dq.dh->tc = true; - dq.dh->qr = true; - dq.dh->ra = dq.dh->rd; - dq.dh->aa = false; - dq.dh->ad = false; + dq.getHeader()->tc = true; + dq.getHeader()->qr = true; + dq.getHeader()->ra = dq.getHeader()->rd; + dq.getHeader()->aa = false; + dq.getHeader()->ad = false; return true; } else { @@ -975,7 +950,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru case DNSAction::Action::NoRecurse: updateBlockStats(); vinfolog("Query from %s setting rd=0 because of dynamic block", dq.remote->toStringWithPort()); - dq.dh->rd = false; + dq.getHeader()->rd = false; return true; default: updateBlockStats(); @@ -1005,19 +980,19 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, const stru return true; } -ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck) +ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const std::vector& request, bool healthCheck) { ssize_t result; if (ss->sourceItf == 0) { - result = send(sd, request, requestLen, 0); + result = send(sd, request.data(), request.size(), 0); } else { struct msghdr msgh; struct iovec iov; cmsgbuf_aligned cbuf; ComboAddress remote(ss->remote); - fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), const_cast(request), requestLen, &remote); + fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), const_cast(reinterpret_cast(request.data())), request.size(), &remote); addCMsgSrcAddr(&msgh, &cbuf, &ss->sourceAddr, ss->sourceItf); result = sendmsg(sd, &msgh, 0); } @@ -1068,28 +1043,25 @@ static bool isUDPQueryAcceptable(ClientState& cs, LocalHolders& holders, const s return true; } -boost::optional> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr& dnsCryptQuery, time_t now, bool tcp) +bool checkDNSCryptQuery(const ClientState& cs, std::vector& query, std::shared_ptr& dnsCryptQuery, time_t now, bool tcp) { if (cs.dnscryptCtx) { #ifdef HAVE_DNSCRYPT vector response; - uint16_t decryptedQueryLen = 0; - dnsCryptQuery = std::make_shared(cs.dnscryptCtx); - bool decrypted = handleDNSCryptQuery(const_cast(query), len, dnsCryptQuery, &decryptedQueryLen, tcp, now, response); + bool decrypted = handleDNSCryptQuery(query, dnsCryptQuery, tcp, now, response); if (!decrypted) { if (response.size() > 0) { - return response; + query = std::move(response); + return true; } throw std::runtime_error("Unable to decrypt DNSCrypt query, dropping."); } - - len = decryptedQueryLen; #endif /* HAVE_DNSCRYPT */ } - return boost::none; + return false; } bool checkQueryHeaders(const struct dnsheader* dh) @@ -1112,10 +1084,10 @@ bool checkQueryHeaders(const struct dnsheader* dh) } #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) -static void queueResponse(const ClientState& cs, const char* response, uint16_t responseLen, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, cmsgbuf_aligned* cbuf) +static void queueResponse(const ClientState& cs, const std::vector& response, const ComboAddress& dest, const ComboAddress& remote, struct mmsghdr& outMsg, struct iovec* iov, cmsgbuf_aligned* cbuf) { outMsg.msg_len = 0; - fillMSGHdr(&outMsg.msg_hdr, iov, nullptr, 0, const_cast(response), responseLen, const_cast(&remote)); + fillMSGHdr(&outMsg.msg_hdr, iov, nullptr, 0, const_cast(reinterpret_cast(&response.at(0))), response.size(), const_cast(&remote)); if (dest.sin4.sin_family == 0) { outMsg.msg_hdr.msg_control = nullptr; @@ -1129,7 +1101,7 @@ static void queueResponse(const ClientState& cs, const char* response, uint16_t /* self-generated responses or cache hits */ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQuestion& dq, bool cacheHit) { - DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.consumed, dq.local, dq.remote, reinterpret_cast(dq.dh), dq.size, dq.len, dq.tcp, dq.queryTime); + DNSResponse dr(dq.qname, dq.qtype, dq.qclass, dq.local, dq.remote, dq.getMutableData(), dq.tcp, dq.queryTime); dr.uniqueId = dq.uniqueId; dr.qTag = dq.qTag; @@ -1144,7 +1116,7 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ #ifdef HAVE_DNSCRYPT if (!cs.muted) { - if (!encryptResponse(reinterpret_cast(dq.dh), &dq.len, dq.size, dq.tcp, dq.dnsCryptQuery, nullptr, nullptr)) { + if (!encryptResponse(dq.getMutableData(), dq.tcp, dq.dnsCryptQuery)) { return false; } } @@ -1154,7 +1126,7 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ ++g_stats.cacheHits; } - switch (dr.dh->rcode) { + switch (dr.getHeader()->rcode) { case RCode::NXDomain: ++g_stats.frontendNXDomain; break; @@ -1172,7 +1144,7 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend) { - const uint16_t queryId = ntohs(dq.dh->id); + const uint16_t queryId = ntohs(dq.getHeader()->id); try { /* we need an accurate ("real") value for the response and @@ -1185,7 +1157,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& return ProcessQueryResult::Drop; } - if(dq.dh->qr) { // something turned it into a response + if (dq.getHeader()->qr) { // something turned it into a response fixUpQueryTurnedResponse(dq, dq.origFlags); if (!prepareOutgoingResponse(holders, cs, dq, false)) { @@ -1204,7 +1176,6 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& const auto servers = serverPool->getServers(); selectedBackend = policy.getSelectedBackend(*servers, dq); - uint16_t cachedResponseSize = dq.size; uint32_t allowExpired = selectedBackend ? 0 : g_staleCacheEntriesTTL; if (dq.packetCache && !dq.skipCache) { @@ -1216,8 +1187,7 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& // we need ECS parsing (parseECS) to be true so we can be sure that the initial incoming query did not have an existing // ECS option, which would make it unsuitable for the zero-scope feature. if (dq.packetCache && !dq.skipCache && (!selectedBackend || !selectedBackend->disableZeroScope) && dq.packetCache->isECSParsingEnabled()) { - if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast(dq.dh), &cachedResponseSize, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) { - dq.len = cachedResponseSize; + if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKeyNoECS, dq.subnet, dq.dnssecOK, allowExpired)) { if (!prepareOutgoingResponse(holders, cs, dq, true)) { return ProcessQueryResult::Drop; @@ -1232,15 +1202,14 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& } } - if (!handleEDNSClientSubnet(dq, dq.ednsAdded, dq.ecsAdded, g_preserveTrailingData)) { + if (!handleEDNSClientSubnet(dq, dq.ednsAdded, dq.ecsAdded)) { vinfolog("Dropping query from %s because we couldn't insert the ECS value", dq.remote->toStringWithPort()); return ProcessQueryResult::Drop; } } if (dq.packetCache && !dq.skipCache) { - if (dq.packetCache->get(dq, dq.consumed, dq.dh->id, reinterpret_cast(dq.dh), &cachedResponseSize, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) { - dq.len = cachedResponseSize; + if (dq.packetCache->get(dq, dq.getHeader()->id, &dq.cacheKey, dq.subnet, dq.dnssecOK, allowExpired)) { if (!prepareOutgoingResponse(holders, cs, dq, true)) { return ProcessQueryResult::Drop; @@ -1251,15 +1220,15 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& ++g_stats.cacheMisses; } - if(!selectedBackend) { + if (!selectedBackend) { ++g_stats.noPolicy; vinfolog("%s query for %s|%s from %s, no policy applied", g_servFailOnNoPolicy ? "ServFailed" : "Dropped", dq.qname->toLogString(), QType(dq.qtype).getName(), dq.remote->toStringWithPort()); if (g_servFailOnNoPolicy) { - restoreFlags(dq.dh, dq.origFlags); + restoreFlags(dq.getHeader(), dq.origFlags); - dq.dh->rcode = RCode::ServFail; - dq.dh->qr = true; + dq.getHeader()->rcode = RCode::ServFail; + dq.getHeader()->qr = true; if (!prepareOutgoingResponse(holders, cs, dq, false)) { return ProcessQueryResult::Drop; @@ -1272,19 +1241,19 @@ ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& } if (dq.addXPF && selectedBackend->xpfRRCode != 0) { - addXPF(dq, selectedBackend->xpfRRCode, g_preserveTrailingData); + addXPF(dq, selectedBackend->xpfRRCode); } selectedBackend->queries++; return ProcessQueryResult::PassToBackend; } - catch(const std::exception& e){ + catch (const std::exception& e){ vinfolog("Got an error while parsing a %s query from %s, id %d: %s", (dq.tcp ? "TCP" : "UDP"), dq.remote->toStringWithPort(), queryId, e.what()); } return ProcessQueryResult::Drop; } -static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, char* query, uint16_t len, size_t queryBufferSize, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, cmsgbuf_aligned* respCBuf) +static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct msghdr* msgh, const ComboAddress& remote, ComboAddress& dest, std::vector& query, struct mmsghdr* responsesVect, unsigned int* queuedResponses, struct iovec* respIOV, cmsgbuf_aligned* respCBuf) { assert(responsesVect == nullptr || (queuedResponses != nullptr && respIOV != nullptr && respCBuf != nullptr)); uint16_t queryId = 0; @@ -1301,13 +1270,13 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct gettime(&queryRealTime, true); std::shared_ptr dnsCryptQuery = nullptr; - auto dnsCryptResponse = checkDNSCryptQuery(cs, query, len, dnsCryptQuery, queryRealTime.tv_sec, false); + auto dnsCryptResponse = checkDNSCryptQuery(cs, query, dnsCryptQuery, queryRealTime.tv_sec, false); if (dnsCryptResponse) { - sendUDPResponse(cs.udpFD, reinterpret_cast(dnsCryptResponse->data()), static_cast(dnsCryptResponse->size()), 0, dest, remote); + sendUDPResponse(cs.udpFD, query, 0, dest, remote); return; } - struct dnsheader* dh = reinterpret_cast(query); + struct dnsheader* dh = reinterpret_cast(query.data()); queryId = ntohs(dh->id); if (!checkQueryHeaders(dh)) { @@ -1315,9 +1284,9 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } uint16_t qtype, qclass; - unsigned int consumed = 0; - DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, dh, queryBufferSize, len, false, &queryRealTime); + unsigned int qnameWireLength = 0; + DNSName qname(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); + DNSQuestion dq(&qname, qtype, qclass, dest.sin4.sin_family != 0 ? &dest : &cs.local, &remote, query, false, &queryRealTime); dq.dnsCryptQuery = std::move(dnsCryptQuery); std::shared_ptr ss{nullptr}; auto result = processQuery(dq, cs, holders, ss); @@ -1326,16 +1295,18 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct return; } + // the buffer might have been invalidated by now (resized) + dh = dq.getHeader(); if (result == ProcessQueryResult::SendAnswer) { #if defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) if (dq.delayMsec == 0 && responsesVect != nullptr) { - queueResponse(cs, reinterpret_cast(dq.dh), dq.len, *dq.local, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf); + queueResponse(cs, query, *dq.local, *dq.remote, responsesVect[*queuedResponses], respIOV, respCBuf); (*queuedResponses)++; return; } #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ /* we use dest, always, because we don't want to use the listening address to send a response since it could be 0.0.0.0 */ - sendUDPResponse(cs.udpFD, reinterpret_cast(dq.dh), dq.len, dq.delayMsec, dest, *dq.remote); + sendUDPResponse(cs.udpFD, query, dq.delayMsec, dest, *dq.remote); return; } @@ -1391,6 +1362,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct ids->destHarvested = false; } + dh = dq.getHeader(); dh->id = idOffset; if (ss->useProxyProtocol) { @@ -1398,7 +1370,7 @@ static void processUDPQuery(ClientState& cs, LocalHolders& holders, const struct } int fd = pickBackendSocketForSending(ss); - ssize_t ret = udpClientSendRequestToBackend(ss, fd, query, dq.len); + ssize_t ret = udpClientSendRequestToBackend(ss, fd, query); if(ret < 0) { ++ss->sendErrors; @@ -1417,7 +1389,7 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde { struct MMReceiver { - char packet[s_maxPacketCacheEntrySize]; + std::vector packet; ComboAddress remote; ComboAddress dest; struct iovec iov; @@ -1430,7 +1402,7 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde - we use it for self-generated responses (from rule or cache) but we only accept incoming payloads up to that size */ - static_assert(s_udpIncomingBufferSize <= sizeof(MMReceiver::packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)"); + static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "the incoming buffer size should not be larger than s_initialUDPPacketBufferSize"); auto recvData = std::unique_ptr(new MMReceiver[vectSize]); auto msgVec = std::unique_ptr(new struct mmsghdr[vectSize]); @@ -1439,7 +1411,8 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde /* initialize the structures needed to receive our messages */ for (size_t idx = 0; idx < vectSize; idx++) { recvData[idx].remote.sin4.sin_family = cs->local.sin4.sin_family; - fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, &recvData[idx].cbuf, sizeof(recvData[idx].cbuf), recvData[idx].packet, cs->dnscryptCtx ? sizeof(recvData[idx].packet) : s_udpIncomingBufferSize, &recvData[idx].remote); + recvData[idx].packet.resize(s_initialUDPPacketBufferSize); + fillMSGHdr(&msgVec[idx].msg_hdr, &recvData[idx].iov, &recvData[idx].cbuf, sizeof(recvData[idx].cbuf), reinterpret_cast(&recvData[idx].packet.at(0)), cs->dnscryptCtx ? recvData[idx].packet.size() : s_udpIncomingBufferSize, &recvData[idx].remote); } /* go now */ @@ -1448,8 +1421,9 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde /* reset the IO vector, since it's also used to send the vector of responses to avoid having to copy the data around */ for (size_t idx = 0; idx < vectSize; idx++) { - recvData[idx].iov.iov_base = recvData[idx].packet; - recvData[idx].iov.iov_len = sizeof(recvData[idx].packet); + recvData[idx].packet.resize(s_initialUDPPacketBufferSize); + recvData[idx].iov.iov_base = &recvData[idx].packet.at(0); + recvData[idx].iov.iov_len = recvData[idx].packet.size(); } /* block until we have at least one message ready, but return @@ -1474,8 +1448,8 @@ static void MultipleMessagesUDPClientThread(ClientState* cs, LocalHolders& holde continue; } - processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, static_cast(got), sizeof(recvData[msgIdx].packet), outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, &recvData[msgIdx].cbuf); - + recvData[msgIdx].packet.resize(got); + processUDPQuery(*cs, holders, msgh, remote, recvData[msgIdx].dest, recvData[msgIdx].packet, outMsgVec.get(), &msgsToSend, &recvData[msgIdx].iov, &recvData[msgIdx].cbuf); } /* immediate (not delayed or sent to a backend) responses (mostly from a rule, dynamic block @@ -1508,13 +1482,13 @@ try else #endif /* defined(HAVE_RECVMMSG) && defined(HAVE_SENDMMSG) && defined(MSG_WAITFORONE) */ { - char packet[s_maxPacketCacheEntrySize]; + std::vector packet(s_initialUDPPacketBufferSize); /* the actual buffer is larger because: - we may have to add EDNS and/or ECS - we use it for self-generated responses (from rule or cache) but we only accept incoming payloads up to that size */ - static_assert(s_udpIncomingBufferSize <= sizeof(packet), "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)"); + static_assert(s_udpIncomingBufferSize <= s_initialUDPPacketBufferSize, "the incoming buffer size should not be larger than sizeof(MMReceiver::packet)"); struct msghdr msgh; struct iovec iov; /* used by HarvestDestinationAddress */ @@ -1523,9 +1497,13 @@ try ComboAddress remote; ComboAddress dest; remote.sin4.sin_family = cs->local.sin4.sin_family; - fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), packet, cs->dnscryptCtx ? sizeof(packet) : s_udpIncomingBufferSize, &remote); + fillMSGHdr(&msgh, &iov, &cbuf, sizeof(cbuf), reinterpret_cast(&packet.at(0)), cs->dnscryptCtx ? packet.size() : s_udpIncomingBufferSize, &remote); for(;;) { + packet.resize(s_initialUDPPacketBufferSize); + iov.iov_base = &packet.at(0); + iov.iov_len = packet.size(); + ssize_t got = recvmsg(cs->udpFD, &msgh, 0); if (got < 0 || static_cast(got) < sizeof(struct dnsheader)) { @@ -1533,7 +1511,8 @@ try continue; } - processUDPQuery(*cs, holders, &msgh, remote, dest, packet, static_cast(got), sizeof(packet), nullptr, nullptr, nullptr, nullptr); + packet.resize(static_cast(got)); + processUDPQuery(*cs, holders, &msgh, remote, dest, packet, nullptr, nullptr, nullptr, nullptr); } } } diff --git a/pdns/dnsdist.hh b/pdns/dnsdist.hh index 83b2a4d225..7ab0be192b 100644 --- a/pdns/dnsdist.hh +++ b/pdns/dnsdist.hh @@ -62,9 +62,9 @@ typedef std::unordered_map QTag; struct DNSQuestion { - DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t queryLen, bool isTcp, const struct timespec* queryTime_): - qname(name), local(lc), remote(rem), dh(header), queryTime(queryTime_), size(bufferSize), consumed(consumed_), tempFailureTTL(boost::none), qtype(type), qclass(class_), len(queryLen), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tcp(isTcp), ecsOverride(g_ECSOverride) { - const uint16_t* flags = getFlagsFromDNSHeader(dh); + DNSQuestion(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, std::vector& data_, bool isTcp, const struct timespec* queryTime_): + data(data_), qname(name), local(lc), remote(rem), queryTime(queryTime_), tempFailureTTL(boost::none), qtype(type), qclass(class_), ecsPrefixLength(rem->sin4.sin_family == AF_INET ? g_ECSSourcePrefixV4 : g_ECSSourcePrefixV6), tcp(isTcp), ecsOverride(g_ECSOverride) { + const uint16_t* flags = getFlagsFromDNSHeader(getHeader()); origFlags = *flags; } DNSQuestion(const DNSQuestion&) = delete; @@ -73,7 +73,48 @@ struct DNSQuestion std::string getTrailingData() const; bool setTrailingData(const std::string&); + const std::vector& getData() const + { + return data; + } + std::vector& getMutableData() + { + return data; + } + + dnsheader* getHeader() + { + if (data.size() < sizeof(dnsheader)) { + throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer"); + } + return reinterpret_cast(&data.at(0)); + } + + const dnsheader* getHeader() const + { + if (data.size() < sizeof(dnsheader)) { + throw std::runtime_error("Trying to access the dnsheader of a too small (" + std::to_string(data.size()) + ") DNSQuestion buffer"); + } + return reinterpret_cast(&data.at(0)); + } + bool hasRoomFor(size_t more) const + { + return data.size() <= getMaximumSize() && (getMaximumSize() - data.size()) >= more; + } + + size_t getMaximumSize() const + { + if (tcp) { + return std::numeric_limits::max(); + } + return 4096; + } + +protected: + std::vector& data; + +public: boost::optional uniqueId; Netmask ecs; boost::optional subnet; @@ -87,18 +128,14 @@ struct DNSQuestion mutable std::shared_ptr > ednsOptions; std::shared_ptr dnsCryptQuery{nullptr}; std::shared_ptr packetCache{nullptr}; - struct dnsheader* dh{nullptr}; const struct timespec* queryTime{nullptr}; struct DOHUnit* du{nullptr}; - size_t size; - unsigned int consumed{0}; int delayMsec{0}; boost::optional tempFailureTTL; uint32_t cacheKeyNoECS; uint32_t cacheKey; const uint16_t qtype; const uint16_t qclass; - uint16_t len; uint16_t ecsPrefixLength; uint16_t origFlags; uint8_t ednsRCode{0}; @@ -116,8 +153,8 @@ struct DNSQuestion struct DNSResponse : DNSQuestion { - DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed_, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t responseLen, bool isTcp, const struct timespec* queryTime_): - DNSQuestion(name, type, class_, consumed_, lc, rem, header, bufferSize, responseLen, isTcp, queryTime_) { } + DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, const ComboAddress* lc, const ComboAddress* rem, std::vector& data_, bool isTcp, const struct timespec* queryTime_): + DNSQuestion(name, type, class_, lc, rem, data_, isTcp, queryTime_) { } DNSResponse(const DNSResponse&) = delete; DNSResponse& operator=(const DNSResponse&) = delete; DNSResponse(DNSResponse&&) = default; @@ -1140,7 +1177,6 @@ extern bool g_servFailOnNoPolicy; extern bool g_useTCPSinglePipe; extern uint16_t g_downstreamTCPCleanupInterval; extern size_t g_udpVectorSize; -extern bool g_preserveTrailingData; extern bool g_allowEmptyResponse; extern shared_ptr g_defaultBPFFilter; @@ -1175,17 +1211,15 @@ void setLuaSideEffect(); // set to report a side effect, cancelling all _no_ s bool getLuaNoSideEffect(); // set if there were only explicit declarations of _no_ side effect void resetLuaSideEffect(); // reset to indeterminate state -bool responseContentMatches(const char* response, const uint16_t responseLen, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& consumed); -bool processResponse(char** response, uint16_t* responseLen, size_t* responseSize, LocalStateHolder >& localRespRulactions, DNSResponse& dr, size_t addRoom, std::vector& rewrittenResponse, bool muted); +bool responseContentMatches(const std::vector& response, const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& remote, unsigned int& qnameWireLength); +bool processResponse(std::vector& response, LocalStateHolder >& localRespRulactions, DNSResponse& dr, bool muted); bool processRulesResult(const DNSAction::Action& action, DNSQuestion& dq, std::string& ruleresult, bool& drop); bool checkQueryHeaders(const struct dnsheader* dh); extern std::vector> g_dnsCryptLocals; -int handleDNSCryptQuery(char* packet, uint16_t len, std::shared_ptr query, uint16_t* decryptedQueryLen, bool tcp, time_t now, std::vector& response); -boost::optional> checkDNSCryptQuery(const ClientState& cs, const char* query, uint16_t& len, std::shared_ptr& dnsCryptQuery, time_t now, bool tcp); - -bool addXPF(DNSQuestion& dq, uint16_t optionCode); +int handleDNSCryptQuery(std::vector& packet, std::shared_ptr& query, bool tcp, time_t now, std::vector& response); +bool checkDNSCryptQuery(const ClientState& cs, std::vector& query, std::shared_ptr& dnsCryptQuery, time_t now, bool tcp); uint16_t getRandomDNSID(); @@ -1203,8 +1237,8 @@ static const size_t s_maxPacketCacheEntrySize{4096}; // don't cache responses la enum class ProcessQueryResult { Drop, SendAnswer, PassToBackend }; ProcessQueryResult processQuery(DNSQuestion& dq, ClientState& cs, LocalHolders& holders, std::shared_ptr& selectedBackend); -DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP); +DNSResponse makeDNSResponseFromIDState(IDState& ids, std::vector& data, bool isTCP); void setIDStateFromDNSQuestion(IDState& ids, DNSQuestion& dq, DNSName&& qname); int pickBackendSocketForSending(std::shared_ptr& state); -ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const char* request, const size_t requestLen, bool healthCheck=false); +ssize_t udpClientSendRequestToBackend(const std::shared_ptr& ss, const int sd, const std::vector& request, bool healthCheck = false); diff --git a/pdns/dnsdistdist/dnsdist-healthchecks.cc b/pdns/dnsdistdist/dnsdist-healthchecks.cc index 6b78c3d3c4..e66a45ea26 100644 --- a/pdns/dnsdistdist/dnsdist-healthchecks.cc +++ b/pdns/dnsdistdist/dnsdist-healthchecks.cc @@ -227,7 +227,7 @@ bool queueHealthCheck(std::shared_ptr& mplexer, const std::shared sock.bind(ds->sourceAddr); } sock.connect(ds->remote); - ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), reinterpret_cast(&packet[0]), packet.size(), true); + ssize_t sent = udpClientSendRequestToBackend(ds, sock.getHandle(), packet, true); if (sent < 0) { int ret = errno; if (g_verboseHealthChecks) diff --git a/pdns/dnsdistdist/dnsdist-idstate.cc b/pdns/dnsdistdist/dnsdist-idstate.cc index db22f3de0a..9db149ada4 100644 --- a/pdns/dnsdistdist/dnsdist-idstate.cc +++ b/pdns/dnsdistdist/dnsdist-idstate.cc @@ -1,9 +1,9 @@ #include "dnsdist.hh" -DNSResponse makeDNSResponseFromIDState(IDState& ids, struct dnsheader* dh, size_t bufferSize, uint16_t responseLen, bool isTCP) +DNSResponse makeDNSResponseFromIDState(IDState& ids, std::vector& data, bool isTCP) { - DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, ids.qname.wirelength(), &ids.origDest, &ids.origRemote, dh, bufferSize, responseLen, isTCP, &ids.sentTime.d_start); + DNSResponse dr(&ids.qname, ids.qtype, ids.qclass, &ids.origDest, &ids.origRemote, data, isTCP, &ids.sentTime.d_start); dr.origFlags = ids.origFlags; dr.ecsAdded = ids.ecsAdded; dr.ednsAdded = ids.ednsAdded; diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h b/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h index efa1c40dd0..cdf0c77bdd 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h +++ b/pdns/dnsdistdist/dnsdist-lua-ffi-interface.h @@ -90,7 +90,7 @@ void dnsdist_ffi_dnsquestion_set_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq, void dnsdist_ffi_dnsquestion_unset_temp_failure_ttl(dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default"))); void dnsdist_ffi_dnsquestion_set_tag(dnsdist_ffi_dnsquestion_t* dq, const char* label, const char* value) __attribute__ ((visibility ("default"))); -void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, const char* contentType) __attribute__ ((visibility ("default"))); +void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType) __attribute__ ((visibility ("default"))); size_t dnsdist_ffi_dnsquestion_get_trailing_data(dnsdist_ffi_dnsquestion_t* dq, const char** out) __attribute__ ((visibility ("default"))); diff --git a/pdns/dnsdistdist/dnsdist-lua-ffi.cc b/pdns/dnsdistdist/dnsdist-lua-ffi.cc index 1cc6cf2cbc..1ed1866e37 100644 --- a/pdns/dnsdistdist/dnsdist-lua-ffi.cc +++ b/pdns/dnsdistdist/dnsdist-lua-ffi.cc @@ -86,27 +86,28 @@ size_t dnsdist_ffi_dnsquestion_get_qname_hash(const dnsdist_ffi_dnsquestion_t* d int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->dh->rcode; + return dq->dq->getHeader()->rcode; } void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->dh; + return dq->dq->getHeader(); } uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->len; + return dq->dq->getData().size(); } +#warning FIXME : we need to provide a way to resize size_t dnsdist_ffi_dnsquestion_get_size(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->size; + return dq->dq->getData().size(); } uint8_t dnsdist_ffi_dnsquestion_get_opcode(const dnsdist_ffi_dnsquestion_t* dq) { - return dq->dq->dh->opcode; + return dq->dq->getHeader()->opcode; } bool dnsdist_ffi_dnsquestion_get_tcp(const dnsdist_ffi_dnsquestion_t* dq) @@ -338,27 +339,28 @@ void dnsdist_ffi_dnsquestion_set_result(dnsdist_ffi_dnsquestion_t* dq, const cha dq->result = std::string(str, strSize); } -void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, const char* contentType) +void dnsdist_ffi_dnsquestion_set_http_response(dnsdist_ffi_dnsquestion_t* dq, uint16_t statusCode, const char* body, size_t bodyLen, const char* contentType) { if (dq->dq->du == nullptr) { return; } #ifdef HAVE_DNS_OVER_HTTPS - dq->dq->du->setHTTPResponse(statusCode, body, contentType); - dq->dq->dh->qr = true; + std::vector bodyVect(body, body + bodyLen); + dq->dq->du->setHTTPResponse(statusCode, std::move(bodyVect), contentType); + dq->dq->getHeader()->qr = true; #endif } void dnsdist_ffi_dnsquestion_set_rcode(dnsdist_ffi_dnsquestion_t* dq, int rcode) { - dq->dq->dh->rcode = rcode; - dq->dq->dh->qr = true; + dq->dq->getHeader()->rcode = rcode; + dq->dq->getHeader()->qr = true; } void dnsdist_ffi_dnsquestion_set_len(dnsdist_ffi_dnsquestion_t* dq, uint16_t len) { - dq->dq->len = len; + dq->dq->getMutableData().resize(len); } void dnsdist_ffi_dnsquestion_set_skip_cache(dnsdist_ffi_dnsquestion_t* dq, bool skipCache) diff --git a/pdns/dnsdistdist/dnsdist-proxy-protocol.cc b/pdns/dnsdistdist/dnsdist-proxy-protocol.cc index e7773a2003..eb10ca7373 100644 --- a/pdns/dnsdistdist/dnsdist-proxy-protocol.cc +++ b/pdns/dnsdistdist/dnsdist-proxy-protocol.cc @@ -29,15 +29,11 @@ std::string getProxyProtocolPayload(const DNSQuestion& dq) bool addProxyProtocol(DNSQuestion& dq, const std::string& payload) { - if ((dq.size - dq.len) < payload.size()) { + if (!dq.hasRoomFor(payload.size())) { return false; } - memmove(reinterpret_cast(dq.dh) + payload.size(), dq.dh, dq.len); - memcpy(dq.dh, payload.c_str(), payload.size()); - dq.len += payload.size(); - - return true; + return addProxyProtocol(dq.getMutableData(), payload); } bool addProxyProtocol(DNSQuestion& dq) @@ -53,9 +49,7 @@ bool addProxyProtocol(std::vector& buffer, const std::string& payload) return false; } - buffer.resize(previousSize + payload.size()); - std::copy_backward(buffer.begin(), buffer.begin() + previousSize, buffer.end()); - std::copy(payload.begin(), payload.end(), buffer.begin()); + buffer.insert(buffer.begin(), payload.begin(), payload.end()); return true; } diff --git a/pdns/dnsdistdist/dnsdist-rules.hh b/pdns/dnsdistdist/dnsdist-rules.hh index 202e0b0553..bbd7018251 100644 --- a/pdns/dnsdistdist/dnsdist-rules.hh +++ b/pdns/dnsdistdist/dnsdist-rules.hh @@ -387,7 +387,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return dq->dh->cd || (getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. + return dq->getHeader()->cd || (getEDNSZ(*dq) & EDNS_HEADER_FLAG_DO); // turns out dig sets ad by default.. } string toString() const override @@ -667,7 +667,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return d_opcode == dq->dh->opcode; + return d_opcode == dq->getHeader()->opcode; } string toString() const override { @@ -743,16 +743,16 @@ public: uint16_t count = 0; switch(d_section) { case 0: - count = ntohs(dq->dh->qdcount); + count = ntohs(dq->getHeader()->qdcount); break; case 1: - count = ntohs(dq->dh->ancount); + count = ntohs(dq->getHeader()->ancount); break; case 2: - count = ntohs(dq->dh->nscount); + count = ntohs(dq->getHeader()->nscount); break; case 3: - count = ntohs(dq->dh->arcount); + count = ntohs(dq->getHeader()->arcount); break; } return count >= d_minCount && count <= d_maxCount; @@ -793,22 +793,22 @@ public: uint16_t count = 0; switch(d_section) { case 0: - count = ntohs(dq->dh->qdcount); + count = ntohs(dq->getHeader()->qdcount); break; case 1: - count = ntohs(dq->dh->ancount); + count = ntohs(dq->getHeader()->ancount); break; case 2: - count = ntohs(dq->dh->nscount); + count = ntohs(dq->getHeader()->nscount); break; case 3: - count = ntohs(dq->dh->arcount); + count = ntohs(dq->getHeader()->arcount); break; } if (count < d_minCount) { return false; } - count = getRecordsOfTypeCount(reinterpret_cast(dq->dh), dq->len, d_section, d_type); + count = getRecordsOfTypeCount(reinterpret_cast(dq->getData().data()), dq->getData().size(), d_section, d_type); return count >= d_minCount && count <= d_maxCount; } string toString() const override @@ -845,8 +845,8 @@ public: } bool matches(const DNSQuestion* dq) const override { - uint16_t length = getDNSPacketLength(reinterpret_cast(dq->dh), dq->len); - return length < dq->len; + uint16_t length = getDNSPacketLength(reinterpret_cast(dq->getData().data()), dq->getData().size()); + return length < dq->getData().size(); } string toString() const override { @@ -902,7 +902,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return d_rcode == dq->dh->rcode; + return d_rcode == dq->getHeader()->rcode; } string toString() const override { @@ -921,7 +921,7 @@ public: bool matches(const DNSQuestion* dq) const override { // avoid parsing EDNS OPT RR when not needed. - if (d_rcode != dq->dh->rcode) { + if (d_rcode != dq->getHeader()->rcode) { return false; } @@ -975,8 +975,7 @@ public: uint16_t optStart; size_t optLen = 0; bool last = false; - const char * packet = reinterpret_cast(dq->dh); - std::string packetStr(packet, dq->len); + std::string packetStr(dq->getData().begin(), dq->getData().end()); int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last); if (res != 0) { // no EDNS OPT RR @@ -987,7 +986,7 @@ public: return false; } - if (optStart < dq->len && packetStr.at(optStart) != 0) { + if (optStart < dq->getData().size() && packetStr.at(optStart) != 0) { // OPT RR Name != '.' return false; } @@ -1010,7 +1009,7 @@ public: } bool matches(const DNSQuestion* dq) const override { - return dq->dh->rd == 1; + return dq->getHeader()->rd == 1; } string toString() const override { diff --git a/pdns/dnsdistdist/doh.cc b/pdns/dnsdistdist/doh.cc index ff9ce10a14..de54c81822 100644 --- a/pdns/dnsdistdist/doh.cc +++ b/pdns/dnsdistdist/doh.cc @@ -328,7 +328,7 @@ static const std::string& getReasonFromStatusCode(uint16_t statusCode) } /* Always called from the main DoH thread */ -static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCode, const std::string& response, const std::vector>& customResponseHeaders, const std::string& contentType, bool addContentType) +static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCode, const std::vector& response, const std::vector>& customResponseHeaders, const std::string& contentType, bool addContentType) { constexpr int overwrite_if_exists = 1; constexpr int maybe_token = 1; @@ -352,7 +352,7 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo } if (df.d_sendCacheControlHeaders && !response.empty()) { - uint32_t minTTL = getDNSPacketMinTTL(response.data(), response.size()); + uint32_t minTTL = getDNSPacketMinTTL(reinterpret_cast(response.data()), response.size()); if (minTTL != std::numeric_limits::max()) { std::string cacheControlValue = "max-age=" + std::to_string(minTTL); /* we need to duplicate the header content because h2o keeps a pointer and we will be deleted before the response has been sent */ @@ -362,18 +362,19 @@ static void handleResponse(DOHFrontend& df, st_h2o_req_t* req, uint16_t statusCo } req->res.content_length = response.size(); - h2o_send_inline(req, response.c_str(), response.size()); + h2o_send_inline(req, reinterpret_cast(response.data()), response.size()); } else if (statusCode >= 300 && statusCode < 400) { /* in that case the response is actually a URL */ /* we need to duplicate the URL because h2o uses it for the location header, keeping a pointer, and we will be deleted before the response has been sent */ - h2o_iovec_t url = h2o_strdup(&req->pool, response.c_str(), response.size()); + h2o_iovec_t url = h2o_strdup(&req->pool, reinterpret_cast(response.data()), response.size()); h2o_send_redirect(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), url.base, url.len); ++df.d_redirectresponses; } else { - if (!response.empty()) { - h2o_send_error_generic(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), response.c_str(), H2O_SEND_ERROR_KEEP_HEADERS); + // we need to make sure it's null-terminated */ + if (!response.empty() && response.at(response.size() - 1) == 0) { + h2o_send_error_generic(req, statusCode, getReasonFromStatusCode(statusCode).c_str(), reinterpret_cast(response.data()), H2O_SEND_ERROR_KEEP_HEADERS); } else { switch(statusCode) { @@ -434,26 +435,25 @@ static int processDOHQuery(DOHUnit* du) rings for example */ struct timespec queryRealTime; gettime(&queryRealTime, true); - uint16_t len = du->query.length(); - /* We reserve at least 512 additional bytes to be able to add EDNS, but we also want - at least s_maxPacketCacheEntrySize bytes to be able to spoof the content or fill the answer from the packet cache */ - du->query.resize(std::max(du->query.size() + 512, s_maxPacketCacheEntrySize)); - size_t bufferSize = du->query.size(); - auto query = const_cast(du->query.c_str()); - struct dnsheader* dh = reinterpret_cast(query); - if (!checkQueryHeaders(dh)) { - du->status_code = 400; - return -1; // drop + { + /* don't keep that pointer around, it will be invalidated if the buffer is ever resized */ + struct dnsheader* dh = reinterpret_cast(du->query.data()); + + if (!checkQueryHeaders(dh)) { + du->status_code = 400; + return -1; // drop + } + + queryId = ntohs(dh->id); } uint16_t qtype, qclass; - unsigned int consumed = 0; - DNSName qname(query, len, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, &du->dest, &du->remote, dh, bufferSize, len, false, &queryRealTime); + unsigned int qnameWireLength = 0; + DNSName qname(reinterpret_cast(du->query.data()), du->query.size(), sizeof(dnsheader), false, &qtype, &qclass, &qnameWireLength); + DNSQuestion dq(&qname, qtype, qclass, &du->dest, &du->remote, du->query, false, &queryRealTime); dq.ednsAdded = du->ednsAdded; dq.du = du; - queryId = ntohs(dh->id); dq.sni = std::move(du->sni); std::shared_ptr ss{nullptr}; @@ -466,12 +466,13 @@ static int processDOHQuery(DOHUnit* du) if (result == ProcessQueryResult::SendAnswer) { if (du->response.empty()) { - du->response = std::string(reinterpret_cast(dq.dh), dq.len); + du->response = std::move(du->query); } /* increase the ref counter before sending the pointer */ du->get(); static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + ssize_t sent = write(du->rsock, &du, sizeof(du)); if (sent != sizeof(du)) { if (errno == EAGAIN || errno == EWOULDBLOCK) { @@ -533,9 +534,11 @@ static int processDOHQuery(DOHUnit* du) ids->du = du; ids->cs = &cs; - ids->origID = dh->id; + ids->origID = queryId; setIDStateFromDNSQuestion(*ids, dq, std::move(qname)); + dq.getHeader()->id = idOffset; + /* If we couldn't harvest the real dest addr, still write down the listening addr since it will be useful (especially if it's not an 'any' one). @@ -551,8 +554,6 @@ static int processDOHQuery(DOHUnit* du) ids->destHarvested = false; } - dh->id = idOffset; - if (ss->useProxyProtocol) { addProxyProtocol(dq); } @@ -560,7 +561,7 @@ static int processDOHQuery(DOHUnit* du) int fd = pickBackendSocketForSending(ss); try { /* you can't touch du after this line, because it might already have been freed */ - ssize_t ret = udpClientSendRequestToBackend(ss, fd, query, dq.len); + ssize_t ret = udpClientSendRequestToBackend(ss, fd, du->query); if(ret < 0) { /* we are about to handle the error, make sure that @@ -653,12 +654,12 @@ static void on_generator_dispose(void *_self) /* This executes in the main DoH thread. We allocate a DOHUnit and send it to dnsdistclient() function in the doh client thread via a pipe */ -static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_req_t* req, std::string&& query, const ComboAddress& local, const ComboAddress& remote, std::string&& path) +static void doh_dispatch_query(DOHServerConfig* dsc, h2o_handler_t* self, h2o_req_t* req, std::vector&& query, const ComboAddress& local, const ComboAddress& remote, std::string&& path) { try { /* we only parse it there as a sanity check, we will parse it again later */ uint16_t qtype; - DNSName qname(query.c_str(), query.size(), sizeof(dnsheader), false, &qtype); + DNSName qname(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &qtype); auto du = std::unique_ptr(new DOHUnit); du->dsc = dsc; @@ -830,9 +831,10 @@ try ++dsc->cs->tlsUnknownqueries; } + #warning turn these into string_view? string path(req->path.base, req->path.len); - string pathOnly(req->path_normalized.base, req->path_normalized.len); + if (dsc->paths.count(pathOnly) == 0) { h2o_send_error_404(req, "Not Found", "there is no endpoint configured for this path", 0); return 0; @@ -853,11 +855,12 @@ try else ++dsc->df->d_http1Stats.d_nbQueries; - std::string query; + std::vector query; /* We reserve at least 512 additional bytes to be able to add EDNS, but we also want at least s_maxPacketCacheEntrySize bytes to be able to fill the answer from the packet cache */ query.reserve(std::max(req->entity.len + 512, s_maxPacketCacheEntrySize)); - query.assign(req->entity.base, req->entity.len); + query.resize(req->entity.len); + memcpy(query.data(), req->entity.base, req->entity.len); doh_dispatch_query(dsc, self, req, std::move(query), local, remote, std::move(path)); } else if(req->query_at != SIZE_MAX && (req->path.len - req->query_at > 5)) { @@ -879,7 +882,8 @@ try break; } - string decoded; + std::string decoded; + /* rough estimate so we hopefully don't need a new allocation later */ /* We reserve at least 512 additional bytes to be able to add EDNS, but we also want at least s_maxPacketCacheEntrySize bytes to be able to fill the answer from the packet cache */ @@ -897,7 +901,9 @@ try else ++dsc->df->d_http1Stats.d_nbQueries; - doh_dispatch_query(dsc, self, req, std::move(decoded), local, remote, std::move(path)); +#warning FIXME: performance + auto vect = std::vector(decoded.begin(), decoded.end()); + doh_dispatch_query(dsc, self, req, std::move(vect), local, remote, std::move(path)); } } else @@ -1029,10 +1035,17 @@ std::string DOHUnit::getHTTPQueryString() const } } -void DOHUnit::setHTTPResponse(uint16_t statusCode, const std::string& body_, const std::string& contentType_) +void DOHUnit::setHTTPResponse(uint16_t statusCode, std::vector&& body_, const std::string& contentType_) { status_code = statusCode; - response = body_; + response = std::move(body_); + if (!response.empty() && statusCode >= 400) { + // we need to make sure it's null-terminated */ + if (response.at(response.size() - 1) != 0) { + response.push_back(0); + } + } + contentType = contentType_; } @@ -1068,14 +1081,14 @@ static void dnsdistclient(int qsock) // if there was no EDNS, we add it with a large buffer size // so we can use UDP to talk to the backend. - auto dh = const_cast(reinterpret_cast(du->query.c_str())); + auto dh = const_cast(reinterpret_cast(du->query.data())); if (!dh->arcount) { std::string res; generateOptRR(std::string(), res, 4096, 0, false); - du->query += res; - dh = const_cast(reinterpret_cast(du->query.c_str())); // may have reallocated + du->query.insert(du->query.end(), res.begin(), res.end()); + dh = const_cast(reinterpret_cast(du->query.data())); // may have reallocated dh->arcount = htons(1); du->ednsAdded = true; } @@ -1089,6 +1102,7 @@ static void dnsdistclient(int qsock) du->get(); static_assert(sizeof(du) <= PIPE_BUF, "Writes up to PIPE_BUF are guaranteed not to be interleaved and to either fully succeed or fail"); + ssize_t sent = write(du->rsock, &du, sizeof(du)); if (sent != sizeof(du)) { if (errno == EAGAIN || errno == EWOULDBLOCK) { diff --git a/pdns/dnsdistdist/test-dnsdistkvs_cc.cc b/pdns/dnsdistdist/test-dnsdistkvs_cc.cc index 9322f61dad..783e9fabf0 100644 --- a/pdns/dnsdistdist/test-dnsdistkvs_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistkvs_cc.cc @@ -233,10 +233,7 @@ BOOST_AUTO_TEST_CASE(test_LMDB) { uint16_t qclass = QClass::IN; ComboAddress lc("192.0.2.1:53"); ComboAddress rem("192.0.2.128:42"); - struct dnsheader dh; - memset(&dh, 0, sizeof(dh)); - size_t bufferSize = 0; - size_t queryLen = 0; + std::vector packet(sizeof(dnsheader)); bool isTcp = false; struct timespec queryRealTime; gettime(&queryRealTime, true); @@ -244,7 +241,7 @@ BOOST_AUTO_TEST_CASE(test_LMDB) { /* the internal QPS limiter does not use the real time */ gettime(&expiredTime); - DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime); + DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, packet, isTcp, &queryRealTime); ComboAddress v4Masked(v4ToMask); ComboAddress v6Masked(v6ToMask); v4Masked.truncate(25); @@ -291,10 +288,7 @@ BOOST_AUTO_TEST_CASE(test_CDB) { uint16_t qclass = QClass::IN; ComboAddress lc("192.0.2.1:53"); ComboAddress rem("192.0.2.128:42"); - struct dnsheader dh; - memset(&dh, 0, sizeof(dh)); - size_t bufferSize = 0; - size_t queryLen = 0; + std::vector packet(sizeof(dnsheader)); bool isTcp = false; struct timespec queryRealTime; gettime(&queryRealTime, true); @@ -302,7 +296,7 @@ BOOST_AUTO_TEST_CASE(test_CDB) { /* the internal QPS limiter does not use the real time */ gettime(&expiredTime); - DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime); + DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, packet, isTcp, &queryRealTime); ComboAddress v4Masked(v4ToMask); ComboAddress v6Masked(v6ToMask); v4Masked.truncate(25); diff --git a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc index 4e21675def..67fdae00c9 100644 --- a/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistlbpolicies_cc.cc @@ -59,7 +59,7 @@ std::string DOHUnit::getHTTPQueryString() const return ""; } -void DOHUnit::setHTTPResponse(uint16_t statusCode, const std::string& body_, const std::string& contentType_) +void DOHUnit::setHTTPResponse(uint16_t statusCode, std::vector&& body_, const std::string& contentType_) { } #endif /* HAVE_DNS_OVER_HTTPS */ @@ -91,17 +91,14 @@ static DNSQuestion getDQ(const DNSName* providedName = nullptr) static const ComboAddress lc("127.0.0.1:53"); static const ComboAddress rem("192.0.2.1:42"); static struct timespec queryRealTime; - static struct dnsheader dh; + static std::vector packet(sizeof(dnsheader)); - memset(&dh, 0, sizeof(dh)); uint16_t qtype = QType::A; uint16_t qclass = QClass::IN; - size_t bufferSize = 0; - size_t queryLen = 0; bool isTcp = false; gettime(&queryRealTime, true); - DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime); + DNSQuestion dq(providedName ? providedName : &qname, qtype, qclass, &lc, &rem, packet, isTcp, &queryRealTime); return dq; } diff --git a/pdns/dnsdistdist/test-dnsdistrules_cc.cc b/pdns/dnsdistdist/test-dnsdistrules_cc.cc index 65d55ad01b..bc396d8b7d 100644 --- a/pdns/dnsdistdist/test-dnsdistrules_cc.cc +++ b/pdns/dnsdistdist/test-dnsdistrules_cc.cc @@ -22,10 +22,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { uint16_t qclass = QClass::IN; ComboAddress lc("127.0.0.1:53"); ComboAddress rem("192.0.2.1:42"); - struct dnsheader dh; - memset(&dh, 0, sizeof(dh)); - size_t bufferSize = 0; - size_t queryLen = 0; + std::vector packet(sizeof(dnsheader)); bool isTcp = false; struct timespec queryRealTime; gettime(&queryRealTime, true); @@ -33,7 +30,7 @@ BOOST_AUTO_TEST_CASE(test_MaxQPSIPRule) { /* the internal QPS limiter does not use the real time */ gettime(&expiredTime); - DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, &dh, bufferSize, queryLen, isTcp, &queryRealTime); + DNSQuestion dq(&qname, qtype, qclass, &lc, &rem, packet, isTcp, &queryRealTime); for (size_t idx = 0; idx < maxQPS; idx++) { /* let's use different source ports, it shouldn't matter */ diff --git a/pdns/doh.hh b/pdns/doh.hh index 36d720cd42..ed17ad2604 100644 --- a/pdns/doh.hh +++ b/pdns/doh.hh @@ -28,8 +28,12 @@ struct DOHServerConfig; class DOHResponseMapEntry { public: - DOHResponseMapEntry(const std::string& regex, uint16_t status, const std::string& content, const boost::optional>>& headers): d_regex(regex), d_customHeaders(headers), d_content(content), d_status(status) + DOHResponseMapEntry(const std::string& regex, uint16_t status, const std::vector& content, const boost::optional>>& headers): d_regex(regex), d_customHeaders(headers), d_content(content), d_status(status) { + if (status >= 400 && !d_content.empty() && d_content.at(d_content.size() -1) != 0) { + // we need to make sure it's null-terminated + d_content.push_back(0); + } } bool matches(const std::string& path) const @@ -42,7 +46,7 @@ public: return d_status; } - const std::string& getContent() const + const std::vector& getContent() const { return d_content; } @@ -55,7 +59,7 @@ public: private: Regex d_regex; boost::optional>> d_customHeaders; - std::string d_content; + std::vector d_content; uint16_t d_status; }; @@ -185,8 +189,8 @@ struct DOHUnit } std::vector> headers; - std::string query; - std::string response; + std::vector query; + std::vector response; std::string sni; std::string path; std::string scheme; @@ -215,7 +219,7 @@ struct DOHUnit std::string getHTTPScheme() const; std::string getHTTPQueryString() const; std::unordered_map getHTTPHeaders() const; - void setHTTPResponse(uint16_t statusCode, const std::string& body, const std::string& contentType=""); + void setHTTPResponse(uint16_t statusCode, std::vector&& body, const std::string& contentType=""); }; #endif /* HAVE_DNS_OVER_HTTPS */ diff --git a/pdns/ednsoptions.cc b/pdns/ednsoptions.cc index ecfba3f867..1b1b6ca617 100644 --- a/pdns/ednsoptions.cc +++ b/pdns/ednsoptions.cc @@ -41,12 +41,12 @@ bool getNextEDNSOption(const char* data, size_t dataLen, uint16_t& optionCode, u return true; } -/* extract a specific EDNS0 option from a pointer on the beginning rdLen of the OPT RR */ -int getEDNSOption(char* optRR, const size_t len, uint16_t wantedOption, char ** optionValue, size_t * optionValueSize) +/* extract the position (relative to the optRR pointer!) and size of a specific EDNS0 option from a pointer on the beginning rdLen of the OPT RR */ +int getEDNSOption(const char* optRR, const size_t len, uint16_t wantedOption, size_t* optionValuePosition, size_t * optionValueSize) { - assert(optRR != NULL); - assert(optionValue != NULL); - assert(optionValueSize != NULL); + assert(optRR != nullptr); + assert(optionValuePosition != nullptr); + assert(optionValueSize != nullptr); size_t pos = 0; if (len < DNS_RDLENGTH_SIZE) return EINVAL; @@ -76,7 +76,7 @@ int getEDNSOption(char* optRR, const size_t len, uint16_t wantedOption, char ** } if (optionCode == wantedOption) { - *optionValue = optRR + pos - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE); + *optionValuePosition = pos - (EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE); *optionValueSize = optionLen + EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE; return 0; } diff --git a/pdns/ednsoptions.hh b/pdns/ednsoptions.hh index 58e8e8868b..a26eb230c1 100644 --- a/pdns/ednsoptions.hh +++ b/pdns/ednsoptions.hh @@ -27,8 +27,8 @@ struct EDNSOptionCode enum EDNSOptionCodeEnum {NSID=3, DAU=5, DHU=6, N3U=7, ECS=8, EXPIRE=9, COOKIE=10, TCPKEEPALIVE=11, PADDING=12, CHAIN=13, KEYTAG=14, EXTENDEDERROR=15}; }; -/* extract a specific EDNS0 option from a pointer on the beginning rdLen of the OPT RR */ -int getEDNSOption(char* optRR, size_t len, uint16_t wantedOption, char ** optionValue, size_t * optionValueSize); +/* extract the position (relative to the optRR pointer!) and size of a specific EDNS0 option from a pointer on the beginning rdLen of the OPT RR */ +int getEDNSOption(const char* optRR, size_t len, uint16_t wantedOption, size_t* optionValuePosition, size_t* optionValueSize); struct EDNSOptionViewValue { diff --git a/pdns/fuzz_dnsdistcache.cc b/pdns/fuzz_dnsdistcache.cc index 79812ae4c7..1b12f35602 100644 --- a/pdns/fuzz_dnsdistcache.cc +++ b/pdns/fuzz_dnsdistcache.cc @@ -43,11 +43,12 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { uint16_t qtype; uint16_t qclass; unsigned int consumed; + std::vector vect(data, data+size); const DNSName qname(reinterpret_cast(data), size, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - pcSkipCookies.getKey(qname.getStorage(), consumed, data, size, false); - pcHashCookies.getKey(qname.getStorage(), consumed, data, size, false); + pcSkipCookies.getKey(qname.getStorage(), consumed, vect, false); + pcHashCookies.getKey(qname.getStorage(), consumed, vect, false); boost::optional subnet; - DNSDistPacketCache::getClientSubnet(reinterpret_cast(data), consumed, size, subnet); + DNSDistPacketCache::getClientSubnet(vect, consumed, subnet); } catch(const std::exception& e) { } diff --git a/pdns/iputils.cc b/pdns/iputils.cc index f1e1f987fc..a032629eaa 100644 --- a/pdns/iputils.cc +++ b/pdns/iputils.cc @@ -261,7 +261,7 @@ int sendOnNBSocket(int fd, const struct msghdr *msgh) return sendErr; } -ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to) +ssize_t sendfromto(int sock, const void* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to) { struct msghdr msgh; struct iovec iov; @@ -269,7 +269,7 @@ ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const Comb /* Set up iov and msgh structures. */ memset(&msgh, 0, sizeof(struct msghdr)); - iov.iov_base = (void*)data; + iov.iov_base = const_cast(data); iov.iov_len = len; msgh.msg_iov = &iov; msgh.msg_iovlen = 1; diff --git a/pdns/iputils.hh b/pdns/iputils.hh index 3552a8e26f..2d9f5af444 100644 --- a/pdns/iputils.hh +++ b/pdns/iputils.hh @@ -1434,7 +1434,7 @@ bool HarvestDestinationAddress(const struct msghdr* msgh, ComboAddress* destinat bool HarvestTimestamp(struct msghdr* msgh, struct timeval* tv); void fillMSGHdr(struct msghdr* msgh, struct iovec* iov, cmsgbuf_aligned* cbuf, size_t cbufsize, char* data, size_t datalen, ComboAddress* addr); int sendOnNBSocket(int fd, const struct msghdr *msgh); -ssize_t sendfromto(int sock, const char* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to); +ssize_t sendfromto(int sock, const void* data, size_t len, int flags, const ComboAddress& from, const ComboAddress& to); size_t sendMsgWithOptions(int fd, const char* buffer, size_t len, const ComboAddress* dest, const ComboAddress* local, unsigned int localItf, int flags); /* requires a non-blocking, connected TCP socket */ diff --git a/pdns/pdns_recursor.cc b/pdns/pdns_recursor.cc index 8261f95ac6..b3bb33b3c3 100644 --- a/pdns/pdns_recursor.cc +++ b/pdns/pdns_recursor.cc @@ -2241,13 +2241,13 @@ static void getQNameAndSubnet(const std::string& question, DNSName* dnsname, uin /* OPT root label (1) followed by type (2) */ if(lookForECS && ntohs(drh->d_type) == QType::OPT) { if (!options) { - char* ecsStart = nullptr; + size_t ecsStartPosition = 0; size_t ecsLen = 0; /* we need to pass the record len */ - int res = getEDNSOption(const_cast(reinterpret_cast(&question.at(pos - sizeof(drh->d_clen)))), questionLen - pos + sizeof(drh->d_clen), EDNSOptionCode::ECS, &ecsStart, &ecsLen); + int res = getEDNSOption(reinterpret_cast(&question.at(pos - sizeof(drh->d_clen))), questionLen - pos + sizeof(drh->d_clen), EDNSOptionCode::ECS, &ecsStartPosition, &ecsLen); if (res == 0 && ecsLen > 4) { EDNSSubnetOpts eso; - if(getEDNSSubnetOptsFromString(ecsStart + 4, ecsLen - 4, &eso)) { + if(getEDNSSubnetOptsFromString(&question.at(pos - sizeof(drh->d_clen) + ecsStartPosition + 4), ecsLen - 4, &eso)) { *ednssubnet=eso; foundECS = true; } diff --git a/pdns/recursordist/test-ednsoptions_cc.cc b/pdns/recursordist/test-ednsoptions_cc.cc index 20cd19481c..cb07d377d6 100644 --- a/pdns/recursordist/test-ednsoptions_cc.cc +++ b/pdns/recursordist/test-ednsoptions_cc.cc @@ -58,13 +58,13 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOption) BOOST_REQUIRE_EQUAL(query.at(pos), 0); BOOST_REQUIRE(query.at(pos + 2) == QType::OPT); - char* ecsStart = nullptr; + size_t ecsStartPosition = 0; size_t ecsLen = 0; - int res = getEDNSOption(reinterpret_cast(query.data()) + pos + 9, questionLen - pos - 9, EDNSOptionCode::ECS, &ecsStart, &ecsLen); + int res = getEDNSOption(reinterpret_cast(&query.at(pos + 9)), questionLen - pos - 9, EDNSOptionCode::ECS, &ecsStartPosition, &ecsLen); BOOST_CHECK_EQUAL(res, 0); EDNSSubnetOpts eso; - BOOST_REQUIRE(getEDNSSubnetOptsFromString(ecsStart + 4, ecsLen - 4, &eso)); + BOOST_REQUIRE(getEDNSSubnetOptsFromString(reinterpret_cast(&query.at(pos + 9 + ecsStartPosition + 4)), ecsLen - 4, &eso)); BOOST_CHECK(eso.source == ecs); } diff --git a/pdns/test-dnscrypt_cc.cc b/pdns/test-dnscrypt_cc.cc index 89658906c4..bbc08e83ef 100644 --- a/pdns/test-dnscrypt_cc.cc +++ b/pdns/test-dnscrypt_cc.cc @@ -52,12 +52,9 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQuery) { vector plainQuery; DNSPacketWriter pw(plainQuery, name, QType::TXT, QClass::IN, 0); pw.getHeader()->rd = 0; - uint16_t len = plainQuery.size(); std::shared_ptr query = std::make_shared(ctx); - uint16_t decryptedLen = 0; - - query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now); + query->parsePacket(plainQuery, false, now); BOOST_CHECK_EQUAL(query->isValid(), true); BOOST_CHECK_EQUAL(query->isEncrypted(), false); @@ -94,12 +91,9 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidA) { vector plainQuery; DNSPacketWriter pw(plainQuery, name, QType::A, QClass::IN, 0); pw.getHeader()->rd = 0; - uint16_t len = plainQuery.size(); std::shared_ptr query = std::make_shared(ctx); - uint16_t decryptedLen = 0; - - query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now); + query->parsePacket(plainQuery, false, now); BOOST_CHECK_EQUAL(query->isValid(), false); } @@ -120,12 +114,9 @@ BOOST_AUTO_TEST_CASE(DNSCryptPlaintextQueryInvalidProviderName) { vector plainQuery; DNSPacketWriter pw(plainQuery, name, QType::TXT, QClass::IN, 0); pw.getHeader()->rd = 0; - uint16_t len = plainQuery.size(); std::shared_ptr query = std::make_shared(ctx); - uint16_t decryptedLen = 0; - - query->parsePacket((char*) plainQuery.data(), len, false, &decryptedLen, now); + query->parsePacket(plainQuery, false, now); BOOST_CHECK_EQUAL(query->isValid(), false); } @@ -157,24 +148,22 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValid) { requiredSize = DNSCryptQuery::s_minUDPLength; } - uint16_t len = plainQuery.size(); plainQuery.resize(requiredSize); - uint16_t encryptedResponseLen = 0; - int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); + size_t initialSize = plainQuery.size(); + int res = ctx->encryptQuery(plainQuery, 4096, clientPublicKey, clientPrivateKey, clientNonce, false, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, 0); - BOOST_CHECK(encryptedResponseLen > len); + BOOST_CHECK(plainQuery.size() > initialSize); std::shared_ptr query = std::make_shared(ctx); - uint16_t decryptedLen = 0; - query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now); + query->parsePacket(plainQuery, false, now); BOOST_CHECK_EQUAL(query->isValid(), true); BOOST_CHECK_EQUAL(query->isEncrypted(), true); - MOADNSParser mdp(true, (char*) plainQuery.data(), decryptedLen); + MOADNSParser mdp(true, (char*) plainQuery.data(), plainQuery.size()); BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1U); BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0U); @@ -209,11 +198,7 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidButShort) { DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0); pw.getHeader()->rd = 1; - uint16_t len = plainQuery.size(); - uint16_t encryptedResponseLen = 0; - - int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); - + int res = ctx->encryptQuery(plainQuery, /* not enough room */ plainQuery.size(), clientPublicKey, clientPrivateKey, clientNonce, false, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, ENOBUFS); } @@ -245,14 +230,11 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) { requiredSize = DNSCryptQuery::s_minUDPLength; } - uint16_t len = plainQuery.size(); - plainQuery.resize(requiredSize); - uint16_t encryptedResponseLen = 0; - - int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); + size_t initialSize = plainQuery.size(); + int res = ctx->encryptQuery(plainQuery, 4096, clientPublicKey, clientPrivateKey, clientNonce, false, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, 0); - BOOST_CHECK(encryptedResponseLen > len); + BOOST_CHECK(plainQuery.size() > initialSize); DNSCryptCert newResolverCert; DNSCryptContext::generateCertificate(2, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, newResolverCert); @@ -260,14 +242,13 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryValidWithOldKey) { ctx->markInactive(resolverCert.getSerial()); std::shared_ptr query = std::make_shared(ctx); - uint16_t decryptedLen = 0; - query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now); + query->parsePacket(plainQuery, false, now); BOOST_CHECK_EQUAL(query->isValid(), true); BOOST_CHECK_EQUAL(query->isEncrypted(), true); - MOADNSParser mdp(true, (char*) plainQuery.data(), decryptedLen); + MOADNSParser mdp(true, (char*) plainQuery.data(), plainQuery.size()); BOOST_CHECK_EQUAL(mdp.d_header.qdcount, 1U); BOOST_CHECK_EQUAL(mdp.d_header.ancount, 0U); @@ -302,19 +283,11 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) { DNSPacketWriter pw(plainQuery, name, QType::AAAA, QClass::IN, 0); pw.getHeader()->rd = 1; - size_t requiredSize = plainQuery.size() + sizeof(DNSCryptQueryHeader) + DNSCRYPT_MAC_SIZE; - if (requiredSize < DNSCryptQuery::s_minUDPLength) { - requiredSize = DNSCryptQuery::s_minUDPLength; - } - - uint16_t len = plainQuery.size(); - plainQuery.resize(requiredSize); - uint16_t encryptedResponseLen = 0; - - int res = ctx->encryptQuery((char*) plainQuery.data(), len, plainQuery.capacity(), clientPublicKey, clientPrivateKey, clientNonce, false, &encryptedResponseLen, std::make_shared(resolverCert)); + size_t initialSize = plainQuery.size(); + int res = ctx->encryptQuery(plainQuery, 4096, clientPublicKey, clientPrivateKey, clientNonce, false, std::make_shared(resolverCert)); BOOST_CHECK_EQUAL(res, 0); - BOOST_CHECK(encryptedResponseLen > len); + BOOST_CHECK(plainQuery.size() > initialSize); DNSCryptCert newResolverCert; DNSCryptContext::generateCertificate(2, now, now + (24 * 60 * 3600), DNSCryptExchangeVersion::VERSION1, providerPrivateKey, resolverPrivateKey, newResolverCert); @@ -325,9 +298,8 @@ BOOST_AUTO_TEST_CASE(DNSCryptEncryptedQueryInvalidWithWrongKey) { /* we have removed the old certificate, we can't decrypt this query */ std::shared_ptr query = std::make_shared(ctx); - uint16_t decryptedLen = 0; - query->parsePacket((char*) plainQuery.data(), encryptedResponseLen, false, &decryptedLen, now); + query->parsePacket(plainQuery, false, now); BOOST_CHECK_EQUAL(query->isValid(), false); } diff --git a/pdns/test-dnsdist_cc.cc b/pdns/test-dnsdist_cc.cc index e24d35bc7d..eba97ce7cb 100644 --- a/pdns/test-dnsdist_cc.cc +++ b/pdns/test-dnsdist_cc.cc @@ -42,9 +42,9 @@ BOOST_AUTO_TEST_SUITE(test_dnsdist_cc) static const uint16_t ECSSourcePrefixV4 = 24; static const uint16_t ECSSourcePrefixV6 = 56; -static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=true, bool hasXPF=false, uint16_t additionals=0, uint16_t answers=0, uint16_t authorities=0) +static void validateQuery(const std::vector& packet, bool hasEdns=true, bool hasXPF=false, uint16_t additionals=0, uint16_t answers=0, uint16_t authorities=0) { - MOADNSParser mdp(true, packet, packetSize); + MOADNSParser mdp(true, reinterpret_cast(packet.data()), packet.size()); BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "www.powerdns.com."); @@ -55,14 +55,14 @@ static void validateQuery(const char * packet, size_t packetSize, bool hasEdns=t BOOST_CHECK_EQUAL(mdp.d_header.arcount, expectedARCount); } -static void validateECS(const char* packet, size_t packetSize, const ComboAddress& expected) +static void validateECS(const std::vector& packet, const ComboAddress& expected) { ComboAddress rem("::1"); unsigned int consumed = 0; uint16_t qtype; uint16_t qclass; - DNSName qname(packet, packetSize, sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &rem, const_cast(reinterpret_cast(packet)), packetSize, packetSize, false, nullptr); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &rem, const_cast&>(packet), false, nullptr); BOOST_CHECK(parseEDNSOptions(dq)); BOOST_REQUIRE(dq.ednsOptions != nullptr); BOOST_CHECK_EQUAL(dq.ednsOptions->size(), 1U); @@ -76,9 +76,9 @@ static void validateECS(const char* packet, size_t packetSize, const ComboAddres BOOST_CHECK_EQUAL(expectedOption.substr(EDNS_OPTION_CODE_SIZE + EDNS_OPTION_LENGTH_SIZE), std::string(ecsOption->second.values.at(0).content, ecsOption->second.values.at(0).size)); } -static void validateResponse(const char * packet, size_t packetSize, bool hasEdns, uint8_t additionalCount=0) +static void validateResponse(const std::vector& packet, bool hasEdns, uint8_t additionalCount=0) { - MOADNSParser mdp(false, packet, packetSize); + MOADNSParser mdp(false, reinterpret_cast(packet.data()), packet.size()); BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "www.powerdns.com."); @@ -101,109 +101,69 @@ BOOST_AUTO_TEST_CASE(test_addXPF) vector query; DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); pw.getHeader()->rd = 1; - const uint16_t len = query.size(); vector queryWithXPF; { - char packet[1500]; - memcpy(packet, query.data(), query.size()); + std::vector packet = query; /* large enough packet */ unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, false, &queryTime); - BOOST_CHECK(addXPF(dq, xpfOptionCode, false)); - BOOST_CHECK(static_cast(dq.len) > query.size()); - validateQuery(packet, dq.len, false, true); - queryWithXPF.resize(dq.len); - memcpy(queryWithXPF.data(), packet, dq.len); + BOOST_CHECK(addXPF(dq, xpfOptionCode)); + BOOST_CHECK(packet.size() > query.size()); + validateQuery(packet, false, true); + queryWithXPF = packet; } { - char packet[1500]; - memcpy(packet, query.data(), query.size()); + std::vector packet = query; - /* not large enough packet */ + /* packet is already too large for the 4096 limit over UDP */ + packet.resize(4096); unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime); - dq.size = dq.len; + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, false, &queryTime); - BOOST_CHECK(!addXPF(dq, xpfOptionCode, false)); - BOOST_CHECK_EQUAL(static_cast(dq.len), query.size()); - validateQuery(packet, dq.len, false, false); + BOOST_REQUIRE(!addXPF(dq, xpfOptionCode)); + BOOST_CHECK_EQUAL(packet.size(), 4096); + packet.resize(query.size()); + validateQuery(packet, false, false); } { - char packet[1500]; - memcpy(packet, query.data(), query.size()); + std::vector packet = query; /* packet with trailing data (overriding it) */ unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, false, &queryTime); /* add trailing data */ const size_t trailingDataSize = 10; /* Making sure we have enough room to allow for fake trailing data */ - BOOST_REQUIRE(sizeof(packet) > dq.len && (sizeof(packet) - dq.len) > trailingDataSize); + packet.resize(packet.size() + trailingDataSize); for (size_t idx = 0; idx < trailingDataSize; idx++) { - packet[dq.len + idx] = 'A'; + packet.push_back('A'); } - dq.len += trailingDataSize; - BOOST_CHECK(addXPF(dq, xpfOptionCode, false)); - BOOST_CHECK_EQUAL(static_cast(dq.len), queryWithXPF.size()); - BOOST_CHECK_EQUAL(memcmp(queryWithXPF.data(), packet, queryWithXPF.size()), 0); - validateQuery(packet, dq.len, false, true); - } - - { - char packet[1500]; - memcpy(packet, query.data(), query.size()); - - /* packet with trailing data (preserving trailing data) */ - unsigned int consumed = 0; - uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime); - - /* add trailing data */ - const size_t trailingDataSize = 10; - /* Making sure we have enough room to allow for fake trailing data */ - BOOST_REQUIRE(sizeof(packet) > dq.len && (sizeof(packet) - dq.len) > trailingDataSize); - for (size_t idx = 0; idx < trailingDataSize; idx++) { - packet[dq.len + idx] = 'A'; - } - dq.len += trailingDataSize; - - BOOST_CHECK(addXPF(dq, xpfOptionCode, true)); - BOOST_CHECK(static_cast(dq.len) > queryWithXPF.size()); - BOOST_CHECK_EQUAL(memcmp(queryWithXPF.data(), packet, queryWithXPF.size()), 0); - for (size_t idx = 0; idx < trailingDataSize; idx++) { - BOOST_CHECK_EQUAL(packet[queryWithXPF.size() + idx], 'A'); - } - validateQuery(packet, dq.len, false, true); + BOOST_CHECK(addXPF(dq, xpfOptionCode)); + BOOST_CHECK_EQUAL(packet.size(), queryWithXPF.size()); + BOOST_CHECK_EQUAL(memcmp(queryWithXPF.data(), packet.data(), queryWithXPF.size()), 0); + validateQuery(packet, false, true); } } @@ -222,89 +182,60 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNS) uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + std::vector packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); - BOOST_CHECK(static_cast(len) > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, false, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, true); BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(packet, len); - validateECS(packet, len, remote); - vector queryWithEDNS; - queryWithEDNS.resize(len); - memcpy(queryWithEDNS.data(), packet, len); + validateQuery(packet); + validateECS(packet, remote); + vector queryWithEDNS = packet; /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); - BOOST_CHECK_EQUAL(static_cast(len), query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, false, newECSOption)); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len, false); + packet.resize(query.size()); + validateQuery(packet, false); /* packet with trailing data (overriding it) */ - memcpy(packet, query.data(), query.size()); + packet = query; ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); /* add trailing data */ const size_t trailingDataSize = 10; /* Making sure we have enough room to allow for fake trailing data */ - BOOST_REQUIRE(sizeof(packet) > len && (sizeof(packet) - len) > trailingDataSize); + packet.resize(packet.size() + trailingDataSize); for (size_t idx = 0; idx < trailingDataSize; idx++) { packet[len + idx] = 'A'; } - len += trailingDataSize; - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); - BOOST_REQUIRE_EQUAL(static_cast(len), queryWithEDNS.size()); - BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0); - BOOST_CHECK_EQUAL(ednsAdded, true); - BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(packet, len); - /* packet with trailing data (preserving trailing data) */ - memcpy(packet, query.data(), query.size()); - ednsAdded = false; - ecsAdded = false; - consumed = 0; - len = query.size(); - qname = DNSName(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); - BOOST_CHECK_EQUAL(qname, name); - BOOST_CHECK(qtype == QType::A); - /* add trailing data */ - /* Making sure we have enough room to allow for fake trailing data */ - BOOST_REQUIRE(sizeof(packet) > len && (sizeof(packet) - len) > trailingDataSize); - for (size_t idx = 0; idx < trailingDataSize; idx++) { - packet[len + idx] = 'A'; - } - len += trailingDataSize; - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, true)); - BOOST_REQUIRE_EQUAL(static_cast(len), queryWithEDNS.size() + trailingDataSize); - BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet, queryWithEDNS.size()), 0); - for (size_t idx = 0; idx < trailingDataSize; idx++) { - BOOST_CHECK_EQUAL(packet[queryWithEDNS.size() + idx], 'A'); - } + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, false, newECSOption)); + BOOST_REQUIRE_EQUAL(packet.size(), queryWithEDNS.size()); + BOOST_CHECK_EQUAL(memcmp(queryWithEDNS.data(), packet.data(), queryWithEDNS.size()), 0); BOOST_CHECK_EQUAL(ednsAdded, true); BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(packet, len); + validateQuery(packet); } BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) @@ -318,45 +249,48 @@ BOOST_AUTO_TEST_CASE(addECSWithoutEDNSAlreadyParsed) DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); pw.getHeader()->rd = 1; - /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; uint16_t qclass; - DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(packet), sizeof(packet), query.size(), false, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ BOOST_CHECK(!parseEDNSOptions(dq)); /* And now we add our own ECS */ - BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false)); - BOOST_CHECK_GT(static_cast(dq.len), query.size()); + BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded)); + BOOST_CHECK_GT(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, true); BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(packet, dq.len); - validateECS(packet, dq.len, remote); + validateQuery(packet); + validateECS(packet, remote); + + /* trailing data */ + packet = query; + packet.resize(2048); - /* not large enough packet */ ednsAdded = false; ecsAdded = false; consumed = 0; - qname = DNSName(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(query.data()), query.size(), query.size(), false, nullptr); + DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); - BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false)); - BOOST_CHECK_EQUAL(static_cast(dq2.len), query.size()); - BOOST_CHECK_EQUAL(ednsAdded, false); - BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), dq2.len, false); + BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded)); + BOOST_CHECK_GT(packet.size(), query.size()); + BOOST_CHECK_LT(packet.size(), 2048); + BOOST_CHECK_EQUAL(ednsAdded, true); + BOOST_CHECK_EQUAL(ecsAdded, true); + validateQuery(packet); + validateECS(packet, remote); } BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { @@ -372,39 +306,37 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECS) { pw.getHeader()->rd = 1; pw.addOpt(512, 0, 0); pw.commit(); - uint16_t len = query.size(); - /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, false, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(packet, len); - validateECS(packet, len, remote); + validateQuery(packet); + validateECS(packet, remote); /* not large enough packet */ consumed = 0; ednsAdded = false; ecsAdded = false; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + packet = query; + + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, false, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, false, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len); + validateQuery(packet); } BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { @@ -419,45 +351,47 @@ BOOST_AUTO_TEST_CASE(addECSWithEDNSNoECSAlreadyParsed) { pw.addOpt(512, 0, 0); pw.commit(); - /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; uint16_t qclass; - DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(packet), sizeof(packet), query.size(), false, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ BOOST_CHECK(parseEDNSOptions(dq)); /* And now we add our own ECS */ - BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false)); - BOOST_CHECK_GT(static_cast(dq.len), query.size()); + BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded)); + BOOST_CHECK_GT(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(packet, dq.len); - validateECS(packet, dq.len, remote); + validateQuery(packet); + validateECS(packet, remote); - /* not large enough packet */ + /* trailing data */ + packet = query; + packet.resize(2048); consumed = 0; ednsAdded = false; ecsAdded = false; - qname = DNSName(reinterpret_cast(query.data()), query.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq2(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(query.data()), query.size(), query.size(), false, nullptr); + DNSQuestion dq2(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); - BOOST_CHECK(!handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded, false)); - BOOST_CHECK_EQUAL(static_cast(dq2.len), query.size()); + BOOST_CHECK(handleEDNSClientSubnet(dq2, ednsAdded, ecsAdded)); + BOOST_CHECK_GT(packet.size(), query.size()); + BOOST_CHECK_LT(packet.size(), 2048); BOOST_CHECK_EQUAL(ednsAdded, false); - BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), dq2.len); + BOOST_CHECK_EQUAL(ecsAdded, true); + validateQuery(packet); + validateECS(packet, remote); } BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { @@ -479,24 +413,22 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSize) { opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); pw.addOpt(512, 0, 0, opts); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, len); - validateECS(packet, len, remote); + validateQuery(packet); + validateECS(packet, remote); } BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) { @@ -517,31 +449,29 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSameSizeAlreadyParsed) { pw.addOpt(512, 0, 0, opts); pw.commit(); - /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; uint16_t qclass; - DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); BOOST_CHECK(qclass == QClass::IN); - DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(packet), sizeof(packet), query.size(), false, nullptr); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); dq.ecsOverride = true; /* Parse the options before handling ECS, simulating a Lua rule asking for EDNS Options */ BOOST_CHECK(parseEDNSOptions(dq)); /* And now we add our own ECS */ - BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded, false)); - BOOST_CHECK_EQUAL(static_cast(dq.len), query.size()); + BOOST_CHECK(handleEDNSClientSubnet(dq, ednsAdded, ecsAdded)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, dq.len); - validateECS(packet, dq.len, remote); + validateQuery(packet); + validateECS(packet, remote); } BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) { @@ -563,24 +493,21 @@ BOOST_AUTO_TEST_CASE(replaceECSWithSmaller) { opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); pw.addOpt(512, 0, 0, opts); pw.commit(); - uint16_t len = query.size(); - /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len < query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() < query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, len); - validateECS(packet, len, remote); + validateQuery(packet); + validateECS(packet, remote); } BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { @@ -596,45 +523,46 @@ BOOST_AUTO_TEST_CASE(replaceECSWithLarger) { DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); pw.getHeader()->rd = 1; EDNSSubnetOpts ecsOpts; + // smaller (less specific so less bits) option + static_assert(8 < ECSSourcePrefixV4, "The ECS scope should be smaller"); ecsOpts.source = Netmask(origRemote, 8); string origECSOption = makeEDNSSubnetOptsString(ecsOpts); DNSPacketWriter::optvect_t opts; opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); pw.addOpt(512, 0, 0, opts); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, len); - validateECS(packet, len, remote); + validateQuery(packet); + validateECS(packet, remote); /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len); + validateQuery(packet); } BOOST_AUTO_TEST_CASE(replaceECSFollowedByTSIG) { @@ -657,39 +585,38 @@ BOOST_AUTO_TEST_CASE(replaceECSFollowedByTSIG) { pw.addOpt(512, 0, 0, opts); pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, len, true, false, 1); - validateECS(packet, len, remote); + validateQuery(packet, true, false, 1); + validateECS(packet, remote); /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len, true, false, 1); + validateQuery(packet, true, false, 1); } BOOST_AUTO_TEST_CASE(replaceECSAfterAN) { @@ -713,39 +640,38 @@ BOOST_AUTO_TEST_CASE(replaceECSAfterAN) { opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); pw.addOpt(512, 0, 0, opts); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, len, true, false, 0, 1, 0); - validateECS(packet, len, remote); + validateQuery(packet, true, false, 0, 1, 0); + validateECS(packet, remote); /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len, true, false, 0, 1, 0); + validateQuery(packet, true, false, 0, 1, 0); } BOOST_AUTO_TEST_CASE(replaceECSAfterAuth) { @@ -769,39 +695,38 @@ BOOST_AUTO_TEST_CASE(replaceECSAfterAuth) { opts.push_back(make_pair(EDNSOptionCode::ECS, origECSOption)); pw.addOpt(512, 0, 0, opts); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, len, true, false, 0, 0, 1); - validateECS(packet, len, remote); + validateQuery(packet, true, false, 0, 0, 1); + validateECS(packet, remote); /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len, true, false, 0, 0, 1); + validateQuery(packet, true, false, 0, 0, 1); } BOOST_AUTO_TEST_CASE(replaceECSBetweenTwoRecords) { @@ -826,39 +751,38 @@ BOOST_AUTO_TEST_CASE(replaceECSBetweenTwoRecords) { pw.addOpt(512, 0, 0, opts); pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(packet, len, true, false, 2); - validateECS(packet, len, remote); + validateQuery(packet, true, false, 2); + validateECS(packet, remote); /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len, true, false, 2); + validateQuery(packet, true, false, 2); } BOOST_AUTO_TEST_CASE(insertECSInEDNSBetweenTwoRecords) { @@ -878,39 +802,38 @@ BOOST_AUTO_TEST_CASE(insertECSInEDNSBetweenTwoRecords) { pw.addOpt(512, 0, 0); pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, true); - validateQuery(packet, len, true, false, 2); - validateECS(packet, len, remote); + validateQuery(packet, true, false, 2); + validateECS(packet, remote); /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(query, packet.size(), consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len, true, false, 2); + validateQuery(packet, true, false, 2); } BOOST_AUTO_TEST_CASE(insertECSAfterTSIG) { @@ -927,40 +850,39 @@ BOOST_AUTO_TEST_CASE(insertECSAfterTSIG) { pw.getHeader()->rd = 1; pw.startRecord(DNSName("tsigname."), QType::TSIG, 0, QClass::ANY, DNSResourceRecord::ADDITIONAL, false); pw.commit(); - uint16_t len = query.size(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(handleEDNSClientSubnet(packet, sizeof packet, consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK((size_t) len > query.size()); + BOOST_CHECK(handleEDNSClientSubnet(packet, 4096, consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK(packet.size() > query.size()); BOOST_CHECK_EQUAL(ednsAdded, true); BOOST_CHECK_EQUAL(ecsAdded, true); /* the MOADNSParser does not allow anything except XPF after a TSIG */ - BOOST_CHECK_THROW(validateQuery(packet, len, true, false, 1), MOADNSException); - validateECS(packet, len, remote); + BOOST_CHECK_THROW(validateQuery(packet, true, false, 1), MOADNSException); + validateECS(packet, remote); /* not large enough packet */ + packet = query; + ednsAdded = false; ecsAdded = false; consumed = 0; - len = query.size(); - qname = DNSName(reinterpret_cast(query.data()), len, sizeof(dnsheader), false, &qtype, NULL, &consumed); + qname = DNSName(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - BOOST_CHECK(!handleEDNSClientSubnet(reinterpret_cast(query.data()), query.size(), consumed, &len, ednsAdded, ecsAdded, true, newECSOption, false)); - BOOST_CHECK_EQUAL((size_t) len, query.size()); + BOOST_CHECK(!handleEDNSClientSubnet(packet, packet.size(), consumed, ednsAdded, ecsAdded, true, newECSOption)); + BOOST_CHECK_EQUAL(packet.size(), query.size()); BOOST_CHECK_EQUAL(ednsAdded, false); BOOST_CHECK_EQUAL(ecsAdded, false); - validateQuery(reinterpret_cast(query.data()), len, true, false); + validateQuery(packet, true, false); } @@ -984,13 +906,13 @@ BOOST_AUTO_TEST_CASE(removeEDNSWhenFirst) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); size_t const ednsOptRRSize = sizeof(struct dnsrecordheader) + 1 /* root in OPT RR */; BOOST_CHECK_EQUAL(newResponse.size(), response.size() - ednsOptRRSize); - validateResponse((const char *) newResponse.data(), newResponse.size(), false, 1); + validateResponse(newResponse, false, 1); } BOOST_AUTO_TEST_CASE(removeEDNSWhenIntermediary) { @@ -1016,13 +938,13 @@ BOOST_AUTO_TEST_CASE(removeEDNSWhenIntermediary) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); size_t const ednsOptRRSize = sizeof(struct dnsrecordheader) + 1 /* root in OPT RR */; BOOST_CHECK_EQUAL(newResponse.size(), response.size() - ednsOptRRSize); - validateResponse((const char *) newResponse.data(), newResponse.size(), false, 2); + validateResponse(newResponse, false, 2); } BOOST_AUTO_TEST_CASE(removeEDNSWhenLast) { @@ -1047,13 +969,13 @@ BOOST_AUTO_TEST_CASE(removeEDNSWhenLast) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); size_t const ednsOptRRSize = sizeof(struct dnsrecordheader) + 1 /* root in OPT RR */; BOOST_CHECK_EQUAL(newResponse.size(), response.size() - ednsOptRRSize); - validateResponse((const char *) newResponse.data(), newResponse.size(), false, 1); + validateResponse(newResponse, false, 1); } BOOST_AUTO_TEST_CASE(removeECSWhenOnlyOption) { @@ -1096,11 +1018,11 @@ BOOST_AUTO_TEST_CASE(removeECSWhenOnlyOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) response.data(), responseLen, true, 1); + validateResponse(response, true, 1); } BOOST_AUTO_TEST_CASE(removeECSWhenFirstOption) { @@ -1148,11 +1070,11 @@ BOOST_AUTO_TEST_CASE(removeECSWhenFirstOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) response.data(), responseLen, true, 1); + validateResponse(response, true, 1); } BOOST_AUTO_TEST_CASE(removeECSWhenIntermediaryOption) { @@ -1204,11 +1126,11 @@ BOOST_AUTO_TEST_CASE(removeECSWhenIntermediaryOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) response.data(), responseLen, true, 1); + validateResponse(response, true, 1); } BOOST_AUTO_TEST_CASE(removeECSWhenLastOption) { @@ -1256,11 +1178,11 @@ BOOST_AUTO_TEST_CASE(removeECSWhenLastOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) response.data(), responseLen, sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) response.data(), responseLen, true, 1); + validateResponse(response, true, 1); } BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenOnlyOption) { @@ -1293,11 +1215,11 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenOnlyOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) newResponse.data(), newResponse.size(), true, 1); + validateResponse(newResponse, true, 1); } BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenFirstOption) { @@ -1335,11 +1257,11 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenFirstOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) newResponse.data(), newResponse.size(), true, 1); + validateResponse(newResponse, true, 1); } BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenIntermediaryOption) { @@ -1379,11 +1301,11 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenIntermediaryOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) newResponse.data(), newResponse.size(), true, 1); + validateResponse(newResponse, true, 1); } BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) { @@ -1421,28 +1343,25 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) { unsigned int consumed = 0; uint16_t qtype; - DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, NULL, &consumed); + DNSName qname((const char*) newResponse.data(), newResponse.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); BOOST_CHECK_EQUAL(qname, name); BOOST_CHECK(qtype == QType::A); - validateResponse((const char *) newResponse.data(), newResponse.size(), true, 1); + validateResponse(newResponse, true, 1); } -static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& realTime, vector& query, size_t len) +static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& realTime, vector& query) { - dnsheader* dh = reinterpret_cast(query.data()); - - return DNSQuestion(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime); + return DNSQuestion(&qname, qtype, qclass, &lc, &rem, query, false, &realTime); } static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, vector& query, bool resizeBuffer=true) { - size_t length = query.size(); if (resizeBuffer) { query.resize(4096); } - auto dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query, length); + auto dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query); BOOST_CHECK(addEDNSToQueryTurnedResponse(dq)); @@ -1455,8 +1374,7 @@ static int getZ(const DNSName& qname, const uint16_t qtype, const uint16_t qclas ComboAddress rem("127.0.0.1"); struct timespec queryRealTime; gettime(&queryRealTime, true); - size_t length = query.size(); - DNSQuestion dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query, length); + DNSQuestion dq = getDNSQuestion(qname, qtype, qclass, lc, rem, queryRealTime, query); return getEDNSZ(dq); } @@ -1591,7 +1509,7 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); - BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.dh), dq.len, &udpPayloadSize, &z), false); + BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), false); BOOST_CHECK_EQUAL(z, 0); BOOST_CHECK_EQUAL(udpPayloadSize, 0); } @@ -1604,9 +1522,9 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { pw.commit(); query.resize(query.size() - (/* RDLEN */ sizeof(uint16_t) + /* last byte of TTL / Z */ 1)); - auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); + auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query, false); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); - BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.dh), dq.len, &udpPayloadSize, &z), false); + BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), false); BOOST_CHECK_EQUAL(z, 0); BOOST_CHECK_EQUAL(udpPayloadSize, 0); } @@ -1620,7 +1538,7 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); - BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.dh), dq.len, &udpPayloadSize, &z), true); + BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, 0); BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers); } @@ -1634,7 +1552,7 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO); - BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.dh), dq.len, &udpPayloadSize, &z), true); + BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO); BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers); } @@ -1648,7 +1566,7 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); BOOST_CHECK_EQUAL(getEDNSZ(dq), 0); - BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.dh), dq.len, &udpPayloadSize, &z), true); + BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, 0); BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers); } @@ -1662,7 +1580,7 @@ BOOST_AUTO_TEST_CASE(test_addEDNSToQueryTurnedResponse) { auto dq = turnIntoResponse(qname, qtype, qclass, lc, rem, queryRealTime, query); BOOST_CHECK_EQUAL(getEDNSZ(dq), EDNS_HEADER_FLAG_DO); - BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.dh), dq.len, &udpPayloadSize, &z), true); + BOOST_CHECK_EQUAL(getEDNSUDPPayloadSizeAndZ(reinterpret_cast(dq.getData().data()), dq.getData().size(), &udpPayloadSize, &z), true); BOOST_CHECK_EQUAL(z, EDNS_HEADER_FLAG_DO); BOOST_CHECK_EQUAL(udpPayloadSize, g_PayloadSizeSelfGenAnswers); } @@ -1692,13 +1610,13 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart) { pw.getHeader()->rcode = RCode::NXDomain; pw.commit(); - int res = getEDNSOptionsStart(reinterpret_cast(query.data()), qname.wirelength(), query.size(), &optRDPosition, &remaining); + int res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining); BOOST_CHECK_EQUAL(res, ENOENT); /* truncated packet (should not matter) */ query.resize(query.size() - 1); - res = getEDNSOptionsStart(reinterpret_cast(query.data()), qname.wirelength(), query.size(), &optRDPosition, &remaining); + res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining); BOOST_CHECK_EQUAL(res, ENOENT); } @@ -1710,7 +1628,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart) { pw.addOpt(512, 0, 0); pw.commit(); - int res = getEDNSOptionsStart(reinterpret_cast(query.data()), qname.wirelength(), query.size(), &optRDPosition, &remaining); + int res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining); BOOST_CHECK_EQUAL(res, 0); BOOST_CHECK_EQUAL(optRDPosition, optRDExpectedOffset); @@ -1719,7 +1637,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart) { /* truncated packet */ query.resize(query.size() - 1); - res = getEDNSOptionsStart(reinterpret_cast(query.data()), qname.wirelength(), query.size(), &optRDPosition, &remaining); + res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining); BOOST_CHECK_EQUAL(res, ENOENT); } @@ -1730,7 +1648,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart) { pw.addOpt(512, 0, 0, opts); pw.commit(); - int res = getEDNSOptionsStart(reinterpret_cast(query.data()), qname.wirelength(), query.size(), &optRDPosition, &remaining); + int res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining); BOOST_CHECK_EQUAL(res, 0); BOOST_CHECK_EQUAL(optRDPosition, optRDExpectedOffset); @@ -1738,7 +1656,7 @@ BOOST_AUTO_TEST_CASE(test_getEDNSOptionsStart) { /* truncated options (should not matter for this test) */ query.resize(query.size() - 1); - res = getEDNSOptionsStart(reinterpret_cast(query.data()), qname.wirelength(), query.size(), &optRDPosition, &remaining); + res = getEDNSOptionsStart(query, qname.wirelength(), &optRDPosition, &remaining); BOOST_CHECK_EQUAL(res, 0); BOOST_CHECK_EQUAL(optRDPosition, optRDExpectedOffset); BOOST_CHECK_EQUAL(remaining, query.size() - optRDExpectedOffset); @@ -1925,28 +1843,24 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { vector queryWithEDNS; DNSPacketWriter pw(query, name, QType::A, QClass::IN, 0); pw.getHeader()->rd = 1; - const uint16_t len = query.size(); DNSPacketWriter pwEDNS(queryWithEDNS, name, QType::A, QClass::IN, 0); pwEDNS.getHeader()->rd = 1; pwEDNS.addOpt(1232, 0, 0); pwEDNS.commit(); - const uint16_t ednsLen = queryWithEDNS.size(); /* test NXD */ { /* no incoming EDNS */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, false, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); - BOOST_CHECK(static_cast(dq.len) > query.size()); - MOADNSParser mdp(true, packet, dq.len); + BOOST_CHECK(packet.size() > query.size()); + MOADNSParser mdp(true, reinterpret_cast(packet.data()), packet.size()); BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "www.powerdns.com."); BOOST_CHECK_EQUAL(mdp.d_header.rcode, RCode::NXDomain); @@ -1961,18 +1875,16 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { } { /* now with incoming EDNS */ - char packet[1500]; - memcpy(packet, queryWithEDNS.data(), queryWithEDNS.size()); + auto packet = queryWithEDNS; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, ednsLen, sizeof(dnsheader), false, &qtype, nullptr, &consumed); - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), queryWithEDNS.size(), false, &queryTime); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, false, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, true, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); - BOOST_CHECK(static_cast(dq.len) > queryWithEDNS.size()); - MOADNSParser mdp(true, packet, dq.len); + BOOST_CHECK(packet.size() > queryWithEDNS.size()); + MOADNSParser mdp(true, reinterpret_cast(packet.data()), packet.size()); BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "www.powerdns.com."); BOOST_CHECK_EQUAL(mdp.d_header.rcode, RCode::NXDomain); @@ -1991,18 +1903,16 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { /* test No Data */ { /* no incoming EDNS */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, len, sizeof(dnsheader), false, &qtype, nullptr, &consumed); - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), query.size(), false, &queryTime); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, false, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); - BOOST_CHECK(static_cast(dq.len) > query.size()); - MOADNSParser mdp(true, packet, dq.len); + BOOST_CHECK(packet.size() > query.size()); + MOADNSParser mdp(true, reinterpret_cast(packet.data()), packet.size()); BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "www.powerdns.com."); BOOST_CHECK_EQUAL(mdp.d_header.rcode, RCode::NoError); @@ -2017,18 +1927,16 @@ BOOST_AUTO_TEST_CASE(test_setNegativeAndAdditionalSOA) { } { /* now with incoming EDNS */ - char packet[1500]; - memcpy(packet, queryWithEDNS.data(), queryWithEDNS.size()); + auto packet = queryWithEDNS; unsigned int consumed = 0; uint16_t qtype; - DNSName qname(packet, ednsLen, sizeof(dnsheader), false, &qtype, nullptr, &consumed); - auto dh = reinterpret_cast(packet); - DNSQuestion dq(&qname, qtype, QClass::IN, qname.wirelength(), &remote, &remote, dh, sizeof(packet), queryWithEDNS.size(), false, &queryTime); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, nullptr, &consumed); + DNSQuestion dq(&qname, qtype, QClass::IN, &remote, &remote, packet, false, &queryTime); BOOST_CHECK(setNegativeAndAdditionalSOA(dq, false, DNSName("zone."), 42, DNSName("mname."), DNSName("rname."), 1, 2, 3, 4 , 5)); - BOOST_CHECK(static_cast(dq.len) > queryWithEDNS.size()); - MOADNSParser mdp(true, packet, dq.len); + BOOST_CHECK(packet.size() > queryWithEDNS.size()); + MOADNSParser mdp(true, reinterpret_cast(packet.data()), packet.size()); BOOST_CHECK_EQUAL(mdp.d_qname.toString(), "www.powerdns.com."); BOOST_CHECK_EQUAL(mdp.d_header.rcode, RCode::NoError); @@ -2059,14 +1967,13 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { pw.commit(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; uint16_t qclass; - DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(packet), sizeof(packet), query.size(), false, nullptr); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); BOOST_CHECK(!parseEDNSOptions(dq)); } @@ -2081,14 +1988,13 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { pw.commit(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; uint16_t qclass; - DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(packet), sizeof(packet), query.size(), false, nullptr); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); BOOST_CHECK(!parseEDNSOptions(dq)); } @@ -2103,14 +2009,13 @@ BOOST_AUTO_TEST_CASE(getEDNSOptionsWithoutEDNS) { pw.commit(); /* large enough packet */ - char packet[1500]; - memcpy(packet, query.data(), query.size()); + auto packet = query; unsigned int consumed = 0; uint16_t qtype; uint16_t qclass; - DNSName qname(packet, query.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); - DNSQuestion dq(&qname, qtype, qclass, consumed, nullptr, &remote, reinterpret_cast(packet), sizeof(packet), query.size(), false, nullptr); + DNSName qname(reinterpret_cast(packet.data()), packet.size(), sizeof(dnsheader), false, &qtype, &qclass, &consumed); + DNSQuestion dq(&qname, qtype, qclass, nullptr, &remote, packet, false, nullptr); BOOST_CHECK(!parseEDNSOptions(dq)); } diff --git a/pdns/test-dnsdistpacketcache_cc.cc b/pdns/test-dnsdistpacketcache_cc.cc index 3ea30998c1..428e12cb98 100644 --- a/pdns/test-dnsdistpacketcache_cc.cc +++ b/pdns/test-dnsdistpacketcache_cc.cc @@ -44,24 +44,20 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { pwR.startRecord(a, QType::A, 7200, QClass::IN, DNSResourceRecord::ANSWER); pwR.xfr32BitInt(0x01020304); pwR.commit(); - uint16_t responseLen = response.size(); - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; boost::optional subnet; - auto dh = reinterpret_cast(query.data()); - DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, false, 0, boost::none); - found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); + found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, 0, true); if (found == true) { - BOOST_CHECK_EQUAL(responseBufSize, responseLen); - int match = memcmp(responseBuf, response.data(), responseLen); + BOOST_CHECK_EQUAL(dq.getData().size(), response.size()); + int match = memcmp(dq.getData().data(), response.data(), dq.getData().size()); BOOST_CHECK_EQUAL(match, 0); BOOST_CHECK(!subnet); } @@ -80,12 +76,10 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { vector query; DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &key, subnet, dnssecOK); if (found == true) { auto removed = PC.expungeByName(a); BOOST_CHECK_EQUAL(removed, 1U); @@ -102,13 +96,10 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheSimple) { vector query; DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; - uint16_t len = query.size(); uint32_t key = 0; boost::optional subnet; - char response[4096]; - uint16_t responseSize = sizeof(response); - DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), len, query.size(), false, &queryTime); - if(PC.get(dq, a.wirelength(), pwQ.getHeader()->id, response, &responseSize, &key, subnet, dnssecOK)) { + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + if (PC.get(dq, pwQ.getHeader()->id, &key, subnet, dnssecOK)) { matches++; } } @@ -152,27 +143,23 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheServFailTTL) { pwR.getHeader()->rcode = RCode::ServFail; pwR.getHeader()->id = pwQ.getHeader()->id; pwR.commit(); - uint16_t responseLen = response.size(); - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; boost::optional subnet; - auto dh = reinterpret_cast(query.data()); - DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); // Insert with failure-TTL of 0 (-> should not enter cache). - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(0)); - found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, false, RCode::ServFail, boost::optional(0)); + found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); // Insert with failure-TTL non-zero (-> should enter cache). - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, RCode::ServFail, boost::optional(300)); - found = PC.get(dq, a.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, false, RCode::ServFail, boost::optional(300)); + found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); } @@ -210,26 +197,21 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNoDataTTL) { pwR.addOpt(4096, 0, 0); pwR.commit(); - uint16_t responseLen = response.size(); - - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; boost::optional subnet; - auto dh = reinterpret_cast(query.data()); - DNSQuestion dq(&name, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); + DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NoError, boost::none); - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, false, RCode::NoError, boost::none); + found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); sleep(2); /* it should have expired by now */ - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); + found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); } @@ -267,26 +249,21 @@ BOOST_AUTO_TEST_CASE(test_PacketCacheNXDomainTTL) { pwR.addOpt(4096, 0, 0); pwR.commit(); - uint16_t responseLen = response.size(); - - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; boost::optional subnet; - auto dh = reinterpret_cast(query.data()); - DNSQuestion dq(&name, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, name.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); + DNSQuestion dq(&name, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &key, subnet, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); - PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, name, QType::A, QClass::IN, reinterpret_cast(response.data()), responseLen, false, RCode::NXDomain, boost::none); - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); + PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, name, QType::A, QClass::IN, response, false, RCode::NXDomain, boost::none); + found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, true); BOOST_CHECK(!subnet); sleep(2); /* it should have expired by now */ - found = PC.get(dq, name.wirelength(), pwR.getHeader()->id, responseBuf, &responseBufSize, &key, subnet, dnssecOK, 0, true); + found = PC.get(dq, pwR.getHeader()->id, &key, subnet, dnssecOK, 0, true); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK(!subnet); } @@ -320,17 +297,13 @@ static void threadMangler(unsigned int offset) pwR.startRecord(a, QType::A, 3600, QClass::IN, DNSResourceRecord::ANSWER); pwR.xfr32BitInt(0x01020304); pwR.commit(); - uint16_t responseLen = response.size(); - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; boost::optional subnet; - auto dh = reinterpret_cast(query.data()); - DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, dh, query.size(), query.size(), false, &queryTime); - g_PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + g_PC.get(dq, 0, &key, subnet, dnssecOK); - g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dh)), dnssecOK, a, QType::A, QClass::IN, (const char*) response.data(), responseLen, false, 0, boost::none); + g_PC.insert(key, subnet, *(getFlagsFromDNSHeader(dq.getHeader())), dnssecOK, a, QType::A, QClass::IN, response, false, 0, boost::none); } } catch(PDNSException& e) { @@ -356,12 +329,10 @@ static void threadReader(unsigned int offset) DNSPacketWriter pwQ(query, a, QType::A, QClass::IN, 0); pwQ.getHeader()->rd = 1; - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); uint32_t key = 0; boost::optional subnet; - DNSQuestion dq(&a, QType::A, QClass::IN, 0, &remote, &remote, (struct dnsheader*) query.data(), query.size(), query.size(), false, &queryTime); - bool found = g_PC.get(dq, a.wirelength(), 0, responseBuf, &responseBufSize, &key, subnet, dnssecOK); + DNSQuestion dq(&a, QType::A, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = g_PC.get(dq, 0, &key, subnet, dnssecOK); if (!found) { g_missing++; } @@ -433,13 +404,11 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { pwQ.addOpt(512, 0, 0, ednsOptions); pwQ.commit(); - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); ComboAddress remote("192.0.2.1"); struct timespec queryTime; gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, dnssecOK); + DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &key, subnetOut, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_REQUIRE(subnetOut); BOOST_CHECK_EQUAL(subnetOut->toString(), opt.source.toString()); @@ -455,10 +424,10 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { pwR.addOpt(512, 0, 0, ednsOptions); pwR.commit(); - PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), dnssecOK, qname, qtype, QClass::IN, reinterpret_cast(response.data()), response.size(), false, RCode::NoError, boost::none); + PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), dnssecOK, qname, qtype, QClass::IN, response, false, RCode::NoError, boost::none); BOOST_CHECK_EQUAL(PC.getSize(), 1U); - found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, dnssecOK); + found = PC.get(dq, 0, &key, subnetOut, dnssecOK); BOOST_CHECK_EQUAL(found, true); BOOST_REQUIRE(subnetOut); BOOST_CHECK_EQUAL(subnetOut->toString(), opt.source.toString()); @@ -478,13 +447,11 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { pwQ.addOpt(512, 0, 0, ednsOptions); pwQ.commit(); - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); ComboAddress remote("192.0.2.1"); struct timespec queryTime; gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &secondKey, subnetOut, dnssecOK); + DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &secondKey, subnetOut, dnssecOK); BOOST_CHECK_EQUAL(found, false); BOOST_CHECK_EQUAL(secondKey, key); BOOST_REQUIRE(subnetOut); @@ -516,7 +483,7 @@ BOOST_AUTO_TEST_CASE(test_PCCollision) { ednsOptions.push_back(std::make_pair(EDNSOptionCode::ECS, makeEDNSSubnetOptsString(opt))); pwFQ.addOpt(512, 0, 0, ednsOptions); pwFQ.commit(); - secondKey = pc.getKey(qname.toDNSString(), qname.wirelength(), secondQuery.data(), secondQuery.size(), false); + secondKey = pc.getKey(qname.toDNSString(), qname.wirelength(), secondQuery, false); auto pair = colMap.insert(std::make_pair(secondKey, opt.source)); total++; if (!pair.second) { @@ -556,13 +523,11 @@ BOOST_AUTO_TEST_CASE(test_PCDNSSECCollision) { pwQ.addOpt(512, 0, EDNS_HEADER_FLAG_DO); pwQ.commit(); - char responseBuf[4096]; - uint16_t responseBufSize = sizeof(responseBuf); ComboAddress remote("192.0.2.1"); struct timespec queryTime; gettime(&queryTime); - DNSQuestion dq(&qname, QType::AAAA, QClass::IN, 0, &remote, &remote, pwQ.getHeader(), query.size(), query.size(), false, &queryTime); - bool found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, true); + DNSQuestion dq(&qname, QType::AAAA, QClass::IN, &remote, &remote, query, false, &queryTime); + bool found = PC.get(dq, 0, &key, subnetOut, true); BOOST_CHECK_EQUAL(found, false); vector response; @@ -576,13 +541,13 @@ BOOST_AUTO_TEST_CASE(test_PCDNSSECCollision) { pwR.addOpt(512, 0, EDNS_HEADER_FLAG_DO); pwR.commit(); - PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), /* DNSSEC OK is set */ true, qname, qtype, QClass::IN, reinterpret_cast(response.data()), response.size(), false, RCode::NoError, boost::none); + PC.insert(key, subnetOut, *(getFlagsFromDNSHeader(pwR.getHeader())), /* DNSSEC OK is set */ true, qname, qtype, QClass::IN, response, false, RCode::NoError, boost::none); BOOST_CHECK_EQUAL(PC.getSize(), 1U); - found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, false); + found = PC.get(dq, 0, &key, subnetOut, false); BOOST_CHECK_EQUAL(found, false); - found = PC.get(dq, qname.wirelength(), 0, responseBuf, &responseBufSize, &key, subnetOut, true); + found = PC.get(dq, 0, &key, subnetOut, true); BOOST_CHECK_EQUAL(found, true); } diff --git a/regression-tests.dnsdist/test_Trailing.py b/regression-tests.dnsdist/test_Trailing.py index fdf009a827..ee44a99b5f 100644 --- a/regression-tests.dnsdist/test_Trailing.py +++ b/regression-tests.dnsdist/test_Trailing.py @@ -24,7 +24,7 @@ class TestTrailingDataToBackend(DNSDistTest): addAction("added.trailing.tests.powerdns.com.", LuaAction(replaceTrailingData)) function fillBuffer(dq) - local available = dq.size - dq.len + local available = 4096 local tail = string.rep("A", available) local success = dq:setTrailingData(tail) if not success then