]> git.ipfire.org Git - thirdparty/pdns.git/commitdiff
dnsdist: Prevent copies of DNSQuestion and DNSResponse objects
authorRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 10 Apr 2019 10:22:32 +0000 (12:22 +0200)
committerRemi Gacogne <remi.gacogne@powerdns.com>
Wed, 10 Apr 2019 12:26:46 +0000 (14:26 +0200)
pdns/dnsdist-lua-actions.cc
pdns/dnsdist.cc
pdns/dnsdist.hh
pdns/test-dnsdist_cc.cc

index 4f255b6e8c0766ec2f8338cfcac9b657af856d7b..c9cadfc5ee01ea23fc2c3e8b77e8ce683c3a2f90 100644 (file)
@@ -745,7 +745,7 @@ private:
 class DnstapLogAction : public DNSAction, public boost::noncopyable
 {
 public:
-  DnstapLogAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSQuestion&, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
+  DnstapLogAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
   {
   }
   DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
@@ -755,7 +755,7 @@ public:
     {
       if (d_alterFunc) {
         std::lock_guard<std::mutex> lock(g_luamutex);
-        (*d_alterFunc)(*dq, &message);
+        (*d_alterFunc)(dq, &message);
       }
     }
     std::string data;
@@ -771,13 +771,13 @@ public:
 private:
   std::string d_identity;
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSQuestion&, DnstapMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > d_alterFunc;
 };
 
 class RemoteLogAction : public DNSAction, public boost::noncopyable
 {
 public:
-  RemoteLogAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey)
+  RemoteLogAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey)
   {
   }
   DNSAction::Action operator()(DNSQuestion* dq, string* ruleresult) const override
@@ -801,7 +801,7 @@ public:
 
     if (d_alterFunc) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      (*d_alterFunc)(*dq, &message);
+      (*d_alterFunc)(dq, &message);
     }
 
     std::string data;
@@ -816,7 +816,7 @@ public:
   }
 private:
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > d_alterFunc;
   std::string d_serverID;
   std::string d_ipEncryptKey;
 };
@@ -871,7 +871,7 @@ private:
 class DnstapLogResponseAction : public DNSResponseAction, public boost::noncopyable
 {
 public:
-  DnstapLogResponseAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSResponse&, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
+  DnstapLogResponseAction(const std::string& identity, std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > alterFunc): d_identity(identity), d_logger(logger), d_alterFunc(alterFunc)
   {
   }
   DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
@@ -883,7 +883,7 @@ public:
     {
       if (d_alterFunc) {
         std::lock_guard<std::mutex> lock(g_luamutex);
-        (*d_alterFunc)(*dr, &message);
+        (*d_alterFunc)(dr, &message);
       }
     }
     std::string data;
@@ -899,13 +899,13 @@ public:
 private:
   std::string d_identity;
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSResponse&, DnstapMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > d_alterFunc;
 };
 
 class RemoteLogResponseAction : public DNSResponseAction, public boost::noncopyable
 {
 public:
-  RemoteLogResponseAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey), d_includeCNAME(includeCNAME)
+  RemoteLogResponseAction(std::shared_ptr<RemoteLoggerInterface>& logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > alterFunc, const std::string& serverID, const std::string& ipEncryptKey, bool includeCNAME): d_logger(logger), d_alterFunc(alterFunc), d_serverID(serverID), d_ipEncryptKey(ipEncryptKey), d_includeCNAME(includeCNAME)
   {
   }
   DNSResponseAction::Action operator()(DNSResponse* dr, string* ruleresult) const override
@@ -929,7 +929,7 @@ public:
 
     if (d_alterFunc) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      (*d_alterFunc)(*dr, &message);
+      (*d_alterFunc)(dr, &message);
     }
 
     std::string data;
@@ -944,7 +944,7 @@ public:
   }
 private:
   std::shared_ptr<RemoteLoggerInterface> d_logger;
-  boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > d_alterFunc;
+  boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > d_alterFunc;
   std::string d_serverID;
   std::string d_ipEncryptKey;
   bool d_includeCNAME;
@@ -1226,7 +1226,7 @@ void setupLuaActions()
       return std::shared_ptr<DNSResponseAction>(new LuaResponseAction(func));
     });
 
-  g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSQuestion&, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
+  g_lua.writeFunction("RemoteLogAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<std::unordered_map<std::string, std::string>> vars) {
       // avoids potentially-evaluated-expression warning with clang.
       RemoteLoggerInterface& rl = *logger.get();
       if (typeid(rl) != typeid(RemoteLogger)) {
@@ -1252,7 +1252,7 @@ void setupLuaActions()
 #endif
     });
 
-  g_lua.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSResponse&, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<bool> includeCNAME, boost::optional<std::unordered_map<std::string, std::string>> vars) {
+  g_lua.writeFunction("RemoteLogResponseAction", [](std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DNSDistProtoBufMessage*)> > alterFunc, boost::optional<bool> includeCNAME, boost::optional<std::unordered_map<std::string, std::string>> vars) {
       // avoids potentially-evaluated-expression warning with clang.
       RemoteLoggerInterface& rl = *logger.get();
       if (typeid(rl) != typeid(RemoteLogger)) {
@@ -1278,7 +1278,7 @@ void setupLuaActions()
 #endif
     });
 
-  g_lua.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSQuestion&, DnstapMessage*)> > alterFunc) {
+  g_lua.writeFunction("DnstapLogAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSQuestion*, DnstapMessage*)> > alterFunc) {
 #ifdef HAVE_PROTOBUF
       return std::shared_ptr<DNSAction>(new DnstapLogAction(identity, logger, alterFunc));
 #else
@@ -1286,7 +1286,7 @@ void setupLuaActions()
 #endif
     });
 
-  g_lua.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(const DNSResponse&, DnstapMessage*)> > alterFunc) {
+  g_lua.writeFunction("DnstapLogResponseAction", [](const std::string& identity, std::shared_ptr<RemoteLoggerInterface> logger, boost::optional<std::function<void(DNSResponse*, DnstapMessage*)> > alterFunc) {
 #ifdef HAVE_PROTOBUF
       return std::shared_ptr<DNSResponseAction>(new DnstapLogResponseAction(identity, logger, alterFunc));
 #else
index 461c5f389649257d2ca3a08b69294b1b3bfcfd58..0e3d00c48e7f9ec822748686f7fafb61becf3d8d 100644 (file)
@@ -1006,7 +1006,7 @@ static bool applyRulesToQuery(LocalHolders& holders, DNSQuestion& dq, string& po
     bool countQuery{true};
     if(g_qcount.filter) {
       std::lock_guard<std::mutex> lock(g_luamutex);
-      std::tie (countQuery, qname) = g_qcount.filter(dq);
+      std::tie (countQuery, qname) = g_qcount.filter(&dq);
     }
 
     if(countQuery) {
index 663e18061bdb115b3deaef847b91d12ef1fa5d5e..d0a3a17845d908a3c6ee30392504bc8633ee2055 100644 (file)
@@ -65,6 +65,9 @@ struct DNSQuestion
     const uint16_t* flags = getFlagsFromDNSHeader(dh);
     origFlags = *flags;
   }
+  DNSQuestion(const DNSQuestion&) = delete;
+  DNSQuestion& operator=(const DNSQuestion&) = delete;
+  DNSQuestion(DNSQuestion&&) = default;
 
 #ifdef HAVE_PROTOBUF
   boost::optional<boost::uuids::uuid> uniqueId;
@@ -108,6 +111,9 @@ struct DNSResponse : DNSQuestion
 {
   DNSResponse(const DNSName* name, uint16_t type, uint16_t class_, unsigned int consumed, const ComboAddress* lc, const ComboAddress* rem, struct dnsheader* header, size_t bufferSize, uint16_t responseLen, bool isTcp, const struct timespec* queryTime_):
     DNSQuestion(name, type, class_, consumed, lc, rem, header, bufferSize, responseLen, isTcp, queryTime_) { }
+  DNSResponse(const DNSResponse&) = delete;
+  DNSResponse& operator=(const DNSResponse&) = delete;
+  DNSResponse(DNSResponse&&) = default;
 };
 
 /* so what could you do:
@@ -565,7 +571,7 @@ struct IDState
 };
 
 typedef std::unordered_map<string, unsigned int> QueryCountRecords;
-typedef std::function<std::tuple<bool, string>(DNSQuestion dq)> QueryCountFilter;
+typedef std::function<std::tuple<bool, string>(const DNSQuestion* dq)> QueryCountFilter;
 struct QueryCount {
   QueryCount()
   {
index f34968e7416f8a3f65739b823b08262c3bf1b79a..100dd47109fbf539896e8d769d4cd45b1a4610b2 100644 (file)
@@ -1105,8 +1105,7 @@ static DNSQuestion getDNSQuestion(const DNSName& qname, const uint16_t qtype, co
 {
   dnsheader* dh = reinterpret_cast<dnsheader*>(query.data());
 
-  DNSQuestion dq(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime);
-  return dq;
+  return DNSQuestion(&qname, qtype, qclass, qname.wirelength(), &lc, &rem, dh, query.size(), len, false, &realTime);
 }
 
 static DNSQuestion turnIntoResponse(const DNSName& qname, const uint16_t qtype, const uint16_t qclass, const ComboAddress& lc, const ComboAddress& rem, const struct timespec& queryRealTime, vector<uint8_t>&  query, bool resizeBuffer=true)