]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4019: control: blocking control connections
authorRAGHURAAM CONJEEVARAM UDAYANAN -X (rconjeev - XORIANT CORPORATION at Cisco) <rconjeev@cisco.com>
Mon, 9 Oct 2023 14:33:54 +0000 (14:33 +0000)
committerOleksii Shumeiko -X (oshumeik - SOFTSERVE INC at Cisco) <oshumeik@cisco.com>
Mon, 9 Oct 2023 14:33:54 +0000 (14:33 +0000)
Merge in SNORT/snort3 from ~RCONJEEV/snort3:control_conn_cmd_block_fix to master

Squashed commit of the following:

commit b1ad1e27d0f38286ac99594af11eb7d1c0cb94f8
Author: RAGHURAAM CONJEEVARAM UDAYANAN -X (rconjeev - XORIANT CORPORATION at Cisco) <rconjeev@cisco.com>
Date:   Mon Sep 25 04:25:11 2023 -0400

    control: allow one command at a time

src/control/control.cc
src/control/control.h
src/control/control_mgmt.cc
src/framework/module.h
src/main/ac_shell_cmd.cc
src/main/snort_module.cc
src/managers/module_manager.cc
src/managers/module_manager.h

index 0204e82ac538046355595257f06b6bee11d5688d..da861dd0e58f92b7f3be00e22b374f652afdb8f8 100644 (file)
@@ -35,6 +35,7 @@
 using namespace snort;
 
 std::vector<std::string> ControlConn::log_exclusion_list;
+unsigned ControlConn::pending_cmds_count = 0;
 
 ControlConn* ControlConn::query_from_lua(const lua_State* L)
 {
@@ -66,6 +67,9 @@ ControlConn::~ControlConn()
 
 void ControlConn::shutdown()
 {
+    if (blocked)
+        blocked = false;
+
     if (is_closed())
         return;
     if (!local)
@@ -168,6 +172,8 @@ int ControlConn::execute_commands()
     while (!is_closed() && !blocked && !pending_commands.empty())
     {
         const std::string& command = pending_commands.front();
+        if (pending_cmds_count && !ModuleManager::is_parallel_cmd(command))
+            break;
         std::string rsp;
         shell->execute(command.c_str(), rsp);
         if (!rsp.empty())
index 8c4c67c49d0598213973bd0ca1be8f5ec1fb7277..71fea8bff19b2d35e935577c3a29bbd09717b33d 100644 (file)
@@ -49,7 +49,6 @@ public:
     void unblock();
     void remove();
     bool show_prompt();
-
     bool is_blocked() const { return blocked; }
     bool is_closed() const { return (fd == -1); }
     bool is_removed() const { return removed; }
@@ -70,6 +69,8 @@ public:
     SO_PUBLIC static ControlConn* query_from_lua(const lua_State*);
 
     static void log_command(const std::string& module, bool log);
+    static unsigned increment_pending_cmds_count() { return ++pending_cmds_count; }
+    static unsigned decrement_pending_cmds_count() { return --pending_cmds_count; }
 
 private:
     void touch();
@@ -81,11 +82,12 @@ private:
     class Shell *shell;
     int fd;
     bool local = false;
-    bool blocked = false;
+    bool blocked = false; //block any new commands from executing before current command in control connection is complete
     bool removed = false;
     time_t touched;
 
     static std::vector<std::string> log_exclusion_list;
+    static unsigned pending_cmds_count; //counter to serialize commands across control connections
 };
 
 #define LogRespond(cn, ...)       do { if (cn) cn->respond(__VA_ARGS__); else LogMessage(__VA_ARGS__); } while(0)
index 5c3aed10e39ba81e72c6f0076828511d565a4857..54f5ba09af28e50458f501b26540c514893e2b56 100644 (file)
@@ -337,6 +337,35 @@ static void delete_control(int fd)
         delete_control(iter);
 }
 
+static int execute_control_commands(ControlConn *ctrlcon)
+{
+    int executed = 0;
+    if (!ctrlcon)
+        return executed;
+
+    executed = ctrlcon->execute_commands();
+    if (executed > 0)
+    {
+        if (ctrlcon->is_local())
+            proc_stats.local_commands += executed;
+        else
+            proc_stats.remote_commands += executed;
+    }
+    return executed;
+}
+
+static void process_pending_control_commands()
+{
+    for (auto it : controls)
+    {
+        if (it.second->has_pending_command())
+        {
+            ControlConn* ctrlcon = it.second;
+            execute_control_commands(ctrlcon);
+        }
+    }
+}
+
 static bool process_control_commands(int fd)
 {
     const auto iter = controls.find(fd);
@@ -353,14 +382,7 @@ static bool process_control_commands(int fd)
         return false;
     }
 
-    int executed = ctrlcon->execute_commands();
-    if (executed > 0)
-    {
-        if (ctrlcon->is_local())
-            proc_stats.local_commands += executed;
-        else
-            proc_stats.remote_commands += executed;
-    }
+    int executed = execute_control_commands(ctrlcon);
 
     if (ctrlcon->is_closed())
         delete_control(iter);
@@ -490,6 +512,8 @@ bool ControlMgmt::service_users()
     static FdEvents event[MAX_CONTROL_FDS];
     unsigned nevent;
 
+    process_pending_control_commands();
+
     if (!poll_control_fds(event, nevent))
         return false;
 
index 8eb3b75b162a98d5573654260270835fc1123e7f..e1c3001f3271d3ddc8e7052cd5b240001b7e47db 100644 (file)
@@ -66,6 +66,8 @@ struct Command
     LuaCFunction func;
     const Parameter* params;
     const char* help;
+    // the flag determines if the command is allowed to run in parallel with other control commands
+    bool can_run_in_parallel = false;
 
     std::string get_arg_list() const;
 };
index a4003a2ae61a8d5e13c972f089dc12467d0f5459..195f59a263f5e0572d3df183d38420ac2bd6b141 100644 (file)
@@ -33,6 +33,7 @@ ACShellCmd::ACShellCmd(ControlConn* conn, AnalyzerCommand* ac) : AnalyzerCommand
 
     if (ctrlcon)
         ctrlcon->block();
+    ControlConn::increment_pending_cmds_count();
 }
 
 bool ACShellCmd::execute(Analyzer& analyzer, void** state)
@@ -44,6 +45,7 @@ bool ACShellCmd::execute(Analyzer& analyzer, void** state)
 ACShellCmd::~ACShellCmd()
 {
     delete ac;
+    ControlConn::decrement_pending_cmds_count();
 
     if (ctrlcon)
     {
index 7b154673dcea3423467a5177c168713c4d07a314..b73a68b69ecf84f58e0db15b2dc3bbfbecaec02f 100644 (file)
@@ -147,14 +147,14 @@ static const Command snort_cmds[] =
     // FIXIT-M rewrite trough to permit updates on the fly
     //{ "process", main_process, nullptr, "process given pcap" },
 
-    { "pause", main_pause, nullptr, "suspend packet processing" },
+    { "pause", main_pause, nullptr, "suspend packet processing", true },
 
     { "resume", main_resume, s_pktnum, "continue packet processing. "
-      "If number of packets is specified, will resume for n packets and pause" },
+      "If number of packets is specified, will resume for n packets and pause", true },
 
-    { "detach", main_detach, nullptr, "detach from control shell (without shutting down)" },
-    { "quit", main_quit, nullptr, "shutdown and dump-stats" },
-    { "help", main_help, nullptr, "this output" },
+    { "detach", main_detach, nullptr, "detach from control shell (without shutting down)", true },
+    { "quit", main_quit, nullptr, "shutdown and dump-stats", true },
+    { "help", main_help, nullptr, "this output", true },
 
     { nullptr, nullptr, nullptr, nullptr }
 };
index 2ef30a6cc88ce8c755acc6613254a493af229e4d..80cf9fe143805d3d2b817d6be25f6a3ea51e8d3e 100644 (file)
@@ -83,6 +83,7 @@ static string s_aliased_name;
 static string s_aliased_type;
 static string s_ips_includer;
 static string s_file_id_includer;
+static std::unordered_set<string> s_parallel_cmds;
 
 // for callbacks from Lua
 static SnortConfig* s_config = nullptr;
@@ -157,11 +158,19 @@ void ModHook::init()
     // would be out of date, out of sync, etc. QED
     reg = new luaL_Reg[++n];
     unsigned k = 0;
-
+    std::string cmd_name;
+    const char* dot = ".";
     while ( k < n )
     {
         reg[k].name = c[k].name;
         reg[k].func = c[k].func;
+        if (c[k].can_run_in_parallel)
+        {
+            cmd_name = mod->get_name();
+            cmd_name = cmd_name + dot + c[k].name;
+            s_parallel_cmds.insert(cmd_name);
+        }
+
         k++;
     }
 }
@@ -1942,6 +1951,30 @@ void ModuleManager::show_modules_json()
     json.close_array();
 }
 
+bool ModuleManager::is_parallel_cmd(std::string control_cmd)
+{
+    control_cmd = remove_whitespace(control_cmd);
+
+    std::string mod_cmd;
+
+    size_t dotPos = control_cmd.find('.');
+    size_t openParenthesisPos = control_cmd.find("(");
+
+    if (dotPos == std::string::npos)
+        mod_cmd = "snort.";
+
+    if (openParenthesisPos != std::string::npos)
+        mod_cmd = mod_cmd + control_cmd.substr(0,openParenthesisPos);
+
+    return 1 == s_parallel_cmds.count(mod_cmd);
+}
+
+std::string ModuleManager::remove_whitespace(std::string& control_cmd)
+{
+    control_cmd.erase(std::remove_if(control_cmd.begin(), control_cmd.end(), ::isspace), control_cmd.end());
+    return control_cmd;
+}
+
 #ifdef UNIT_TEST
 
 #include <catch/snort_catch.h>
index c09d1084859730b6f47130852a4a4a4eb4a4baeb..850ed38ba58458d4a4ccb757a2240134c1a9152c 100644 (file)
@@ -94,6 +94,8 @@ public:
     static void reset_stats(clear_counter_type_t);
 
     static void clear_global_active_counters();
+    static bool is_parallel_cmd(std::string control_cmd);
+    static std::string remove_whitespace(std::string& control_cmd);
 
 
     static std::set<uint32_t> gids;