using namespace snort;
std::vector<std::string> ControlConn::log_exclusion_list;
+unsigned ControlConn::pending_cmds_count = 0;
ControlConn* ControlConn::query_from_lua(const lua_State* L)
{
void ControlConn::shutdown()
{
+ if (blocked)
+ blocked = false;
+
if (is_closed())
return;
if (!local)
while (!is_closed() && !blocked && !pending_commands.empty())
{
const std::string& command = pending_commands.front();
+ if (pending_cmds_count && !ModuleManager::is_parallel_cmd(command))
+ break;
std::string rsp;
shell->execute(command.c_str(), rsp);
if (!rsp.empty())
void unblock();
void remove();
bool show_prompt();
-
bool is_blocked() const { return blocked; }
bool is_closed() const { return (fd == -1); }
bool is_removed() const { return removed; }
SO_PUBLIC static ControlConn* query_from_lua(const lua_State*);
static void log_command(const std::string& module, bool log);
+ static unsigned increment_pending_cmds_count() { return ++pending_cmds_count; }
+ static unsigned decrement_pending_cmds_count() { return --pending_cmds_count; }
private:
void touch();
class Shell *shell;
int fd;
bool local = false;
- bool blocked = false;
+ bool blocked = false; //block any new commands from executing before current command in control connection is complete
bool removed = false;
time_t touched;
static std::vector<std::string> log_exclusion_list;
+ static unsigned pending_cmds_count; //counter to serialize commands across control connections
};
#define LogRespond(cn, ...) do { if (cn) cn->respond(__VA_ARGS__); else LogMessage(__VA_ARGS__); } while(0)
delete_control(iter);
}
+static int execute_control_commands(ControlConn *ctrlcon)
+{
+ int executed = 0;
+ if (!ctrlcon)
+ return executed;
+
+ executed = ctrlcon->execute_commands();
+ if (executed > 0)
+ {
+ if (ctrlcon->is_local())
+ proc_stats.local_commands += executed;
+ else
+ proc_stats.remote_commands += executed;
+ }
+ return executed;
+}
+
+static void process_pending_control_commands()
+{
+ for (auto it : controls)
+ {
+ if (it.second->has_pending_command())
+ {
+ ControlConn* ctrlcon = it.second;
+ execute_control_commands(ctrlcon);
+ }
+ }
+}
+
static bool process_control_commands(int fd)
{
const auto iter = controls.find(fd);
return false;
}
- int executed = ctrlcon->execute_commands();
- if (executed > 0)
- {
- if (ctrlcon->is_local())
- proc_stats.local_commands += executed;
- else
- proc_stats.remote_commands += executed;
- }
+ int executed = execute_control_commands(ctrlcon);
if (ctrlcon->is_closed())
delete_control(iter);
static FdEvents event[MAX_CONTROL_FDS];
unsigned nevent;
+ process_pending_control_commands();
+
if (!poll_control_fds(event, nevent))
return false;
LuaCFunction func;
const Parameter* params;
const char* help;
+ // the flag determines if the command is allowed to run in parallel with other control commands
+ bool can_run_in_parallel = false;
std::string get_arg_list() const;
};
if (ctrlcon)
ctrlcon->block();
+ ControlConn::increment_pending_cmds_count();
}
bool ACShellCmd::execute(Analyzer& analyzer, void** state)
ACShellCmd::~ACShellCmd()
{
delete ac;
+ ControlConn::decrement_pending_cmds_count();
if (ctrlcon)
{
// FIXIT-M rewrite trough to permit updates on the fly
//{ "process", main_process, nullptr, "process given pcap" },
- { "pause", main_pause, nullptr, "suspend packet processing" },
+ { "pause", main_pause, nullptr, "suspend packet processing", true },
{ "resume", main_resume, s_pktnum, "continue packet processing. "
- "If number of packets is specified, will resume for n packets and pause" },
+ "If number of packets is specified, will resume for n packets and pause", true },
- { "detach", main_detach, nullptr, "detach from control shell (without shutting down)" },
- { "quit", main_quit, nullptr, "shutdown and dump-stats" },
- { "help", main_help, nullptr, "this output" },
+ { "detach", main_detach, nullptr, "detach from control shell (without shutting down)", true },
+ { "quit", main_quit, nullptr, "shutdown and dump-stats", true },
+ { "help", main_help, nullptr, "this output", true },
{ nullptr, nullptr, nullptr, nullptr }
};
static string s_aliased_type;
static string s_ips_includer;
static string s_file_id_includer;
+static std::unordered_set<string> s_parallel_cmds;
// for callbacks from Lua
static SnortConfig* s_config = nullptr;
// would be out of date, out of sync, etc. QED
reg = new luaL_Reg[++n];
unsigned k = 0;
-
+ std::string cmd_name;
+ const char* dot = ".";
while ( k < n )
{
reg[k].name = c[k].name;
reg[k].func = c[k].func;
+ if (c[k].can_run_in_parallel)
+ {
+ cmd_name = mod->get_name();
+ cmd_name = cmd_name + dot + c[k].name;
+ s_parallel_cmds.insert(cmd_name);
+ }
+
k++;
}
}
json.close_array();
}
+bool ModuleManager::is_parallel_cmd(std::string control_cmd)
+{
+ control_cmd = remove_whitespace(control_cmd);
+
+ std::string mod_cmd;
+
+ size_t dotPos = control_cmd.find('.');
+ size_t openParenthesisPos = control_cmd.find("(");
+
+ if (dotPos == std::string::npos)
+ mod_cmd = "snort.";
+
+ if (openParenthesisPos != std::string::npos)
+ mod_cmd = mod_cmd + control_cmd.substr(0,openParenthesisPos);
+
+ return 1 == s_parallel_cmds.count(mod_cmd);
+}
+
+std::string ModuleManager::remove_whitespace(std::string& control_cmd)
+{
+ control_cmd.erase(std::remove_if(control_cmd.begin(), control_cmd.end(), ::isspace), control_cmd.end());
+ return control_cmd;
+}
+
#ifdef UNIT_TEST
#include <catch/snort_catch.h>
static void reset_stats(clear_counter_type_t);
static void clear_global_active_counters();
+ static bool is_parallel_cmd(std::string control_cmd);
+ static std::string remove_whitespace(std::string& control_cmd);
static std::set<uint32_t> gids;