]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #3131 in SNORT/snort3 from ~SHRARANG/snort3:appid_lua_init_mem_opt...
authorShravan Rangarajuvenkata (shrarang) <shrarang@cisco.com>
Wed, 27 Oct 2021 02:31:41 +0000 (02:31 +0000)
committerShravan Rangarajuvenkata (shrarang) <shrarang@cisco.com>
Wed, 27 Oct 2021 02:31:41 +0000 (02:31 +0000)
Squashed commit of the following:

commit 3463c2fe5d7af7e5b54790e31164c5ec834be778
Author: Shravan Rangaraju <shrarang@cisco.com>
Date:   Tue Oct 26 15:11:48 2021 -0400

    appid: during initialization, skip loading of Lua detectors that don't have validate function

src/network_inspectors/appid/appid_config.cc
src/network_inspectors/appid/lua_detector_module.cc
src/network_inspectors/appid/lua_detector_module.h

index 4ba519d42166916d65a7713f55621a544552eeed..3d07f290e005c4a2092479a0325f869a8dfc0a1d 100644 (file)
@@ -212,7 +212,7 @@ AppId OdpContext::get_port_service_id(IpProtocol proto, uint16_t port)
     AppId appId;
 
     if (proto == IpProtocol::TCP)
-      appId = tcp_port_only[port];
+        appId = tcp_port_only[port];
     else
         appId = udp_port_only[port];
 
@@ -229,7 +229,7 @@ void OdpThreadContext::initialize(AppIdContext& ctxt, bool is_control, bool relo
     if (!is_control and reload_odp)
         LuaDetectorManager::init_thread_manager(ctxt);
     else
-        LuaDetectorManager::initialize(ctxt, is_control? 1 : 0, reload_odp);
+        LuaDetectorManager::initialize(ctxt, is_control, reload_odp);
 }
 
 OdpThreadContext::~OdpThreadContext()
index c8d21819f8c32d434355954a85ea11a15ff19d72..31af1003b968f785280351acd95c8268b6122dab 100644 (file)
@@ -50,9 +50,10 @@ using namespace std;
 #define AVG_LUA_TRACKER_SIZE_IN_BYTES 740
 #define MAX_MEMORY_FOR_LUA_DETECTORS (512 * 1024 * 1024)
 
-static std::vector<LuaDetectorManager*> lua_detector_mgr_list;
+static vector<LuaDetectorManager*> lua_detector_mgr_list;
+static unordered_set<string> lua_detectors_w_validate;
 
-bool get_lua_field(lua_State* L, int table, const char* field, std::string& out)
+bool get_lua_field(lua_State* L, int table, const char* field, string& out)
 {
     lua_getfield(L, table, field);
     bool result = lua_isstring(L, -1);
@@ -89,13 +90,13 @@ bool get_lua_field(lua_State* L, int table, const char* field, IpProtocol& out)
     return result;
 }
 
-inline void set_control(lua_State* L, int is_control)
+inline void set_control(lua_State* L, bool is_control)
 {
-    lua_pushboolean (L, is_control); // push flag to stack
+    lua_pushboolean (L, is_control ? 1 : 0); // push flag to stack
     lua_setglobal(L, "is_control"); // create global key to store value
 }
 
-static lua_State* create_lua_state(const AppIdConfig& config, int is_control)
+static lua_State* create_lua_state(const AppIdConfig& config, bool is_control)
 {
     auto L = luaL_newstate();
 
@@ -150,13 +151,13 @@ static lua_State* create_lua_state(const AppIdConfig& config, int is_control)
     return L;
 }
 
-LuaDetectorManager::LuaDetectorManager(AppIdContext& ctxt, int is_control) :
+LuaDetectorManager::LuaDetectorManager(AppIdContext& ctxt, bool is_control) :
     ctxt(ctxt)
 {
     allocated_objects.clear();
     cb_detectors.clear();
     L = create_lua_state(ctxt.config, is_control);
-    if (is_control == 1)
+    if (is_control)
         init_chp_glossary();
 }
 
@@ -180,7 +181,7 @@ LuaDetectorManager::~LuaDetectorManager()
             lua_getfield(L, -1, lsd->package_info.cleanFunctionName.c_str());
             if ( lua_isfunction(L, -1) )
             {
-                std::string name = lsd->package_info.name + "_";
+                string name = lsd->package_info.name + "_";
                 lua_getglobal(L, name.c_str());
 
                 if ( lua_pcall(L, 1, 1, 0) )
@@ -201,7 +202,7 @@ LuaDetectorManager::~LuaDetectorManager()
     cb_detectors.clear(); // do not free Lua objects in cb_detectors
 }
 
-void LuaDetectorManager::initialize(AppIdContext& ctxt, int is_control, bool reload)
+void LuaDetectorManager::initialize(AppIdContext& ctxt, bool is_control, bool reload)
 {
     LuaDetectorManager* lua_detector_mgr = new LuaDetectorManager(ctxt, is_control);
     odp_thread_local_ctxt->set_lua_detector_mgr(*lua_detector_mgr);
@@ -224,7 +225,7 @@ void LuaDetectorManager::initialize(AppIdContext& ctxt, int is_control, bool rel
         }
     }
 
-    lua_detector_mgr->initialize_lua_detectors(reload);
+    lua_detector_mgr->initialize_lua_detectors(is_control, reload);
     lua_detector_mgr->activate_lua_detectors();
 
     if (ctxt.config.list_odp_detectors)
@@ -327,7 +328,7 @@ static inline uint32_t compute_lua_tracker_size(uint64_t rnaMemory, uint32_t num
 LuaObject* LuaDetectorManager::create_lua_detector(const char* detector_name,
     bool is_custom, const char* detector_filename, bool& has_validate)
 {
-    std::string log_name;
+    string log_name;
     IpProtocol proto = IpProtocol::PROTO_NOT_SET;
 
     has_validate = false;
@@ -404,12 +405,12 @@ LuaObject* LuaDetectorManager::create_lua_detector(const char* detector_name,
 
 static int dump(lua_State*, const void* buf,size_t size, void* data)
 {
-    std::string* s = static_cast<std::string*>(data);
+    string* s = static_cast<string*>(data);
     s->append(static_cast<const char*>(buf), size);
     return 0;
 }
 
-bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom, bool reload, std::string& buf)
+bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom, bool is_control, bool reload, string& buf)
 {
     if (reload and !buf.empty())
     {
@@ -423,6 +424,13 @@ bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom,
     }
     else
     {
+        if (!is_control)
+        {
+            auto iter = lua_detectors_w_validate.find(detector_filename);
+            if (iter == lua_detectors_w_validate.end())
+                return false;
+        }
+
         if (luaL_loadfile(L, detector_filename))
         {
             if (init(L))
@@ -476,7 +484,7 @@ bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom,
     return has_validate;
 }
 
-void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bool reload)
+void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bool is_control, bool reload)
 {
     char pattern[PATH_MAX];
     snprintf(pattern, sizeof(pattern), "%s/*", path);
@@ -490,7 +498,7 @@ void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bo
             WarningMessage("appid: leak of %d lua stack elements before detector load\n",
                 lua_gettop(L));
 
-        std::string buf;
+        string buf;
         for (unsigned n = 0; n < globs.gl_pathc; n++)
         {
             ifstream file(globs.gl_pathv[n], ios::ate);
@@ -510,17 +518,36 @@ void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bo
             }
             file.close();
 
-            bool has_validate = load_detector(globs.gl_pathv[n], is_custom, reload, buf);
+            // In the packet threads, we do not need to load Lua detectors that don't have validate
+            // function such as payload_group_*, ssl_group_*, etc. That's because the patterns they
+            // register are stored in global tables only in control thread. In packet threads, they
+            // do nothing. Skipping loading of these detectors in packet threads saves on the memory
+            // used by LuaJIT.
+
+            // Because the code flow for loading Lua detectors is different for initialization vs
+            // reload, the LuaJIT memory saving is achieved differently in these two cases.
+
+            // During initialization, load_lua_detectors() gets called for all the threads - first
+            // for the control thread and then for the packet threads. Control thread stores the
+            // detectors that have validate in lua_detectors_w_validate. Packet thread loads a
+            // detector in load_detector() only if it finds the detector in lua_detectors_w_validate.
+
+            // During reload, load_lua_detectors() gets called only for control thread. This
+            // function loads detectors for all the packet threads too during reload. It skips
+            // loading detectors that don't have validate for packet threads.
+            bool has_validate = load_detector(globs.gl_pathv[n], is_custom, is_control, reload, buf);
 
             if (reload)
             {
                 for (auto& lua_detector_mgr : lua_detector_mgr_list)
                 {
                     if (has_validate)
-                        lua_detector_mgr->load_detector(globs.gl_pathv[n], is_custom, reload, buf);
+                        lua_detector_mgr->load_detector(globs.gl_pathv[n], is_custom, is_control, reload, buf);
                 }
                 buf.clear();
             }
+            else if (is_control and has_validate)
+                lua_detectors_w_validate.insert(globs.gl_pathv[n]);
             lua_settop(L, 0);
         }
 
@@ -534,7 +561,7 @@ void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bo
             pattern, rval);
 }
 
-void LuaDetectorManager::initialize_lua_detectors(bool reload)
+void LuaDetectorManager::initialize_lua_detectors(bool is_control, bool reload)
 {
     char path[PATH_MAX];
     const char* dir = ctxt.config.app_detector_dir;
@@ -543,7 +570,7 @@ void LuaDetectorManager::initialize_lua_detectors(bool reload)
         return;
 
     snprintf(path, sizeof(path), "%s/odp/lua", dir);
-    load_lua_detectors(path, false, reload);
+    load_lua_detectors(path, false, is_control, reload);
     num_odp_detectors = allocated_objects.size();
 
     if (reload)
@@ -552,14 +579,14 @@ void LuaDetectorManager::initialize_lua_detectors(bool reload)
             lua_detector_mgr->num_odp_detectors = lua_detector_mgr->allocated_objects.size();
     }
     snprintf(path, sizeof(path), "%s/custom/lua", dir);
-    load_lua_detectors(path, true, reload);
+    load_lua_detectors(path, true, is_control, reload);
 }
 
 void LuaDetectorManager::activate_lua_detectors()
 {
     uint32_t lua_tracker_size = compute_lua_tracker_size(MAX_MEMORY_FOR_LUA_DETECTORS,
         allocated_objects.size());
-    std::list<LuaObject*>::iterator lo = allocated_objects.begin();
+    list<LuaObject*>::iterator lo = allocated_objects.begin();
 
     if (lua_gettop(L))
         WarningMessage("appid: leak of %d lua stack elements before detector activate\n",
@@ -584,7 +611,7 @@ void LuaDetectorManager::activate_lua_detectors()
         }
 
         /*first parameter is DetectorUserData */
-        std::string name = lsd->package_info.name + "_";
+        string name = lsd->package_info.name + "_";
         lua_getglobal(L, name.c_str());
 
         /*second parameter is a table containing configuration stuff. */
index 615edc633864eba490edfccc83848cedb80a0c12..f85e1cfb7ba35ae20fa00b6444d766ba0053ab04 100644 (file)
@@ -48,9 +48,9 @@ bool get_lua_field(lua_State* L, int table, const char* field, IpProtocol& out);
 class LuaDetectorManager
 {
 public:
-    LuaDetectorManager(AppIdContext&, int);
+    LuaDetectorManager(AppIdContext&, bool);
     ~LuaDetectorManager();
-    static void initialize(AppIdContext&, int is_control=0, bool reload=false);
+    static void initialize(AppIdContext&, bool is_control=false, bool reload=false);
     static void init_thread_manager(const AppIdContext&);
     static void clear_lua_detector_mgrs();
 
@@ -69,11 +69,11 @@ public:
     LuaObject* get_cb_detector(AppId app_id);
 
 private:
-    void initialize_lua_detectors(bool reload = false);
+    void initialize_lua_detectors(bool is_control, bool reload = false);
     void activate_lua_detectors();
     void list_lua_detectors();
-    bool load_detector(char* detector_name, bool is_custom, bool reload, std::string& buf);
-    void load_lua_detectors(const char* path, bool is_custom, bool reload = false);
+    bool load_detector(char* detector_name, bool is_custom, bool is_control, bool reload, std::string& buf);
+    void load_lua_detectors(const char* path, bool is_custom, bool is_control, bool reload = false);
     LuaObject* create_lua_detector(const char* detector_name, bool is_custom,
         const char* detector_filename, bool& has_validate);