]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Prevent allocations and copies by using the right types
authorRemi Gacogne <remi.gacogne@powerdns.com>
Tue, 13 Oct 2020 08:37:50 +0000 (10:37 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Mon, 11 Jan 2021 09:22:00 +0000 (10:22 +0100)
pdns/dnsdist-ecs.cc
pdns/dnsdist-ecs.hh
pdns/dnsdist-lua-actions.cc
pdns/dnsdist-lua-bindings-dnsquestion.cc
pdns/dnsdist.cc
pdns/dnsdistdist/dnsdist-lua-ffi-interface.h
pdns/dnsdistdist/dnsdist-lua-ffi.cc
pdns/dnsdistdist/dnsdist-rules.hh
pdns/dnsdistdist/doh.cc
pdns/test-dnsdist_cc.cc

index 7a3a28c85c0d75015b6be1a232c35552215c3aa3..a45ecd86be2aa94062189e211690f23106a8245e 100644 (file)
@@ -41,7 +41,7 @@ uint16_t g_ECSSourcePrefixV6 = 56;
 bool g_ECSOverride{false};
 bool g_addEDNSToSelfGeneratedResponses{true};
 
-int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent)
+int rewriteResponseWithoutEDNS(const std::vector<uint8_t>& initialPacket, vector<uint8_t>& newContent)
 {
   assert(initialPacket.size() >= sizeof(dnsheader));
   const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
@@ -52,7 +52,7 @@ int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>
   if (ntohs(dh->qdcount) == 0)
     return ENOENT;
 
-  PacketReader pr(initialPacket);
+  GenericPacketReader<std::vector<uint8_t>> pr(initialPacket);
 
   size_t idx = 0;
   DNSName rrname;
@@ -149,7 +149,7 @@ static bool addOrReplaceECSOption(std::vector<std::pair<uint16_t, std::string>>&
   return true;
 }
 
-static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
+static bool slowRewriteQueryWithExistingEDNS(const std::vector<uint8_t>& initialPacket, vector<uint8_t>& newContent, bool& ednsAdded, bool& ecsAdded, bool overrideExisting, const string& newECSOption)
 {
   assert(initialPacket.size() >= sizeof(dnsheader));
   const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
@@ -165,7 +165,7 @@ static bool slowRewriteQueryWithExistingEDNS(const std::string& initialPacket, v
     throw std::runtime_error("slowRewriteQueryWithExistingEDNS() should not be called for queries that have no EDNS");
   }
 
-  PacketReader pr(initialPacket);
+  GenericPacketReader<std::vector<uint8_t>> pr(initialPacket);
 
   size_t idx = 0;
   DNSName rrname;
@@ -317,7 +317,7 @@ static bool slowParseEDNSOptions(const std::vector<uint8_t>& packet, std::shared
   return true;
 }
 
-int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last)
+int locateEDNSOptRR(const std::vector<uint8_t>& packet, uint16_t * optStart, size_t * optLen, bool * last)
 {
   assert(optStart != NULL);
   assert(optLen != NULL);
@@ -327,7 +327,8 @@ int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * opt
   if (ntohs(dh->arcount) == 0)
     return ENOENT;
 
-  PacketReader pr(packet);
+  GenericPacketReader<std::vector<uint8_t>> pr(packet);
+
   size_t idx = 0;
   DNSName rrname;
   uint16_t qdcount = ntohs(dh->qdcount);
@@ -436,7 +437,7 @@ void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPref
   generateEDNSOption(EDNSOptionCode::ECS, payload, res);
 }
 
-void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
+bool generateOptRR(const std::string& optRData, std::vector<uint8_t>& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK)
 {
   const uint8_t name = 0;
   dnsrecordheader dh;
@@ -445,15 +446,22 @@ void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayload
   edns0.version = 0;
   edns0.extFlags = dnssecOK ? htons(EDNS_HEADER_FLAG_DO) : 0;
 
+  if ((maximumSize - res.size()) < (sizeof(name) + sizeof(dh) + optRData.length())) {
+    return false;
+  }
+
   dh.d_type = htons(QType::OPT);
   dh.d_class = htons(udpPayloadSize);
   static_assert(sizeof(EDNS0Record) == sizeof(dh.d_ttl), "sizeof(EDNS0Record) must match sizeof(dnsrecordheader.d_ttl)");
   memcpy(&dh.d_ttl, &edns0, sizeof edns0);
   dh.d_clen = htons(static_cast<uint16_t>(optRData.length()));
-  res.reserve(sizeof(name) + sizeof(dh) + optRData.length());
-  res.assign(reinterpret_cast<const char *>(&name), sizeof name);
-  res.append(reinterpret_cast<const char *>(&dh), sizeof(dh));
-  res.append(optRData.c_str(), optRData.length());
+
+  res.reserve(res.size() + sizeof(name) + sizeof(dh) + optRData.length());
+  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&name), reinterpret_cast<const uint8_t*>(&name) + sizeof(name));
+  res.insert(res.end(), reinterpret_cast<const uint8_t*>(&dh), reinterpret_cast<const uint8_t*>(&dh) + sizeof(dh));
+  res.insert(res.end(), reinterpret_cast<const uint8_t*>(optRData.data()), reinterpret_cast<const uint8_t*>(optRData.data()) + optRData.length());
+
+  return true;
 }
 
 static bool replaceEDNSClientSubnetOption(std::vector<uint8_t>& packet, size_t maximumSize, size_t const oldEcsOptionStartPosition, size_t const oldEcsOptionSize, size_t const optRDLenPosition, const string& newECSOption)
@@ -556,17 +564,10 @@ static bool addECSToExistingOPT(std::vector<uint8_t>& packet, size_t maximumSize
 
 static bool addEDNSWithECS(std::vector<uint8_t>& packet, size_t maximumSize, const string& newECSOption, bool& ednsAdded, bool& ecsAdded)
 {
-  /* we need to add a EDNS0 RR with one EDNS0 ECS option, fixing the AR count */
-  string EDNSRR;
-  generateOptRR(newECSOption, EDNSRR, g_EdnsUDPPayloadSize, 0, false);
-
-  if ((maximumSize - packet.size()) < EDNSRR.size()) {
+  if (!generateOptRR(newECSOption, packet, maximumSize, g_EdnsUDPPayloadSize, 0, false)) {
     return false;
   }
 
-#warning FIXME: we can avoid a copy here by generating in place
-  packet.insert(packet.end(), EDNSRR.begin(), EDNSRR.end());
-
   struct dnsheader* dh = reinterpret_cast<struct dnsheader*>(packet.data());
   uint16_t arcount = ntohs(dh->arcount);
   arcount++;
@@ -587,7 +588,7 @@ bool handleEDNSClientSubnet(std::vector<uint8_t>& packet, const size_t maximumSi
     vector<uint8_t> newContent;
     newContent.reserve(packet.size());
 
-    if (!slowRewriteQueryWithExistingEDNS(std::string(reinterpret_cast<const char*>(packet.data()), packet.size()), newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
+    if (!slowRewriteQueryWithExistingEDNS(packet, newContent, ednsAdded, ecsAdded, overrideExisting, newECSOption)) {
       ednsAdded = false;
       ecsAdded = false;
       return false;
@@ -708,7 +709,7 @@ int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optio
   return 0;
 }
 
-bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
+bool isEDNSOptionInOpt(const std::vector<uint8_t>& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart, uint16_t* optContentLen)
 {
   if (optLen < optRecordMinimumSize) {
     return false;
@@ -747,7 +748,7 @@ bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const s
   return false;
 }
 
-int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent)
+int rewriteResponseWithoutEDNSOption(const std::vector<uint8_t>& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent)
 {
   assert(initialPacket.size() >= sizeof(dnsheader));
   const struct dnsheader* dh = reinterpret_cast<const struct dnsheader*>(initialPacket.data());
@@ -758,7 +759,7 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin
   if (ntohs(dh->qdcount) == 0)
     return ENOENT;
 
-  PacketReader pr(initialPacket);
+  GenericPacketReader<std::vector<uint8_t>> pr(initialPacket);
 
   size_t idx = 0;
   DNSName rrname;
@@ -844,12 +845,12 @@ int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uin
   return 0;
 }
 
-bool addEDNS(std::vector<uint8_t>& packet, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
+bool addEDNS(std::vector<uint8_t>& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode)
 {
-  std::string optRecord;
-  generateOptRR(std::string(), optRecord, payloadSize, ednsrcode, dnssecOK);
+  if (!generateOptRR(std::string(), packet, maximumSize, payloadSize, ednsrcode, dnssecOK)) {
+    return false;
+  }
 
-  packet.insert(packet.end(), optRecord.begin(), optRecord.end());
   auto dh = reinterpret_cast<dnsheader*>(packet.data());
   dh->arcount = htons(ntohs(dh->arcount) + 1);
 
@@ -937,7 +938,7 @@ bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone,
 
   if (hadEDNS) {
     /* now we need to add a new OPT record */
-    return addEDNS(packet, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
+    return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
   }
 
   return true;
@@ -976,7 +977,7 @@ bool addEDNSToQueryTurnedResponse(DNSQuestion& dq)
 
   if (g_addEDNSToSelfGeneratedResponses) {
     /* now we need to add a new OPT record */
-    return addEDNS(packet, dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
+    return addEDNS(packet, dq.getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, dq.ednsRCode);
   }
 
   /* otherwise we are just fine */
@@ -1048,9 +1049,7 @@ bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
   size_t optLen = 0;
   bool last = false;
   const auto& packet = dq.getData();
-#warning FIXME: save an alloc+copy
-  std::string packetStr(reinterpret_cast<const char*>(packet.data()), packet.size());
-  int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
+  int res = locateEDNSOptRR(packet, &optStart, &optLen, &last);
   if (res != 0) {
     // no EDNS OPT RR
     return false;
@@ -1060,7 +1059,7 @@ bool getEDNS0Record(const DNSQuestion& dq, EDNS0Record& edns0)
     return false;
   }
 
-  if (optStart < packet.size() && packetStr.at(optStart) != 0) {
+  if (optStart < packet.size() && packet.at(optStart) != 0) {
     // OPT RR Name != '.'
     return false;
   }
index 385454a320b7ea07c559b4e04a23a7bbd6b76753..60492576958c2fc30d9fc8035d59934df55802ed 100644 (file)
@@ -27,15 +27,15 @@ static const size_t optRecordMinimumSize = 11;
 extern size_t g_EdnsUDPPayloadSize;
 extern uint16_t g_PayloadSizeSelfGenAnswers;
 
-int rewriteResponseWithoutEDNS(const std::string& initialPacket, vector<uint8_t>& newContent);
-int locateEDNSOptRR(const std::string& packet, uint16_t * optStart, size_t * optLen, bool * last);
-void generateOptRR(const std::string& optRData, string& res, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK);
+int rewriteResponseWithoutEDNS(const std::vector<uint8_t>& initialPacket, vector<uint8_t>& newContent);
+int locateEDNSOptRR(const std::vector<uint8_t> & packet, uint16_t * optStart, size_t * optLen, bool * last);
+bool generateOptRR(const std::string& optRData, std::vector<uint8_t>& res, size_t maximumSize, uint16_t udpPayloadSize, uint8_t ednsrcode, bool dnssecOK);
 void generateECSOption(const ComboAddress& source, string& res, uint16_t ECSPrefixLength);
 int removeEDNSOptionFromOPT(char* optStart, size_t* optLen, const uint16_t optionCodeToRemove);
-int rewriteResponseWithoutEDNSOption(const std::string& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
+int rewriteResponseWithoutEDNSOption(const std::vector<uint8_t>& initialPacket, const uint16_t optionCodeToSkip, vector<uint8_t>& newContent);
 int getEDNSOptionsStart(const std::vector<uint8_t>& packet, const size_t offset, uint16_t* optRDPosition, size_t * remaining);
-bool isEDNSOptionInOpt(const std::string& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart = nullptr, uint16_t* optContentLen = nullptr);
-bool addEDNS(std::vector<uint8_t>& packet, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode);
+bool isEDNSOptionInOpt(const std::vector<uint8_t>& packet, const size_t optStart, const size_t optLen, const uint16_t optionCodeToFind, size_t* optContentStart = nullptr, uint16_t* optContentLen = nullptr);
+bool addEDNS(std::vector<uint8_t>& packet, size_t maximumSize, bool dnssecOK, uint16_t payloadSize, uint8_t ednsrcode);
 bool addEDNSToQueryTurnedResponse(DNSQuestion& dq);
 bool setNegativeAndAdditionalSOA(DNSQuestion& dq, bool nxd, const DNSName& zone, uint32_t ttl, const DNSName& mname, const DNSName& rname, uint32_t serial, uint32_t refresh, uint32_t retry, uint32_t expire, uint32_t minimum);
 
index 3b962af8640732f2d9288c06becbe6fbe1f9f2ad..0a3d7abaa9bbfa5f8253e429a3a90bc7c95076fa 100644 (file)
@@ -635,7 +635,7 @@ DNSAction::Action SpoofAction::operator()(DNSQuestion* dq, std::string* ruleresu
   dq->getHeader()->ancount = htons(dq->getHeader()->ancount);
 
   if (hadEDNS && raw == false) {
-    addEDNS(dq->getMutableData(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0);
+    addEDNS(dq->getMutableData(), dq->getMaximumSize(), dnssecOK, g_PayloadSizeSelfGenAnswers, 0);
   }
 
   return Action::HeaderModify;
@@ -658,16 +658,10 @@ public:
     std::string optRData;
     generateEDNSOption(d_code, mac, optRData);
 
-    std::string res;
-    generateOptRR(optRData, res, g_EdnsUDPPayloadSize, 0, false);
-
-    if (!dq->hasRoomFor(res.length())) {
-      return Action::None;
-    }
-
-    dq->getHeader()->arcount = htons(1);
     auto& data = dq->getMutableData();
-    data.insert(data.end(), res.begin(), res.end());
+    if (generateOptRR(optRData, data, dq->getMaximumSize(), g_EdnsUDPPayloadSize, 0, false)) {
+      dq->getHeader()->arcount = htons(1);
+    }
 
     return Action::None;
   }
index 01988b7a23c973b841bbbfbd00b5f4ac856d77ed..553b1f521762891b9af1435e4b73249a5f28afd5 100644 (file)
@@ -38,8 +38,6 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerMember<dnsheader* (DNSQuestion::*)>("dh", [](const DNSQuestion& dq) -> dnsheader* { return const_cast<DNSQuestion&>(dq).getHeader(); }, [](DNSQuestion& dq, const dnsheader* dh) { *(dq.getHeader()) = *dh; });
   luaCtx.registerMember<uint16_t (DNSQuestion::*)>("len", [](const DNSQuestion& dq) -> uint16_t { return dq.getData().size(); }, [](DNSQuestion& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
   luaCtx.registerMember<uint8_t (DNSQuestion::*)>("opcode", [](const DNSQuestion& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSQuestion& dq, uint8_t newOpcode) { (void) newOpcode; });
-  #warning FIXME we need to provide Lua with a way to update the size
-  //luaCtx.registerMember<size_t (DNSQuestion::*)>("size", [](const DNSQuestion& dq) -> size_t { return dq.getData().size(); }, [](DNSQuestion& dq, size_t newSize) { (void) newSize; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("tcp", [](const DNSQuestion& dq) -> bool { return dq.tcp; }, [](DNSQuestion& dq, bool newTcp) { (void) newTcp; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("skipCache", [](const DNSQuestion& dq) -> bool { return dq.skipCache; }, [](DNSQuestion& dq, bool newSkipCache) { dq.skipCache = newSkipCache; });
   luaCtx.registerMember<bool (DNSQuestion::*)>("useECS", [](const DNSQuestion& dq) -> bool { return dq.useECS; }, [](DNSQuestion& dq, bool useECS) { dq.useECS = useECS; });
@@ -141,8 +139,6 @@ void setupLuaBindingsDNSQuestion(LuaContext& luaCtx)
   luaCtx.registerMember<dnsheader* (DNSResponse::*)>("dh", [](const DNSResponse& dr) -> dnsheader* { return const_cast<DNSResponse&>(dr).getHeader(); }, [](DNSResponse& dr, const dnsheader* dh) { *(dr.getHeader()) = *dh; });
   luaCtx.registerMember<uint16_t (DNSResponse::*)>("len", [](const DNSResponse& dq) -> uint16_t { return dq.getData().size(); }, [](DNSResponse& dq, uint16_t newlen) { dq.getMutableData().resize(newlen); });
   luaCtx.registerMember<uint8_t (DNSResponse::*)>("opcode", [](const DNSResponse& dq) -> uint8_t { return dq.getHeader()->opcode; }, [](DNSResponse& dq, uint8_t newOpcode) { (void) newOpcode; });
-  #warning FIXME we need to provide Lua with a way to update the size
-  //luaCtx.registerMember<size_t (DNSResponse::*)>("size", [](const DNSResponse& dq) -> size_t { return dq.size; }, [](DNSResponse& dq, size_t newSize) { (void) newSize; });
   luaCtx.registerMember<bool (DNSResponse::*)>("tcp", [](const DNSResponse& dq) -> bool { return dq.tcp; }, [](DNSResponse& dq, bool newTcp) { (void) newTcp; });
   luaCtx.registerMember<bool (DNSResponse::*)>("skipCache", [](const DNSResponse& dq) -> bool { return dq.skipCache; }, [](DNSResponse& dq, bool newSkipCache) { dq.skipCache = newSkipCache; });
   luaCtx.registerFunction<void(DNSResponse::*)(std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc)>("editTTLs", [](DNSResponse& dr, std::function<uint32_t(uint8_t section, uint16_t qclass, uint16_t qtype, uint32_t ttl)> editFunc) {
index 4f039355da76657f07b096bb458a880369b93709..a64483c8d2f1daaf745b511721f8c8e657dc4da9 100644 (file)
@@ -147,7 +147,7 @@ std::set<std::string> g_capabilitiesToRetain;
 static size_t const s_initialUDPPacketBufferSize = s_maxPacketCacheEntrySize + DNSCRYPT_MAX_RESPONSE_PADDING_AND_MAC_SIZE;
 static_assert(s_initialUDPPacketBufferSize <= UINT16_MAX, "Packet size should fit in a uint16_t");
 
-static void truncateTC(std::vector<uint8_t>& packet, unsigned int qnameWireLength)
+static void truncateTC(std::vector<uint8_t>& packet, size_t maximumSize, unsigned int qnameWireLength)
 {
   try
   {
@@ -164,7 +164,7 @@ static void truncateTC(std::vector<uint8_t>& packet, unsigned int qnameWireLengt
     dh->ancount = dh->arcount = dh->nscount = 0;
 
     if (hadEDNS) {
-      addEDNS(packet, z & EDNS_HEADER_FLAG_DO, payloadSize, 0);
+      addEDNS(packet, maximumSize, z & EDNS_HEADER_FLAG_DO, payloadSize, 0);
     }
   }
   catch(...)
@@ -323,19 +323,17 @@ static bool fixUpResponse(std::vector<uint8_t>& response, const DNSName& qname,
     size_t optLen = 0;
     bool last = false;
 
-#warning FIXME: save an alloc+copy
-    const std::string responseStr(reinterpret_cast<const char*>(response.data()), response.size());
-    int res = locateEDNSOptRR(responseStr, &optStart, &optLen, &last);
+    int res = locateEDNSOptRR(response, &optStart, &optLen, &last);
 
     if (res == 0) {
       if (zeroScope) { // this finds if an EDNS Client Subnet scope was set, and if it is 0
         size_t optContentStart = 0;
         uint16_t optContentLen = 0;
         /* we need at least 4 bytes after the option length (family: 2, source prefix-length: 1, scope prefix-length: 1) */
-        if (isEDNSOptionInOpt(responseStr, optStart, optLen, EDNSOptionCode::ECS, &optContentStart, &optContentLen) && optContentLen >= 4) {
+        if (isEDNSOptionInOpt(response, optStart, optLen, EDNSOptionCode::ECS, &optContentStart, &optContentLen) && optContentLen >= 4) {
           /* see if the EDNS Client Subnet SCOPE PREFIX-LENGTH byte in position 3 is set to 0, which is the only thing
              we care about. */
-          *zeroScope = responseStr.at(optContentStart + 3) == 0;
+          *zeroScope = response.at(optContentStart + 3) == 0;
         }
       }
 
@@ -353,7 +351,7 @@ static bool fixUpResponse(std::vector<uint8_t>& response, const DNSName& qname,
         else {
           /* Removing an intermediary RR could lead to compression error */
           std::vector<uint8_t> rewrittenResponse;
-          if (rewriteResponseWithoutEDNS(responseStr, rewrittenResponse) == 0) {
+          if (rewriteResponseWithoutEDNS(response, rewrittenResponse) == 0) {
             response = std::move(rewrittenResponse);
           }
           else {
@@ -374,7 +372,7 @@ static bool fixUpResponse(std::vector<uint8_t>& response, const DNSName& qname,
         else {
           std::vector<uint8_t> rewrittenResponse;
           /* Removing an intermediary RR could lead to compression error */
-          if (rewriteResponseWithoutEDNSOption(responseStr, EDNSOptionCode::ECS, rewrittenResponse) == 0) {
+          if (rewriteResponseWithoutEDNSOption(response, EDNSOptionCode::ECS, rewrittenResponse) == 0) {
             response = std::move(rewrittenResponse);
           }
           else {
@@ -389,11 +387,10 @@ static bool fixUpResponse(std::vector<uint8_t>& response, const DNSName& qname,
 }
 
 #ifdef HAVE_DNSCRYPT
-static bool encryptResponse(std::vector<uint8_t>& response, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery)
+static bool encryptResponse(std::vector<uint8_t>& response, size_t maximumSize, bool tcp, std::shared_ptr<DNSCryptQuery> dnsCryptQuery)
 {
   if (dnsCryptQuery) {
-    #warning FIXME should not be harcoded
-    int res = dnsCryptQuery->encryptResponse(response, tcp ? std::numeric_limits<uint16_t>::max() : 4096, tcp);
+    int res = dnsCryptQuery->encryptResponse(response, maximumSize, tcp);
     if (res != 0) {
       /* dropping response */
       vinfolog("Error encrypting the response, dropping.");
@@ -468,7 +465,7 @@ bool processResponse(std::vector<uint8_t>& response, LocalStateHolder<vector<DNS
 
 #ifdef HAVE_DNSCRYPT
   if (!muted) {
-    if (!encryptResponse(response, dr.tcp, dr.dnsCryptQuery)) {
+    if (!encryptResponse(response, dr.getMaximumSize(), dr.tcp, dr.dnsCryptQuery)) {
       return false;
     }
   }
@@ -610,13 +607,12 @@ void responderThread(std::shared_ptr<DownstreamState> dss)
           continue;
         }
 
-        if (dh->tc && g_truncateTC) {
-          truncateTC(response, qnameWireLength);
-        }
-
         dh->id = ids->origID;
 
         DNSResponse dr = makeDNSResponseFromIDState(*ids, response, false);
+        if (dh->tc && g_truncateTC) {
+          truncateTC(response, dr.getMaximumSize(), qnameWireLength);
+        }
         memcpy(&cleartextDH, dr.getHeader(), sizeof(cleartextDH));
 
         if (!processResponse(response, localRespRulactions, dr, ids->cs && ids->cs->muted)) {
@@ -1116,7 +1112,7 @@ static bool prepareOutgoingResponse(LocalHolders& holders, ClientState& cs, DNSQ
 
 #ifdef HAVE_DNSCRYPT
   if (!cs.muted) {
-    if (!encryptResponse(dq.getMutableData(), dq.tcp, dq.dnsCryptQuery)) {
+    if (!encryptResponse(dq.getMutableData(), dq.getMaximumSize(), dq.tcp, dq.dnsCryptQuery)) {
       return false;
     }
   }
index cdf0c77bdd59bc0165ac8c98a3e53896ab5e6485..969cd2083e4ed3660aa4b8c2ab460e5584aaa6a6 100644 (file)
@@ -57,6 +57,7 @@ int dnsdist_ffi_dnsquestion_get_rcode(const dnsdist_ffi_dnsquestion_t* dq) __att
 void* dnsdist_ffi_dnsquestion_get_header(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 size_t dnsdist_ffi_dnsquestion_get_size(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
+bool dnsdist_ffi_dnsquestion_set_size(dnsdist_ffi_dnsquestion_t* dq, size_t newSize) __attribute__ ((visibility ("default")));
 uint8_t dnsdist_ffi_dnsquestion_get_opcode(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 bool dnsdist_ffi_dnsquestion_get_tcp(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
 bool dnsdist_ffi_dnsquestion_get_skip_cache(const dnsdist_ffi_dnsquestion_t* dq) __attribute__ ((visibility ("default")));
index 1ed1866e377ec2444bd4a73fd62245d3f240c698..4a96a208d8d012ffb93777763670c4210e790d70 100644 (file)
@@ -99,12 +99,22 @@ uint16_t dnsdist_ffi_dnsquestion_get_len(const dnsdist_ffi_dnsquestion_t* dq)
   return dq->dq->getData().size();
 }
 
-#warning FIXME : we need to provide a way to resize
 size_t dnsdist_ffi_dnsquestion_get_size(const dnsdist_ffi_dnsquestion_t* dq)
 {
   return dq->dq->getData().size();
 }
 
+bool dnsdist_ffi_dnsquestion_set_size(dnsdist_ffi_dnsquestion_t* dq, size_t newSize)
+{
+  try {
+    dq->dq->getMutableData().resize(newSize);
+    return true;
+  }
+  catch (const std::exception& e) {
+    return false;
+  }
+}
+
 uint8_t dnsdist_ffi_dnsquestion_get_opcode(const dnsdist_ffi_dnsquestion_t* dq)
 {
   return dq->dq->getHeader()->opcode;
index bbd70182510fae7cf71ae40945dcab0c9b9b095d..abbe81329b7e064f5782316eeb53cb7fbf59559f 100644 (file)
@@ -975,8 +975,7 @@ public:
     uint16_t optStart;
     size_t optLen = 0;
     bool last = false;
-    std::string packetStr(dq->getData().begin(), dq->getData().end());
-    int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
+    int res = locateEDNSOptRR(dq->getData(), &optStart, &optLen, &last);
     if (res != 0) {
       // no EDNS OPT RR
       return false;
@@ -986,12 +985,12 @@ public:
       return false;
     }
 
-    if (optStart < dq->getData().size() && packetStr.at(optStart) != 0) {
+    if (optStart < dq->getData().size() && dq->getData().at(optStart) != 0) {
       // OPT RR Name != '.'
       return false;
     }
 
-    return isEDNSOptionInOpt(packetStr, optStart, optLen, d_optcode);
+    return isEDNSOptionInOpt(dq->getData(), optStart, optLen, d_optcode);
   }
   string toString() const override
   {
index de54c818228161e66981ca53d410f3afb8ec6e8f..de8b322733e2ab30588040edcc0806ba660a97d0 100644 (file)
@@ -206,7 +206,7 @@ struct DOHServerConfig
   DOHServerConfig& operator=(const DOHServerConfig&) = delete;
 
   LocalHolders holders;
-  std::unordered_set<std::string> paths;
+  std::set<std::string> paths;
   h2o_globalconf_t h2o_config;
   h2o_context_t h2o_ctx;
   std::shared_ptr<DOHAcceptContext> accept_ctx{nullptr};
@@ -831,15 +831,18 @@ try
       ++dsc->cs->tlsUnknownqueries;
   }
 
-  #warning turn these into string_view?
-  string path(req->path.base, req->path.len);
-  string pathOnly(req->path_normalized.base, req->path_normalized.len);
-
+  // would be nice to be able to use a pdns_string_view there, but we would need heterogeneous lookups
+  // (having string in the set and compare them to string_view, for example. Note that comparing
+  // two boost::string_view uses the pointer, not the content).
+  const std::string pathOnly(req->path_normalized.base, req->path_normalized.len);
   if (dsc->paths.count(pathOnly) == 0) {
     h2o_send_error_404(req, "Not Found", "there is no endpoint configured for this path", 0);
     return 0;
   }
 
+  // would be nice to be able to use a pdns_string_view there,
+  // but regex (called by matches() internally) requires a null-terminated string
+  string path(req->path.base, req->path.len);
   for (const auto& entry : dsc->df->d_responsesMap) {
     if (entry->matches(path)) {
       const auto& customHeaders = entry->getHeaders();
@@ -882,7 +885,7 @@ try
         break;
       }
 
-      std::string decoded;
+      std::vector<uint8_t> decoded;
 
       /* rough estimate so we hopefully don't need a new allocation later */
       /* We reserve at least 512 additional bytes to be able to add EDNS, but we also want
@@ -901,9 +904,7 @@ try
         else
           ++dsc->df->d_http1Stats.d_nbQueries;
 
-#warning FIXME: performance
-        auto vect = std::vector<uint8_t>(decoded.begin(), decoded.end());
-        doh_dispatch_query(dsc, self, req, std::move(vect), local, remote, std::move(path));
+        doh_dispatch_query(dsc, self, req, std::move(decoded), local, remote, std::move(path));
       }
     }
     else
@@ -1084,13 +1085,11 @@ static void dnsdistclient(int qsock)
       auto dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(du->query.data()));
 
       if (!dh->arcount) {
-        std::string res;
-        generateOptRR(std::string(), res, 4096, 0, false);
-
-        du->query.insert(du->query.end(), res.begin(), res.end());
-        dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(du->query.data())); // may have reallocated
-        dh->arcount = htons(1);
-        du->ednsAdded = true;
+        if (generateOptRR(std::string(), du->query, 4096, 4096, 0, false)) {
+          dh = const_cast<struct dnsheader*>(reinterpret_cast<const struct dnsheader*>(du->query.data())); // may have reallocated
+          dh->arcount = htons(1);
+          du->ednsAdded = true;
+        }
       }
       else {
         // we leave existing EDNS in place
index eba97ce7cbd25dc807e0d4166eeb11a104cd2bac..df7ca35099f7a4e11f325ecaf9f3124df81842df 100644 (file)
@@ -901,7 +901,7 @@ BOOST_AUTO_TEST_CASE(removeEDNSWhenFirst) {
   pw.commit();
 
   vector<uint8_t> newResponse;
-  int res = rewriteResponseWithoutEDNS(std::string((const char *) response.data(), response.size()), newResponse);
+  int res = rewriteResponseWithoutEDNS(response, newResponse);
   BOOST_CHECK_EQUAL(res, 0);
 
   unsigned int consumed = 0;
@@ -933,7 +933,7 @@ BOOST_AUTO_TEST_CASE(removeEDNSWhenIntermediary) {
   pw.commit();
 
   vector<uint8_t> newResponse;
-  int res = rewriteResponseWithoutEDNS(std::string((const char *) response.data(), response.size()), newResponse);
+  int res = rewriteResponseWithoutEDNS(response, newResponse);
   BOOST_CHECK_EQUAL(res, 0);
 
   unsigned int consumed = 0;
@@ -963,7 +963,7 @@ BOOST_AUTO_TEST_CASE(removeEDNSWhenLast) {
   pw.commit();
 
   vector<uint8_t> newResponse;
-  int res = rewriteResponseWithoutEDNS(std::string((const char *) response.data(), response.size()), newResponse);
+  int res = rewriteResponseWithoutEDNS(response, newResponse);
 
   BOOST_CHECK_EQUAL(res, 0);
 
@@ -1004,7 +1004,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenOnlyOption) {
   size_t optLen = 0;
   bool last = false;
 
-  int res = locateEDNSOptRR(std::string((char *) response.data(), response.size()), &optStart, &optLen, &last);
+  int res = locateEDNSOptRR(response, &optStart, &optLen, &last);
   BOOST_CHECK_EQUAL(res, 0);
   BOOST_CHECK_EQUAL(last, true);
 
@@ -1056,7 +1056,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenFirstOption) {
   size_t optLen = 0;
   bool last = false;
 
-  int res = locateEDNSOptRR(std::string((char *) response.data(), response.size()), &optStart, &optLen, &last);
+  int res = locateEDNSOptRR(response, &optStart, &optLen, &last);
   BOOST_CHECK_EQUAL(res, 0);
   BOOST_CHECK_EQUAL(last, true);
 
@@ -1112,7 +1112,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenIntermediaryOption) {
   size_t optLen = 0;
   bool last = false;
 
-  int res = locateEDNSOptRR(std::string((char *) response.data(), response.size()), &optStart, &optLen, &last);
+  int res = locateEDNSOptRR(response, &optStart, &optLen, &last);
   BOOST_CHECK_EQUAL(res, 0);
   BOOST_CHECK_EQUAL(last, true);
 
@@ -1164,7 +1164,7 @@ BOOST_AUTO_TEST_CASE(removeECSWhenLastOption) {
   size_t optLen = 0;
   bool last = false;
 
-  int res = locateEDNSOptRR(std::string((char *) response.data(), response.size()), &optStart, &optLen, &last);
+  int res = locateEDNSOptRR(response, &optStart, &optLen, &last);
   BOOST_CHECK_EQUAL(res, 0);
   BOOST_CHECK_EQUAL(last, true);
 
@@ -1208,7 +1208,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenOnlyOption) {
   pw.commit();
 
   vector<uint8_t> newResponse;
-  int res = rewriteResponseWithoutEDNSOption(std::string((const char *) response.data(), response.size()), EDNSOptionCode::ECS, newResponse);
+  int res = rewriteResponseWithoutEDNSOption(response, EDNSOptionCode::ECS, newResponse);
   BOOST_CHECK_EQUAL(res, 0);
 
   BOOST_CHECK_EQUAL(newResponse.size(), response.size() - (origECSOptionStr.size() + 4));
@@ -1250,7 +1250,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenFirstOption) {
   pw.commit();
 
   vector<uint8_t> newResponse;
-  int res = rewriteResponseWithoutEDNSOption(std::string((const char *) response.data(), response.size()), EDNSOptionCode::ECS, newResponse);
+  int res = rewriteResponseWithoutEDNSOption(response, EDNSOptionCode::ECS, newResponse);
   BOOST_CHECK_EQUAL(res, 0);
 
   BOOST_CHECK_EQUAL(newResponse.size(), response.size() - (origECSOptionStr.size() + 4));
@@ -1294,7 +1294,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenIntermediaryOption) {
   pw.commit();
 
   vector<uint8_t> newResponse;
-  int res = rewriteResponseWithoutEDNSOption(std::string((const char *) response.data(), response.size()), EDNSOptionCode::ECS, newResponse);
+  int res = rewriteResponseWithoutEDNSOption(response, EDNSOptionCode::ECS, newResponse);
   BOOST_CHECK_EQUAL(res, 0);
 
   BOOST_CHECK_EQUAL(newResponse.size(), response.size() - (origECSOptionStr.size() + 4));
@@ -1336,7 +1336,7 @@ BOOST_AUTO_TEST_CASE(rewritingWithoutECSWhenLastOption) {
   pw.commit();
 
   vector<uint8_t> newResponse;
-  int res = rewriteResponseWithoutEDNSOption(std::string((const char *) response.data(), response.size()), EDNSOptionCode::ECS, newResponse);
+  int res = rewriteResponseWithoutEDNSOption(response, EDNSOptionCode::ECS, newResponse);
   BOOST_CHECK_EQUAL(res, 0);
 
   BOOST_CHECK_EQUAL(newResponse.size(), response.size() - (origECSOptionStr.size() + 4));
@@ -1670,8 +1670,7 @@ BOOST_AUTO_TEST_CASE(test_isEDNSOptionInOpt) {
     uint16_t optStart;
     size_t optLen;
     bool last = false;
-    std::string packetStr(reinterpret_cast<const char*>(query.data()), query.size());
-    int res = locateEDNSOptRR(packetStr, &optStart, &optLen, &last);
+    int res = locateEDNSOptRR(query, &optStart, &optLen, &last);
     if (res != 0) {
       // no EDNS OPT RR
       return false;
@@ -1681,12 +1680,12 @@ BOOST_AUTO_TEST_CASE(test_isEDNSOptionInOpt) {
       return false;
     }
 
-    if (optStart < query.size() && packetStr.at(optStart) != 0) {
+    if (optStart < query.size() && query.at(optStart) != 0) {
       // OPT RR Name != '.'
       return false;
     }
 
-    return isEDNSOptionInOpt(packetStr, optStart, optLen, code, optContentStart, optContentLen);
+    return isEDNSOptionInOpt(query, optStart, optLen, code, optContentStart, optContentLen);
   };
 
   const DNSName qname("www.powerdns.com.");