]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2639 in SNORT/snort3 from ~STECHEW/snort3:control_request_fix_sha...
authorSteve Chew (stechew) <stechew@cisco.com>
Sun, 6 Dec 2020 02:25:31 +0000 (02:25 +0000)
committerSteve Chew (stechew) <stechew@cisco.com>
Sun, 6 Dec 2020 02:25:31 +0000 (02:25 +0000)
Squashed commit of the following:

commit ac1f3fa3866ba47d09512acc3fb3e969b27f5603
Author: Steve Chew <stechew@cisco.com>
Date:   Fri Nov 20 11:48:19 2020 -0500

    main: convert Request to shared_ptr to avoid memory problems.

13 files changed:
src/main.cc
src/main.h
src/main/analyzer_command.cc
src/main/analyzer_command.h
src/main/control.cc
src/main/control.h
src/main/control_mgmt.cc
src/main/control_mgmt.h
src/main/request.cc
src/main/request.h
src/network_inspectors/appid/appid_module.cc
src/network_inspectors/rna/rna_module.cc
src/network_inspectors/rna/test/rna_module_mock.h

index d70ea9f6739698cbaf36c5d1fb5ee868e0487741..12506f74539dc20225528adbbf089f874307d37c 100644 (file)
@@ -34,7 +34,6 @@
 #include "lua/lua.h"
 #include "main/analyzer.h"
 #include "main/analyzer_command.h"
-#include "main/request.h"
 #include "main/shell.h"
 #include "main/snort.h"
 #include "main/snort_config.h"
@@ -112,15 +111,14 @@ static int main_read()
     return pig_poke->get(-1);
 }
 
-static Request request;
-static Request* current_request = &request;
+static SharedRequest current_request = std::make_shared<Request>();
 #ifdef SHELL
 static int current_fd = -1;
 #endif
 
-Request& get_current_request()
+SharedRequest get_current_request()
 {
-    return *current_request;
+    return current_request;
 }
 
 //-------------------------------------------------------------------------
@@ -369,10 +367,15 @@ int main_reload_config(lua_State* L)
         else
             current_request->respond("== reload failed - bad config\n");
 
-        HostAttributesManager::load_failure_cleanup();
         return 0;
     }
 
+    if ( !sc->attribute_hosts_file.empty() )
+    {
+        if ( !HostAttributesManager::load_hosts_file(sc, sc->attribute_hosts_file.c_str()) )
+            current_request->respond("== reload failed - host attributes file failed to load\n");
+    }
+
     int32_t num_hosts = HostAttributesManager::get_num_host_entries();
     if ( num_hosts >= 0 )
         LogMessage( "host attribute table: %d hosts loaded\n", num_hosts);
index d23e158cdebeb9ab640ef77f879d41dbb3c210a3..22d2710d12fc9266348b5b76a6195dbe902ff7cd 100644 (file)
@@ -26,7 +26,7 @@
 struct lua_State;
 
 const char* get_prompt();
-Request& get_current_request();
+SharedRequest get_current_request();
 
 // commands provided by the snort module
 int main_delete_inspector(lua_State* = nullptr);
index b381a9bd4ba008337b3ccaf4c9f697eb00fd7c8c..ea2a509807b9b5d36088caf570c4ad4d88a6b13d 100644 (file)
@@ -32,7 +32,6 @@
 #include "utils/stats.h"
 
 #include "analyzer.h"
-#include "request.h"
 #include "snort.h"
 #include "snort_config.h"
 #include "swapper.h"
@@ -95,7 +94,7 @@ ACGetStats::~ACGetStats()
     LogMessage("==================================================\n"); // Marking End of stats
 }
 
-ACSwap::ACSwap(Swapper* ps, Request* req, bool from_shell) : ps(ps), request(req), from_shell(from_shell)
+ACSwap::ACSwap(Swapper* ps, SharedRequest req, bool from_shell) : ps(ps), request(req), from_shell(from_shell)
 {
     assert(Swapper::get_reload_in_progress() == false);
     Swapper::set_reload_in_progress(true);
@@ -173,7 +172,7 @@ ACSwap::~ACSwap()
     request->respond("== reload complete\n", from_shell, true);
 }
 
-ACHostAttributesSwap::ACHostAttributesSwap(Request* req, bool from_shell)
+ACHostAttributesSwap::ACHostAttributesSwap(SharedRequest req, bool from_shell)
     : request(req), from_shell(from_shell)
 {
     assert(Swapper::get_reload_in_progress() == false);
index ee8c09831f932a3e59a49f92d3c1d68c178f8f83..1c9d904ec80c94cd255e87aa40464010345ac22d 100644 (file)
 #ifndef ANALYZER_COMMANDS_H
 #define ANALYZER_COMMANDS_H
 
-#include "main/snort_types.h"
+#include "request.h"
+#include "snort_types.h"
 
 class Analyzer;
-class Request;
 class Swapper;
 
 namespace snort
@@ -106,26 +106,26 @@ class ACSwap : public snort::AnalyzerCommand
 {
 public:
     ACSwap() = delete;
-    ACSwap(Swapper* ps, Request* req, bool from_shell);
+    ACSwap(Swapper* ps, SharedRequest req, bool from_shell);
     bool execute(Analyzer&, void**) override;
     const char* stringify() override { return "SWAP"; }
     ~ACSwap() override;
 private:
     Swapper *ps;
-    Request* request;
+    SharedRequest request;
     bool from_shell;
 };
 
 class ACHostAttributesSwap : public snort::AnalyzerCommand
 {
 public:
-    ACHostAttributesSwap(Request* req, bool from_shell);
+    ACHostAttributesSwap(SharedRequest req, bool from_shell);
     bool execute(Analyzer&, void**) override;
     const char* stringify() override { return "HOST_ATTRIBUTES_SWAP"; }
     ~ACHostAttributesSwap() override;
 
 private:
-    Request* request;
+    SharedRequest request;
     bool from_shell;
 };
 
index fd82e35234df0c0c5def49484eb3dbbdd00cdf8d..ed380b3e088f2e973918d351c0d859f0d0b86e5d 100644 (file)
@@ -27,7 +27,6 @@
 #include "utils/util.h"
 
 #include "control_mgmt.h"
-#include "request.h"
 #include "shell.h"
 
 using namespace snort;
@@ -42,7 +41,7 @@ ControlConn::ControlConn(int i, bool local)
     fd = i;
     local_control = local;
     sh = new Shell;
-    request = new Request(fd);
+    request = std::make_shared<Request>(fd);
     configure();
     show_prompt();
 }
@@ -52,7 +51,6 @@ ControlConn::~ControlConn()
     if( !local_control )
         close(fd);
     delete sh;
-    delete request;
 }
 
 void ControlConn::configure() const
@@ -60,7 +58,7 @@ void ControlConn::configure() const
     ModuleManager::load_commands(sh);
 }
 
-int ControlConn::shell_execute(int& current_fd, Request*& current_request)
+int ControlConn::shell_execute(int& current_fd, SharedRequest& current_request)
 {
     if ( !request->read() )
         return -1;
index 6f17f87e7a070ac456a918576fbbca9b641e8204..062ea1452055639f41fd02c5b819786b3e5270f5 100644 (file)
@@ -23,6 +23,7 @@
 #ifndef CONTROL_H
 #define CONTROL_H
 
+#include "main/request.h"
 #include "main/snort_types.h"
 
 class ControlConn
@@ -36,7 +37,7 @@ public:
 
     int get_fd() const { return fd; }
     class Shell* get_shell() const { return sh; }
-    class Request* get_request() const { return request; }
+    SharedRequest get_request() const { return request; }
     bool is_local_control() const { return local_control; }
 
     void block();
@@ -45,7 +46,7 @@ public:
     bool is_blocked() const { return blocked; }
 
     void configure() const;
-    int shell_execute(int& current_fd, Request*& current_request);
+    int shell_execute(int& current_fd, SharedRequest& current_request);
     bool show_prompt() const;
 
 private:
@@ -53,7 +54,7 @@ private:
     bool blocked = false;
     bool local_control;
     class Shell *sh;
-    class Request* request;
+    SharedRequest request;
 };
 
 #endif
index 4aaf8b7893afdfb53941d941d61d1bc6296fa31f..1a2c4c84d6b32eb07964886cc65c6d176756cebb 100644 (file)
@@ -34,7 +34,6 @@
 #include "utils/stats.h"
 #include "utils/util.h"
 #include "control.h"
-#include "request.h"
 #include "snort_config.h"
 #include "utils/util_cstring.h"
 
@@ -181,14 +180,14 @@ int ControlMgmt::socket_term()
     return 0;
 }
 
-bool ControlMgmt::process_control_commands(int& current_fd, Request*& current_request, int evnt_fd)
+bool ControlMgmt::process_control_commands(int& current_fd, SharedRequest& current_request, int evnt_fd)
 {
     auto control_conn = controls.find(evnt_fd);
 
     if (control_conn == controls.end())
         return false;
 
-    Request* old_request = current_request;
+    SharedRequest old_request = current_request;
     int fd = control_conn->second->shell_execute(current_fd, current_request);
     current_fd = -1;
     current_request = old_request;
@@ -204,7 +203,7 @@ bool ControlMgmt::process_control_commands(int& current_fd, Request*& current_re
     return true;
 }
 
-bool ControlMgmt::service_users(int& current_fd, Request*& current_request)
+bool ControlMgmt::service_users(int& current_fd, SharedRequest& current_request)
 {
     bool ret = false;
     struct epoll_event events[MAX_EPOLL_EVENTS];
@@ -336,7 +335,7 @@ int ControlMgmt::socket_term()
     return 0;
 }
 
-bool ControlMgmt::process_control_commands(int& current_fd, Request*& current_request)
+bool ControlMgmt::process_control_commands(int& current_fd, SharedRequest& current_request)
 {
     bool ret = false;
 
@@ -346,7 +345,7 @@ bool ControlMgmt::process_control_commands(int& current_fd, Request*& current_re
         int fd = (*control)->get_fd();
         if (FD_ISSET(fd, &inputs))
         {
-            Request* old_request = current_request;
+            SharedRequest old_request = current_request;
             fd = (*control)->shell_execute(current_fd, current_request);
             current_fd = -1;
             current_request = old_request;
@@ -370,7 +369,7 @@ bool ControlMgmt::process_control_commands(int& current_fd, Request*& current_re
     return ret;
 }
 
-bool ControlMgmt::service_users(int& current_fd, Request*& current_request)
+bool ControlMgmt::service_users(int& current_fd, SharedRequest& current_request)
 {
     FD_ZERO(&inputs);
     int max_fd = -1;
index bd9543c9206f39de53cc9ed8b7ab2e1ca8bce4cc..552df9b657a4b69553f8f89420707822e89ccc19 100644 (file)
@@ -25,6 +25,8 @@
 
 #include <vector>
 
+#include "request.h"
+
 class ControlConn;
 
 class ControlMgmt
@@ -37,8 +39,8 @@ public:
     static int socket_term();
     static int socket_conn();
 
-    static bool process_control_commands(int& current_fd, class Request*& current_request, int);
-    static bool process_control_commands(int& current_fd, class Request*& current_request);
+    static bool process_control_commands(int& current_fd, SharedRequest& current_request, int);
+    static bool process_control_commands(int& current_fd, SharedRequest& current_request);
 
     static ControlConn* find_control(int fd);
     static bool find_control(int fd, std::vector<ControlConn*>::iterator& control);
@@ -47,7 +49,7 @@ public:
     static void delete_control(int fd);
     static void delete_control(std::vector<ControlConn*>::iterator& control);
 
-    static bool service_users(int& current_fd, class Request*& current_request);
+    static bool service_users(int& current_fd, SharedRequest& current_request);
 
 private:
     static int setup_socket_family();
index 21e2957d2364c9e6b6288752d9ce9f9f5a362576..0cce5034b25673f28a9696607561825baddb3175 100644 (file)
@@ -113,7 +113,7 @@ bool Request::send_queued_response()
 }
 #endif
 
-Request& get_dispatched_request()
+SharedRequest get_dispatched_request()
 {
     return get_current_request();
 }
index 7e6620ac39ff32b4dd8f7ede7ba7aa2c18a37363..5a6e80c60f1fefd6ae8427bb374198d6d7ba07b4 100644 (file)
@@ -22,6 +22,7 @@
 #ifndef REQUEST_H
 #define REQUEST_H
 
+#include <memory>
 #include <mutex>
 #include <queue>
 
@@ -48,6 +49,8 @@ private:
     std::mutex queued_response_mutex;
 };
 
-SO_PUBLIC Request& get_dispatched_request();
+using SharedRequest = std::shared_ptr<Request>;
+
+SO_PUBLIC SharedRequest get_dispatched_request();
 
 #endif
index 83bbe2747b86f368d1fd584166a1442e1bde40a5..ee7cd5062701663e5f28dfa7024479ebb1657372 100644 (file)
@@ -165,14 +165,14 @@ class ACThirdPartyAppIdContextUnload : public AnalyzerCommand
 public:
     bool execute(Analyzer&, void**) override;
     ACThirdPartyAppIdContextUnload(const AppIdInspector& inspector, ThirdPartyAppIdContext* tp_ctxt,
-        Request& current_request, bool from_shell): inspector(inspector),
+        SharedRequest current_request, bool from_shell): inspector(inspector),
         tp_ctxt(tp_ctxt), request(current_request), from_shell(from_shell) { }
     ~ACThirdPartyAppIdContextUnload() override;
     const char* stringify() override { return "THIRD-PARTY_CONTEXT_UNLOAD"; }
 private:
     const AppIdInspector& inspector;
     ThirdPartyAppIdContext* tp_ctxt =  nullptr;
-    Request& request;
+    SharedRequest request;
     bool from_shell;
 };
 
@@ -199,7 +199,7 @@ ACThirdPartyAppIdContextUnload::~ACThirdPartyAppIdContextUnload()
     ctxt.create_tp_appid_ctxt();
     main_broadcast_command(new ACThirdPartyAppIdContextSwap(inspector));
     LogMessage("== reload third-party complete\n");
-    request.respond("== reload third-party complete\n", from_shell, true);
+    request->respond("== reload third-party complete\n", from_shell, true);
     Swapper::set_reload_in_progress(false);
 }
 
@@ -208,14 +208,14 @@ class ACOdpContextSwap : public AnalyzerCommand
 public:
     bool execute(Analyzer&, void**) override;
     ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt,
-        Request& current_request, bool from_shell) : inspector(inspector),
+        SharedRequest current_request, bool from_shell) : inspector(inspector),
         odp_ctxt(odp_ctxt), request(current_request), from_shell(from_shell) { }
     ~ACOdpContextSwap() override;
     const char* stringify() override { return "ODP_CONTEXT_SWAP"; }
 private:
     const AppIdInspector& inspector;
     OdpContext& odp_ctxt;
-    Request& request;
+    SharedRequest request;
     bool from_shell;
 };
 
@@ -252,7 +252,7 @@ ACOdpContextSwap::~ACOdpContextSwap()
         ctxt.get_odp_ctxt().get_app_info_mgr().dump_appid_configurations(file_path);
     }
     LogMessage("== reload detectors complete\n");
-    request.respond("== reload detectors complete\n", from_shell, true);
+    request->respond("== reload detectors complete\n", from_shell, true);
     Swapper::set_reload_in_progress(false);
 }
 
@@ -301,28 +301,28 @@ static int disable_debug(lua_State*)
 static int reload_third_party(lua_State* L)
 {
     bool from_shell = ( L != nullptr );
-    Request& current_request = get_current_request();
+    SharedRequest current_request = get_current_request();
     if (Swapper::get_reload_in_progress())
     {
-        current_request.respond("== reload pending; retry\n", from_shell);
+        current_request->respond("== reload pending; retry\n", from_shell);
         return 0;
     }
-    current_request.respond(".. reloading third-party\n", from_shell);
+    current_request->respond(".. reloading third-party\n", from_shell);
     AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME);
     if (!inspector)
     {
-        current_request.respond("== reload third-party failed - appid not enabled\n", from_shell);
+        current_request->respond("== reload third-party failed - appid not enabled\n", from_shell);
         return 0;
     }
     const AppIdContext& ctxt = inspector->get_ctxt();
     ThirdPartyAppIdContext* old_ctxt = ctxt.get_tp_appid_ctxt();
     if (!old_ctxt)
     {
-        current_request.respond("== reload third-party failed - third-party module doesn't exist\n", from_shell);
+        current_request->respond("== reload third-party failed - third-party module doesn't exist\n", from_shell);
         return 0;
     }
     Swapper::set_reload_in_progress(true);
-    current_request.respond("== unloading old third-party configuration\n", from_shell);
+    current_request->respond("== unloading old third-party configuration\n", from_shell);
     main_broadcast_command(new ACThirdPartyAppIdContextUnload(*inspector, old_ctxt,
         current_request, from_shell), from_shell);
     return 0;
@@ -340,17 +340,17 @@ static void clear_dynamic_host_cache_services()
 static int reload_detectors(lua_State* L)
 {
     bool from_shell = ( L != nullptr );
-    Request& current_request = get_current_request();
+    SharedRequest current_request = get_current_request();
     if (Swapper::get_reload_in_progress())
     {
-        current_request.respond("== reload pending; retry\n", from_shell);
+        current_request->respond("== reload pending; retry\n", from_shell);
         return 0;
     }
-    current_request.respond(".. reloading detectors\n", from_shell);
+    current_request->respond(".. reloading detectors\n", from_shell);
     AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME);
     if (!inspector)
     {
-        current_request.respond("== reload detectors failed - appid not enabled\n", from_shell);
+        current_request->respond("== reload detectors failed - appid not enabled\n", from_shell);
         return 0;
     }
     Swapper::set_reload_in_progress(true);
@@ -372,7 +372,7 @@ static int reload_detectors(lua_State* L)
     odp_thread_local_ctxt->initialize(ctxt, true, true);
     odp_ctxt.initialize();
 
-    current_request.respond("== swapping detectors configuration\n", from_shell);
+    current_request->respond("== swapping detectors configuration\n", from_shell);
     main_broadcast_command(new ACOdpContextSwap(*inspector, old_odp_ctxt,
         current_request, from_shell), from_shell);
     return 0;
index e3adcdf77611294d6ac5fb0e6e76262a5eeab78a..6e735bc1fd821173448569e726612970a823e39c 100644 (file)
@@ -94,8 +94,8 @@ static int purge_data(lua_State* L)
 
         host_cache.invalidate();
 
-        auto& request = get_dispatched_request();
-        request.respond("data purge done\n", false, true);
+        SharedRequest request = get_dispatched_request();
+        request->respond("data purge done\n", false, true);
         LogMessage("data purge done\n");
     }
 
index 6b95de8949cbf8729d8d76eb246a19bbc80f2f18..aeddbade061d0f709b462eb0f7b101a076d50899 100644 (file)
@@ -110,9 +110,9 @@ private:
 
 } // end of namespace snort
 
-static Request mock_request;
+static SharedRequest mock_request = std::make_shared<Request>();
 void Request::respond(const char*, bool, bool) { }
-Request& get_dispatched_request() { return mock_request; }
+SharedRequest get_dispatched_request() { return mock_request; }
 
 HostCacheMac* get_host_cache_mac() { return nullptr; }