From: Oleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco) Date: Fri, 2 May 2025 20:37:21 +0000 (+0000) Subject: Pull request #4718: mp_data_bus: Adding stats and CLI commands to MPDataBus X-Git-Tag: 3.7.4.0~4 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=68f15b26ab9dd7a56c1d30931eaccd8d029c56b1;p=thirdparty%2Fsnort3.git Pull request #4718: mp_data_bus: Adding stats and CLI commands to MPDataBus Merge in SNORT/snort3 from ~OSTEPANO/snort3:cli_stats_mp to master Squashed commit of the following: commit 8160a86149c4b0030e74b6a04a6919ce55bf3913 Author: Oleksandr Stepanov Date: Mon Apr 28 06:02:08 2025 -0400 mp_data_bus: Adding peg stats and socket commands for MPDataBus --- diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.cc b/src/connectors/unixdomain_connector/unixdomain_connector.cc index 24a75779b..461df406c 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.cc +++ b/src/connectors/unixdomain_connector/unixdomain_connector.cc @@ -85,7 +85,7 @@ static bool attempt_connection(int& sfd, const char* path) { // Function to handle connection retries static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) { if(update_handler) - update_handler(nullptr, (cfg.conn_retries > 0)); + update_handler(nullptr, ( (cfg.conn_retries > 0) and (cfg.setup == UnixDomainConnectorConfig::Setup::CALL) )); else ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr); diff --git a/src/framework/mp_data_bus.cc b/src/framework/mp_data_bus.cc index 8e6477050..8197afeb4 100644 --- a/src/framework/mp_data_bus.cc +++ b/src/framework/mp_data_bus.cc @@ -32,19 +32,45 @@ #include "utils/stats.h" #include "main/snort_types.h" #include "log/messages.h" +#include "log/log_stats.h" #include "helpers/ring.h" #include "managers/mp_transport_manager.h" +#include "managers/module_manager.h" +#include "main/snort.h" +#include "framework/module.h" using namespace snort; +void MPDataBusLog(const char* msg, ...); + std::condition_variable MPDataBus::queue_cv; std::mutex MPDataBus::queue_mutex; uint32_t MPDataBus::mp_max_eventq_size = DEFAULT_MAX_EVENTQ_SIZE; std::string MPDataBus::transport = DEFAULT_TRANSPORT; bool MPDataBus::enable_debug = false; MPTransport* MPDataBus::transport_layer = nullptr; +MPDataBusStats MPDataBus::mp_global_stats = {}; +#ifdef REG_TEST +bool MPDataBus::hold_events = false; +#endif static std::unordered_map mp_pub_ids; +static std::mutex mp_stats_mutex; +static uint mp_current_process_id = 0; + +void MPDataBusLog(const char* msg, ...) +{ + if (!MPDataBus::enable_debug) + return; + + char buf[256]; + va_list args; + va_start(args, msg); + vsnprintf(buf, sizeof(buf), msg, args); + va_end(args); + + LogMessage("MPDataBusDbg ID=%d %s", mp_current_process_id, buf); +} //-------------------------------------------------------------------------- // public methods @@ -60,7 +86,7 @@ MPDataBus::~MPDataBus() { stop_worker_thread(); - for (auto& [key, sublist] : mp_pub_sub) + for (auto& [_, sublist] : mp_pub_sub) { for (auto* handler : sublist) { @@ -87,6 +113,8 @@ unsigned MPDataBus::init(int max_procs) return 1; } + mp_current_process_id = Snort::get_process_id(); + transport_layer = MPTransportManager::get_transport(transport); if (transport_layer == nullptr) { @@ -136,6 +164,18 @@ unsigned MPDataBus::get_id(const PubKey& key) return mp_pub_ids[key.publisher]; } +const char* MPDataBus::get_name_from_id(unsigned id) +{ + for (const auto& [name, pub_id] : mp_pub_ids) + { + if (pub_id == id) + { + return name.c_str(); + } + } + return nullptr; +} + void MPDataBus::subscribe(const PubKey& key, unsigned eid, DataHandler* h) { if(! SnortConfig::get_conf()->mp_dbus) @@ -145,7 +185,7 @@ void MPDataBus::subscribe(const PubKey& key, unsigned eid, DataHandler* h) } SnortConfig::get_conf()->mp_dbus->_subscribe(key, eid, h); - MP_DATABUS_LOG("MPDataBus: Subscribed to event ID %u\n", eid); + MPDataBusLog("Subscribed to event ID %u\n", eid); } bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, std::shared_ptr e, Flow*) @@ -166,12 +206,9 @@ bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, std::shared_ptr lock(queue_mutex); - queue_cv.notify_one(); - } + queue_cv.notify_one(); - MP_DATABUS_LOG("MPDataBus: Event published for publisher ID %u and event ID %u\n", pub_id, evt_id); + MPDataBusLog("Event published for publisher ID %u and event ID %u\n", pub_id, evt_id); return true; } @@ -189,7 +226,7 @@ void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSer MPHelperFunctions helpers(mp_serializer_helper, mp_deserializer_helper); SnortConfig::get_conf()->mp_dbus->transport_layer->register_event_helpers(pub_id, evt_id, helpers); - MP_DATABUS_LOG("MPDataBus: Registered event helpers for event ID %u\n", evt_id); + MPDataBusLog("Registered event helpers for event ID %u\n", evt_id); } // API for receiving the DataEvent and Event type from transport layer @@ -199,9 +236,18 @@ void MPDataBus::receive_message(const MPEventInfo& event_info) unsigned evt_id = event_info.type; unsigned pub_id = event_info.pub_id; - MP_DATABUS_LOG("MPDataBus: Received message for publisher ID %u and event ID %u\n", pub_id, evt_id); + MPDataBusLog("Received message for publisher ID %u and event ID %u\n", pub_id, evt_id); + + auto pub_res = _publish(pub_id, evt_id, *e, nullptr); - _publish(pub_id, evt_id, *e, nullptr); + { + std::lock_guard lock(mp_stats_mutex); + mp_pub_stats[pub_id].total_messages_received++; + if(pub_res) + { + mp_pub_stats[pub_id].total_messages_delivered++; + } + } } @@ -210,25 +256,42 @@ void MPDataBus::receive_message(const MPEventInfo& event_info) //-------------------------------------------------------------------------- void MPDataBus::process_event_queue() { +#ifdef REG_TEST + if (hold_events) + { + return; + } +#endif if (!mp_event_queue) { return; } - std::unique_lock lock(queue_mutex); + std::unique_lock u_lock(queue_mutex); - queue_cv.wait_for(lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP), [this]() { + queue_cv.wait_for(u_lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP), [this]() { return mp_event_queue != nullptr && !mp_event_queue->empty(); }); - lock.unlock(); - while (!mp_event_queue->empty()) { std::shared_ptr event_info = mp_event_queue->get(nullptr); if (event_info) { - MP_DATABUS_LOG("MPDataBus: Processing event for publisher ID %u \n", + MPDataBusLog("Processing event for publisher ID %u \n", event_info->pub_id); - transport_layer->send_to_transport(*event_info); + auto send_res = transport_layer->send_to_transport(*event_info); + + { + std::lock_guard lock(mp_stats_mutex); + mp_pub_stats[event_info->pub_id].total_messages_published++; + if (!send_res) + { + mp_pub_stats[event_info->pub_id].total_messages_dropped++; + } + else + { + mp_pub_stats[event_info->pub_id].total_messages_sent++; + } + } } } } @@ -270,7 +333,164 @@ static bool compare(DataHandler* a, DataHandler* b) return false; } -void MPDataBus::_subscribe(unsigned pid, unsigned eid, DataHandler* h) +void snort::MPDataBus::set_debug_enabled(bool flag) +{ + enable_debug = flag; + if(transport_layer) + { + if(flag) + { + transport_layer->enable_logging(); + } + else + { + transport_layer->disable_logging(); + } + } +} + +void MPDataBus::sum_stats() +{ + std::lock_guard lock(mp_stats_mutex); + + mp_global_stats.total_messages_sent = 0; + mp_global_stats.total_messages_received = 0; + mp_global_stats.total_messages_dropped = 0; + mp_global_stats.total_messages_published = 0; + mp_global_stats.total_messages_delivered = 0; + + for(auto& [_, pub_stats] : mp_pub_stats) + { + mp_global_stats.total_messages_dropped += pub_stats.total_messages_dropped; + mp_global_stats.total_messages_published += pub_stats.total_messages_published; + mp_global_stats.total_messages_received += pub_stats.total_messages_received; + mp_global_stats.total_messages_sent += pub_stats.total_messages_sent; + mp_global_stats.total_messages_delivered += pub_stats.total_messages_delivered; + } +} + +void MPDataBus::dump_stats(ControlConn *ctrlconn, const char *module_name) +{ + set_log_conn(ctrlconn); + if (module_name) + { + auto mod_id = mp_pub_ids.find(module_name); + if (mod_id == mp_pub_ids.end()) + { + return; + } + std::lock_guard lock(mp_stats_mutex); + auto mod_stats = mp_pub_stats[mod_id->second]; + + LogMessage("MPDataBus Stats for %s\n", module_name); + show_stats((PegCount*)&mod_stats, mp_databus_pegs, array_size(mp_databus_pegs)-1); + } + else + { + sum_stats(); + + show_stats((PegCount*)&mp_global_stats, mp_databus_pegs, array_size(mp_databus_pegs)-1); + + auto transport_module = ModuleManager::get_module(transport.c_str()); + if(transport_module) + { + auto transport_pegs = transport_module->get_pegs(); + if(transport_pegs) + { + uint size = 0; + while(transport_pegs[size].type != CountType::END) + { + size++; + } + show_stats(transport_module->get_counts(), transport_pegs, size); + } + } + } + set_log_conn(nullptr); +} + +void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name) +{ + int current_read_idx = 0; + uint ring_items = mp_event_queue->count(); + if(ring_items == 0) + { + if (ctrlconn) + { + ctrlconn->respond("No events in the event queue\n"); + } + else + { + LogMessage("No events in the event queue\n"); + } + return; + } + auto event_queue_store = mp_event_queue->grab_store(current_read_idx); + + if (current_read_idx == 0) + { + current_read_idx = mp_max_eventq_size - 1; + } + else + { + current_read_idx--; + } + + for (uint i = current_read_idx; i <= ring_items; i++) + { + if(i >= mp_max_eventq_size) + { + i = 0; + ring_items -= mp_max_eventq_size; + } + auto event_info = event_queue_store[i]; + if (event_info) + { + if (module_name) + { + if (event_info->pub_id != mp_pub_ids[module_name]) + { + continue; + } + } + if (ctrlconn) + { + ctrlconn->respond("Publisher module: %s, Event ID: %u\n", get_name_from_id(event_info->pub_id), event_info->type); + } + else + { + LogMessage("Publisher module: %s, Event ID: %u\n", get_name_from_id(event_info->pub_id), event_info->type); + } + } + } +} + +void snort::MPDataBus::show_channel_status(ControlConn *ctrlconn) +{ + if(!transport_layer or !ctrlconn) + { + return; + } + + uint size = 0; + auto transport_status = transport_layer->get_channel_status(size); + if (size == 0) + { + ctrlconn->respond("No active connections\n"); + return; + } + std::string response; + for (uint i = 0; i < size; i++) + { + const auto& channel = transport_status[i]; + response += "Channel ID: " + std::to_string(channel.id) + ", Name: " + channel.name + ", Status: " + channel.get_status_string() + "\n"; + } + + ctrlconn->respond("%s", response.c_str()); + delete[] transport_status; +} + +void MPDataBus::_subscribe(unsigned pid, unsigned eid, DataHandler *h) { std::pair key = {pid, eid}; @@ -287,15 +507,15 @@ void MPDataBus::_subscribe(const PubKey& key, unsigned eid, DataHandler* h) } -void MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f) +bool MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f) { std::pair key = {pid, eid}; auto it = mp_pub_sub.find(key); if (it == mp_pub_sub.end()) { - MP_DATABUS_LOG("MPDataBus: No subscribers for publisher ID %u and event ID %u\n", pid, eid); - return; + MPDataBusLog("No subscribers for publisher ID %u and event ID %u\n", pid, eid); + return false; } const SubList& subs = it->second; @@ -303,5 +523,7 @@ void MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f) { handler->handle(e, f); } + + return true; } diff --git a/src/framework/mp_data_bus.h b/src/framework/mp_data_bus.h index ee1ce340e..06c9efe9b 100644 --- a/src/framework/mp_data_bus.h +++ b/src/framework/mp_data_bus.h @@ -40,24 +40,18 @@ #include #include #include +#include +#include "control/control.h" +#include "framework/mp_transport.h" +#include "framework/counts.h" #include "main/snort_types.h" #include "data_bus.h" -#include "framework/mp_transport.h" -#include -#include "framework/mp_transport.h" #define DEFAULT_TRANSPORT "unix_transport" #define DEFAULT_MAX_EVENTQ_SIZE 1000 #define WORKER_THREAD_SLEEP 100 -#define MP_DATABUS_LOG(msg, ...) do { \ - if (!MPDataBus::enable_debug) \ - break; \ - LogMessage(msg, __VA_ARGS__); \ - } while (0) - - template class Ring; @@ -67,6 +61,33 @@ class Flow; struct Packet; struct SnortConfig; +struct MPDataBusStats +{ + MPDataBusStats() : + total_messages_sent(0), + total_messages_received(0), + total_messages_dropped(0), + total_messages_published(0), + total_messages_delivered(0) + { } + + PegCount total_messages_sent; + PegCount total_messages_received; + PegCount total_messages_dropped; + PegCount total_messages_published; + PegCount total_messages_delivered; +}; + +static const PegInfo mp_databus_pegs[] = +{ + { CountType::SUM, "total_messages_sent", "total messages sent" }, + { CountType::SUM, "total_messages_received", "total messages received" }, + { CountType::SUM, "total_messages_dropped", "total messages dropped" }, + { CountType::SUM, "total_messages_published", "total messages published" }, + { CountType::SUM, "total_messages_delivered", "total messages delivered" }, + { CountType::END, nullptr, nullptr }, +}; + typedef bool (*MPSerializeFunc)(DataEvent* event, char*& buffer, uint16_t* length); typedef bool (*MPDeserializeFunc)(const char* buffer, uint16_t length, DataEvent*& event); @@ -115,12 +136,17 @@ public: static uint32_t mp_max_eventq_size; static std::string transport; static bool enable_debug; +#ifdef REG_TEST + static bool hold_events; +#endif static MPTransport * transport_layer; + static MPDataBusStats mp_global_stats; unsigned init(int); void clone(MPDataBus& from, const char* exclude_name = nullptr); static unsigned get_id(const PubKey& key); + static const char* get_name_from_id(unsigned id); static bool valid(unsigned pub_id) { return pub_id != 0; } @@ -142,11 +168,19 @@ public: Ring>* get_event_queue() { return mp_event_queue; } + void set_debug_enabled(bool flag); + + void sum_stats(); + + void dump_stats(ControlConn* ctrlconn, const char* module_name); + void dump_events(ControlConn* ctrlconn, const char* module_name); + void show_channel_status(ControlConn* ctrlconn); + private: void _subscribe(unsigned pid, unsigned eid, DataHandler* h); void _subscribe(const PubKey& key, unsigned eid, DataHandler* h); - void _publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f); + bool _publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f); private: typedef std::vector SubList; @@ -161,6 +195,8 @@ private: static std::condition_variable queue_cv; static std::mutex queue_mutex; + std::unordered_map mp_pub_stats; + void start_worker_thread(); void stop_worker_thread(); void worker_thread_func(); diff --git a/src/framework/mp_transport.h b/src/framework/mp_transport.h index 4b2decb74..882fb7fd5 100644 --- a/src/framework/mp_transport.h +++ b/src/framework/mp_transport.h @@ -24,6 +24,7 @@ #include "framework/base_api.h" #include +#include namespace snort { @@ -36,6 +37,32 @@ struct MPHelperFunctions; typedef std::function TransportReceiveEventHandler; +enum MPTransportChannelStatus +{ + DISCONNECTED = 0, + CONNECTING, + CONNECTED, + MAX +}; + +struct MPTransportChannelStatusHandle +{ + int id = 0; + std::string name; + MPTransportChannelStatus status = DISCONNECTED; + + const char* get_status_string() const + { + switch (status) + { + case DISCONNECTED: return "DISCONNECTED"; + case CONNECTING: return "CONNECTING"; + case CONNECTED: return "CONNECTED"; + default: return "UNKNOWN"; + } + } +}; + class MPTransport { public: @@ -54,6 +81,7 @@ class MPTransport virtual void enable_logging() = 0; virtual void disable_logging() = 0; virtual bool is_logging_enabled() = 0; + virtual MPTransportChannelStatusHandle* get_channel_status(uint& size) = 0; }; diff --git a/src/framework/test/mp_data_bus_test.cc b/src/framework/test/mp_data_bus_test.cc index 658d0f894..0dad5aae1 100644 --- a/src/framework/test/mp_data_bus_test.cc +++ b/src/framework/test/mp_data_bus_test.cc @@ -26,6 +26,8 @@ #include "../main/snort_config.h" #include "utils/stats.h" #include "helpers/ring.h" +#include "main/snort.h" +#include "managers/module_manager.h" #include #include @@ -55,7 +57,30 @@ SnortConfig::SnortConfig(const SnortConfig* const, const char*) SnortConfig::~SnortConfig() { } + +void set_log_conn(ControlConn*) { } + +unsigned Snort::get_process_id() +{ + return 0; +} + +Module* ModuleManager::get_module(const char*) +{ + return nullptr; } +} + +void show_stats(PegCount*, const PegInfo*, unsigned, const char*) +{ + mock().actualCall("show_stats"); +} +void show_stats(PegCount*, const PegInfo*, const std::vector&, const char*, FILE*) { } +void show_stats(unsigned long*, PegInfo const*, char const*) {} + +bool ControlConn::respond(const char*, ...) { return true; } + +static bool test_transport_send_result = true; class MockMPTransport : public MPTransport { @@ -68,6 +93,11 @@ public: return count; } + static void reset_count() + { + count = 0; + } + static int get_test_register_helpers_calls() { return test_register_helpers_calls; @@ -76,7 +106,7 @@ public: bool send_to_transport(MPEventInfo&) override { count++; - return true; + return test_transport_send_result; } void register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&) override @@ -129,6 +159,13 @@ public: { return true; } + + MPTransportChannelStatusHandle* get_channel_status(uint& size) override + { + size = 0; + return nullptr; + } + private: inline static int count = 0; inline static int test_register_helpers_calls = 0; @@ -224,6 +261,7 @@ TEST_GROUP(mp_data_bus_pub) MPDataBus* mp_dbus = nullptr; void setup() override { + MockMPTransport::reset_count(); mp_dbus = new MPDataBus(); mp_dbus->init(2); pub_id1 = MPDataBus::get_id(pub_key1); @@ -247,9 +285,44 @@ TEST(mp_data_bus_pub, publish) std::this_thread::sleep_for(std::chrono::milliseconds(150)); + mp_dbus->sum_stats(); + CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_published); + CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_sent); + CHECK_EQUAL(0, MPDataBus::mp_global_stats.total_messages_dropped); + + mock().expectNCalls(2, "show_stats"); + + mp_dbus->dump_stats(nullptr, nullptr); + mp_dbus->dump_stats(nullptr, "mp_ut1"); + + mock().checkExpectations(); + + delete mp_dbus; + + CHECK_EQUAL(1, MockMPTransport::get_count()); +} + +TEST(mp_data_bus_pub, publish_fail_to_send) +{ + CHECK_TRUE(mp_dbus->get_event_queue()->empty()); + CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0); + + test_transport_send_result = false; + + std::shared_ptr event = std::make_shared(100); + + mp_dbus->publish(pub_id1, DbUtIds::EVENT1, event); + + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + + mp_dbus->sum_stats(); + CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_published); + CHECK_EQUAL(0, MPDataBus::mp_global_stats.total_messages_sent); + CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_dropped); + delete mp_dbus; - CHECK(1 == MockMPTransport::get_count()); + test_transport_send_result = true; } TEST_GROUP(mp_data_bus) @@ -331,6 +404,9 @@ TEST(mp_data_bus, subscribe_and_receive) MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1); SnortConfig::get_conf()->mp_dbus->receive_message(event_info1); + SnortConfig::get_conf()->mp_dbus->sum_stats(); + CHECK_EQUAL(2, MPDataBus::mp_global_stats.total_messages_received); + CHECK_EQUAL(200, h1->evt_msg); } diff --git a/src/helpers/ring.h b/src/helpers/ring.h index 914b78689..57f18d6b6 100644 --- a/src/helpers/ring.h +++ b/src/helpers/ring.h @@ -43,6 +43,8 @@ public: T get(T); bool put(T); + T* grab_store(int& ix); + int count(); bool full(); bool empty(); @@ -112,6 +114,16 @@ bool Ring::put(T v) return true; } +template +T* Ring::grab_store(int& ix) +{ + int i = logic.read(); + if ( i < 0 ) + return nullptr; + ix = i; + return store; +} + template int Ring::count() { diff --git a/src/main/modules.cc b/src/main/modules.cc index dbaa053fd..767f6e21b 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -47,6 +47,7 @@ #include "js_norm/js_norm_module.h" #include "latency/latency_module.h" #include "log/messages.h" +#include "lua/lua.h" #include "managers/module_manager.h" #include "managers/plugin_manager.h" #include "memory/memory_module.h" @@ -406,14 +407,64 @@ static const Parameter mp_data_bus_params[] = { "debug", Parameter::PT_BOOL, nullptr, "false", "enable debugging" }, +#ifdef REG_TEST + { "hold_events", Parameter::PT_BOOL, nullptr, "false", + "hold events from publishing" }, +#endif { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } }; +static int dump_mp_stats(lua_State* L) +{ + if (!L) + return 0; + + ControlConn* ctrlconn = ControlConn::query_from_lua(L); + + auto mod_name = luaL_optstring(L, 1, nullptr); + + if (mod_name and strlen(mod_name) == 0) + mod_name = nullptr; + + if(SnortConfig::get_conf()->mp_dbus) + SnortConfig::get_conf()->mp_dbus->dump_stats(ctrlconn, mod_name); + + return 0; +} + +static int dump_mp_events(lua_State* L) +{ + if (!L) + return 0; + + ControlConn* ctrlconn = ControlConn::query_from_lua(L); + + auto mod_name = luaL_optstring(L, 1, nullptr); + + if (mod_name and strlen(mod_name) == 0) + mod_name = nullptr; + + if(SnortConfig::get_conf()->mp_dbus) + SnortConfig::get_conf()->mp_dbus->dump_events(ctrlconn, mod_name); + + return 0; +} + +static int show_mp_channel_status(lua_State* L) +{ + ControlConn* ctrlconn = ControlConn::query_from_lua(L); + + if(SnortConfig::get_conf()->mp_dbus) + SnortConfig::get_conf()->mp_dbus->show_channel_status(ctrlconn); + + return 0; +} + static int enable_debug(lua_State*) { if(SnortConfig::get_conf()->mp_dbus) - SnortConfig::get_conf()->mp_dbus->enable_debug = true; + SnortConfig::get_conf()->mp_dbus->set_debug_enabled(true); return 0; } @@ -421,7 +472,7 @@ static int enable_debug(lua_State*) static int disable_debug(lua_State*) { if(SnortConfig::get_conf()->mp_dbus) - SnortConfig::get_conf()->mp_dbus->enable_debug = false; + SnortConfig::get_conf()->mp_dbus->set_debug_enabled(false); return 0; } @@ -430,6 +481,9 @@ static const Command mp_dbus_cmds[] = { {"enable", enable_debug, nullptr, "enable multiprocess data bus debugging"}, {"disable", disable_debug, nullptr, "disable multiprocess data bus debugging"}, + { "dump_stats", dump_mp_stats, nullptr, "dump multiprocess data bus statistics" }, + { "dump_events", dump_mp_events, nullptr, "dump multiprocess data bus events" }, + { "show_channel_status", show_mp_channel_status, nullptr, "show multiprocess data bus channel status" }, {nullptr, nullptr, nullptr, nullptr} }; @@ -446,6 +500,8 @@ public: bool begin(const char*, int, SnortConfig*) override; bool end(const char*, int, SnortConfig*) override; const Command* get_commands() const override; + const PegInfo* get_pegs() const override; + PegCount* get_counts() const override; Usage get_usage() const override { return GLOBAL; } @@ -475,6 +531,12 @@ bool MPDataBusModule::set(const char*, Value& v, SnortConfig*) { MPDataBus::enable_debug = v.get_bool(); } +#ifdef REG_TEST + else if ( v.is("hold_events") ) + { + MPDataBus::hold_events = v.get_bool(); + } +#endif else { WarningMessage("MPDataBus: Unknown parameter '%s' in mp_data_bus module\n", v.get_name()); @@ -488,6 +550,17 @@ const Command* MPDataBusModule::get_commands() const return mp_dbus_cmds; } +const PegInfo* MPDataBusModule::get_pegs() const +{ + return mp_databus_pegs; +} + +PegCount* MPDataBusModule::get_counts() const +{ + if(SnortConfig::get_conf()->mp_dbus) + SnortConfig::get_conf()->mp_dbus->sum_stats(); + return (PegCount*)&MPDataBus::mp_global_stats; +} //------------------------------------------------------------------------- // reference module //------------------------------------------------------------------------- diff --git a/src/main/snort.cc b/src/main/snort.cc index b735a25b6..8d18aa83d 100644 --- a/src/main/snort.cc +++ b/src/main/snort.cc @@ -174,7 +174,6 @@ void Snort::init(int argc, char** argv) if (sc->max_procs > 1) { sc->mp_dbus = new MPDataBus(); - sc->mp_dbus->init(sc->max_procs); } PluginManager::load_so_plugins(sc); @@ -195,6 +194,11 @@ void Snort::init(int argc, char** argv) ModuleManager::init_stats(); ModuleManager::reset_stats(sc); + if (sc->mp_dbus) + { + sc->mp_dbus->init(sc->max_procs); + } + if (sc->alert_before_pass()) sc->rule_order = IpsAction::get_default_priorities(true); diff --git a/src/main/snort_config.cc b/src/main/snort_config.cc index 5db70d17b..1a863cefd 100644 --- a/src/main/snort_config.cc +++ b/src/main/snort_config.cc @@ -63,6 +63,7 @@ #include "managers/mpse_manager.h" #include "managers/plugin_manager.h" #include "managers/so_manager.h" +#include "managers/mp_transport_manager.h" #include "memory/memory_config.h" #include "packet_io/sfdaq.h" #include "packet_io/sfdaq_config.h" @@ -1133,6 +1134,7 @@ void SnortConfig::cleanup_fatal_error() const SnortConfig* sc = SnortConfig::get_conf(); if ( sc && !sc->dirty_pig ) { + MPTransportManager::term(); ModuleManager::term(); EventManager::release_plugins(); IpsManager::release_plugins(); diff --git a/src/managers/test/mp_transport_manager_test.cc b/src/managers/test/mp_transport_manager_test.cc index 5438ac529..d4900d3ca 100644 --- a/src/managers/test/mp_transport_manager_test.cc +++ b/src/managers/test/mp_transport_manager_test.cc @@ -56,6 +56,11 @@ class MockTransport : public MPTransport { } virtual bool is_logging_enabled() override { return false; } + MPTransportChannelStatusHandle* get_channel_status(uint& size) override + { + size = 0; + return nullptr; + } }; unsigned get_instance_id() { return 0; } @@ -65,10 +70,8 @@ unsigned ThreadConfig::get_instance_max() { return 1; } using namespace snort; -void show_stats(unsigned long*, PegInfo const*, std::vector > const&, char const*, _IO_FILE*) -{} -void show_stats(unsigned long*, PegInfo const*, unsigned int, char const*) -{} +void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { } +void show_stats(PegCount*, const PegInfo*, const std::vector&, const char*, FILE*) { } static void mock_transport_tinit(MPTransport* t) { diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc index 8e40562f7..783c69aa4 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc @@ -25,6 +25,7 @@ #include #include +#include #include #include #include @@ -37,17 +38,11 @@ #include "main/snort_config.h" static std::mutex _receive_mutex; -static std::mutex _update_connectors_mutex; +static std::mutex _send_mutex; +static std::mutex _read_mutex; #define UNIX_SOCKET_NAME_PREFIX "/snort_unix_connector_" - -#define MP_TRANSPORT_LOG_LABEL "MPUnixTransport" - -#define MP_TRANSPORT_LOG(msg, ...) do { \ - if (!this->is_logging_enabled_flag) \ - break; \ - LogMessage(msg, __VA_ARGS__); \ - } while (0) +#define MP_TRANSPORT_LOG_LABEL "MPUnixTransportDbg" namespace snort { @@ -80,7 +75,7 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg) { if (msg->content_length < sizeof(MPTransportMessage)) { - MP_TRANSPORT_LOG("%s: Incomplete message received\n", MP_TRANSPORT_LOG_LABEL); + MPTransportLog("Incomplete message received\n"); return; } @@ -88,14 +83,14 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg) if (transport_message_header->type >= MAX_TYPE) { - MP_TRANSPORT_LOG("%s: Invalid message type received\n", MP_TRANSPORT_LOG_LABEL); + MPTransportLog("Invalid message type received\n"); return; } auto deserialize_func = get_event_deserialization_function(transport_message_header->pub_id, transport_message_header->event_id); if (!deserialize_func) { - MP_TRANSPORT_LOG("%s: No deserialization function found for event: type %d, id %d\n", MP_TRANSPORT_LOG_LABEL, transport_message_header->type, transport_message_header->event_id); + MPTransportLog("No deserialization function found for event: type %d, id %d\n", transport_message_header->type, transport_message_header->event_id); return; } @@ -104,29 +99,33 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg) MPEventInfo event(std::shared_ptr (internal_event), transport_message_header->event_id, transport_message_header->pub_id); (transport_receive_handler)(event); - + transport_stats.received_events++; + transport_stats.received_bytes += sizeof(MPTransportMessageHeader) + transport_message_header->data_length; } delete msg; } -void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector, UnixDomainConnectorConfig* cfg) +void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector, UnixDomainConnectorConfig* cfg, const ushort& channel_id) { assert(connector); assert(cfg); - std::lock_guard guard(_update_connectors_mutex); + std::lock_guard guard_send(_send_mutex); + std::lock_guard guard_read(_read_mutex); + + transport_stats.successful_connections++; auto side_channel = new SideChannel(ScMsgFormat::BINARY); side_channel->connector_receive = connector; side_channel->connector_transmit = side_channel->connector_receive; side_channel->register_receive_handler(std::bind(&MPUnixDomainTransport::side_channel_receive_handler, this, std::placeholders::_1)); connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this)); - this->side_channels.push_back(new SideChannelHandle(side_channel, cfg)); + this->side_channels.push_back(new SideChannelHandle(side_channel, cfg, channel_id)); connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); } -MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c) : MPTransport(), - config(c) +MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c, MPUnixTransportStats& stats) : MPTransport(), + config(c), transport_stats(stats) { this->is_logging_enabled_flag = c->enable_logging; } @@ -142,7 +141,8 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event) if (!serialize_func) { - MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event.type); + transport_stats.send_errors++; + MPTransportLog("No serialize function found for event %d\n", event.type); return false; } @@ -153,15 +153,25 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event) (serialize_func)(event.event.get(), transport_message.data, &transport_message.header.data_length); - for (auto &&sc_handler : this->side_channels) { - auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length); - memcpy(msg->content, &transport_message, sizeof(MPTransportMessageHeader)); - memcpy(msg->content + sizeof(MPTransportMessageHeader), transport_message.data, transport_message.header.data_length); - auto send_result = sc_handler->side_channel->transmit_message(msg); - if (!send_result) + std::lock_guard guard(_send_mutex); + + for (auto &&sc_handler : this->side_channels) { - MP_TRANSPORT_LOG("%s: Failed to send message to side channel\n", MP_TRANSPORT_LOG_LABEL); + auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length); + memcpy(msg->content, &transport_message, sizeof(MPTransportMessageHeader)); + memcpy(msg->content + sizeof(MPTransportMessageHeader), transport_message.data, transport_message.header.data_length); + auto send_result = sc_handler->side_channel->transmit_message(msg); + if (!send_result) + { + MPTransportLog("Failed to send message to side channel\n"); + transport_stats.send_errors++; + } + else + { + transport_stats.sent_events++; + transport_stats.sent_bytes += sizeof(MPTransportMessageHeader) + transport_message.header.data_length; + } } } @@ -176,7 +186,7 @@ void MPUnixDomainTransport::register_event_helpers(const unsigned& pub_id, const assert(helper.serializer); this->event_helpers[pub_id] = SerializeFunctionHandle(); - this->event_helpers[pub_id].serialize_functions.insert({event_id, helper}); + this->event_helpers[pub_id].serialize_functions.insert({event_id, std::move(helper)}); } void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler& handler) @@ -201,7 +211,7 @@ void MPUnixDomainTransport::process_messages_from_side_channels() } { - std::lock_guard guard(_update_connectors_mutex); + std::lock_guard guard(_read_mutex); bool messages_left; do @@ -227,7 +237,8 @@ void MPUnixDomainTransport::notify_process_thread() void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_recconecting, SideChannel *side_channel) { - std::lock_guard guard(_update_connectors_mutex); + std::lock_guard guard_send(_send_mutex); + std::lock_guard guard_read(_read_mutex); if (side_channel->connector_receive) { delete side_channel->connector_receive; @@ -236,13 +247,15 @@ void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connec if (connector) { + connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this)); side_channel->connector_receive = side_channel->connector_transmit = connector; + this->transport_stats.successful_connections++; } else { if (is_recconecting == false) { - MP_TRANSPORT_LOG("%s: Accepted connection interrupted, removing handle\n", MP_TRANSPORT_LOG_LABEL); + MPTransportLog("Accepted connection interrupted, removing handle\n"); for(auto it = this->side_channels.begin(); it != this->side_channels.end(); ++it) { if ((*it)->side_channel == side_channel) @@ -252,22 +265,41 @@ void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connec break; } } + this->transport_stats.closed_connections++; + } + else + { + this->transport_stats.connection_retries++; } } } +void MPUnixDomainTransport::MPTransportLog(const char *msg, ...) +{ + if (!is_logging_enabled_flag) + return; + + char buf[256]; + va_list args; + va_start(args, msg); + vsnprintf(buf, sizeof(buf), msg, args); + va_end(args); + + LogMessage("%s ID=%d %s", MP_TRANSPORT_LOG_LABEL, mp_current_process_id, buf); +} + MPSerializeFunc MPUnixDomainTransport::get_event_serialization_function(unsigned pub_id, unsigned event_id) { auto helper_it = this->event_helpers.find(pub_id); if (helper_it == this->event_helpers.end()) { - MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id); + MPTransportLog("%s: No available helper functions is registered for %d\n", pub_id); return nullptr; } auto helper_functions = helper_it->second.get_function_set(event_id); if (!helper_functions) { - MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id); + MPTransportLog("%s: No serialize function found for event %d\n", event_id); return nullptr; } return helper_functions->serializer; @@ -278,13 +310,13 @@ MPDeserializeFunc MPUnixDomainTransport::get_event_deserialization_function(unsi auto helper_it = this->event_helpers.find(pub_id); if (helper_it == this->event_helpers.end()) { - MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id); + MPTransportLog("No available helper functions is registered for %d\n", pub_id); return nullptr; } auto helper_functions = helper_it->second.get_function_set(event_id); if (!helper_functions) { - MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id); + MPTransportLog("No serialize function found for event %d\n", event_id); return nullptr; } return helper_functions->deserializer; @@ -337,14 +369,26 @@ void MPUnixDomainTransport::init_side_channels() if (config->max_processes < 2) return; - auto instance_id = Snort::get_process_id();//Snort instance id + auto instance_id = mp_current_process_id = Snort::get_process_id();//Snort instance id auto max_processes = config->max_processes; this->is_running = true; + if ( std::filesystem::is_directory(config->unix_domain_socket_path) == false ) + { + std::error_code ec; + std::filesystem::create_directories(config->unix_domain_socket_path, ec); + if (ec) + { + MPTransportLog("Failed to create directory %s\n", config->unix_domain_socket_path.c_str()); + return; + } + } + for (ushort i = instance_id; i < max_processes; i++) { auto listen_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i); + auto unix_listener = new UnixDomainConnectorListener(listen_path.c_str()); UnixDomainConnectorConfig* unix_config = new UnixDomainConnectorConfig(); @@ -366,7 +410,7 @@ void MPUnixDomainTransport::init_side_channels() } unix_config->paths.push_back(listen_path); - unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2), unix_config); + unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2, instance_id + i), unix_config); auto unix_listener_handle = new UnixAcceptorHandle(); unix_listener_handle->connector_config = unix_config; @@ -379,7 +423,7 @@ void MPUnixDomainTransport::init_side_channels() auto side_channel = new SideChannel(ScMsgFormat::BINARY); side_channel->register_receive_handler([this](SCMessage* msg) { this->side_channel_receive_handler(msg); }); - auto send_path = config->unix_domain_socket_path + "/" + "snort_unix_connector_" + std::to_string(i); + auto send_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i); UnixDomainConnectorConfig* connector_conf = new UnixDomainConnectorConfig(); connector_conf->setup = UnixDomainConnectorConfig::Setup::CALL; @@ -390,26 +434,22 @@ void MPUnixDomainTransport::init_side_channels() connector_conf->connect_timeout_seconds = config->connect_timeout_seconds; connector_conf->paths.push_back(send_path); - auto connector = unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); - - if (connector) - connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this)); - - side_channel->connector_receive = connector; - side_channel->connector_transmit = side_channel->connector_receive; - this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf)); + unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); + + this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf, i)); } this->consume_thread = new std::thread(&MPUnixDomainTransport::process_messages_from_side_channels, this); } + void MPUnixDomainTransport::cleanup_side_channels() { - std::lock_guard guard(_update_connectors_mutex); + std::lock_guard guard_send(_send_mutex); + std::lock_guard guard_read(_read_mutex); for (uint i = 0; i < this->side_channels.size(); i++) { - auto side_channel = this->side_channels[i]; - delete side_channel; + delete this->side_channels[i]; } this->side_channels.clear(); @@ -428,6 +468,7 @@ SideChannelHandle::~SideChannelHandle() if (connector_config) delete connector_config; } + void MPUnixDomainTransport::enable_logging() { this->is_logging_enabled_flag = true; @@ -443,4 +484,29 @@ bool MPUnixDomainTransport::is_logging_enabled() return this->is_logging_enabled_flag; } -}; +MPTransportChannelStatusHandle *MPUnixDomainTransport::get_channel_status(uint &size) +{ + std::lock_guard guard_send(_send_mutex); + std::lock_guard guard_read(_read_mutex); + if (this->side_channels.size() == 0) + { + size = 0; + return nullptr; + } + MPTransportChannelStatusHandle* result = new MPTransportChannelStatusHandle[this->side_channels.size()]; + + size = this->side_channels.size(); + uint it = 0; + + for (auto &&sc_handler : this->side_channels) + { + result[it].id = sc_handler->channel_id; + result[it].status = sc_handler->side_channel->connector_receive ? MPTransportChannelStatus::CONNECTED : MPTransportChannelStatus::CONNECTING; + result[it].name = "Snort connection to " + std::to_string(sc_handler->channel_id) + " instance"; + it++; + } + + return result; +} + +} diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.h b/src/mp_transport/mp_unix_transport/mp_unix_transport.h index 8208e030e..0c23eb087 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport.h +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport.h @@ -32,6 +32,29 @@ namespace snort { +struct MPUnixTransportStats +{ + MPUnixTransportStats() : + sent_events(0), + sent_bytes(0), + received_events(0), + received_bytes(0), + send_errors(0), + successful_connections(0), + closed_connections(0), + connection_retries(0) + { } + + PegCount sent_events; + PegCount sent_bytes; + PegCount received_events; + PegCount received_bytes; + PegCount send_errors; + PegCount successful_connections; + PegCount closed_connections; + PegCount connection_retries; +}; + struct MPUnixDomainTransportConfig { std::string unix_domain_socket_path; @@ -60,14 +83,15 @@ struct SerializeFunctionHandle struct SideChannelHandle { - SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc) : - side_channel(sc), connector_config(cc) + SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc, const ushort& channel_id) : + side_channel(sc), connector_config(cc), channel_id(channel_id) { } ~SideChannelHandle(); SideChannel* side_channel; UnixDomainConnectorConfig* connector_config; + ushort channel_id; }; struct UnixAcceptorHandle @@ -80,7 +104,7 @@ class MPUnixDomainTransport : public MPTransport { public: - MPUnixDomainTransport(MPUnixDomainTransportConfig* c); + MPUnixDomainTransport(MPUnixDomainTransportConfig* c, MPUnixTransportStats& stats); ~MPUnixDomainTransport() override; bool configure(const SnortConfig*) override; @@ -95,6 +119,7 @@ class MPUnixDomainTransport : public MPTransport void disable_logging() override; bool is_logging_enabled() override; void cleanup(); + MPTransportChannelStatusHandle* get_channel_status(uint& size) override; MPUnixDomainTransportConfig* get_config() { return config; } @@ -105,14 +130,17 @@ class MPUnixDomainTransport : public MPTransport void init_side_channels(); void cleanup_side_channels(); void side_channel_receive_handler(SCMessage* msg); - void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg); + void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg, const ushort& channel_id); void process_messages_from_side_channels(); void notify_process_thread(); void connector_update_handler(UnixDomainConnector* connector, bool is_recconecting, SideChannel* side_channel); + void MPTransportLog(const char* msg, ...); MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id); MPDeserializeFunc get_event_deserialization_function(unsigned pub_id, unsigned event_id); + uint mp_current_process_id = 0; + TransportReceiveEventHandler transport_receive_handler = nullptr; MPUnixDomainTransportConfig* config = nullptr; @@ -126,6 +154,8 @@ class MPUnixDomainTransport : public MPTransport std::thread* consume_thread = nullptr; std::condition_variable consume_thread_cv; + + MPUnixTransportStats& transport_stats; }; } diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc index b0c8c5f05..f01109f20 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc @@ -25,6 +25,7 @@ #include "main/snort_config.h" #include "log/messages.h" +#include "utils/stats.h" #define DEFAULT_UNIX_DOMAIN_SOCKET_PATH "/tmp/snort_unix_connectors" @@ -42,6 +43,19 @@ static const Parameter unix_transport_params[] = { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } }; +static const PegInfo mp_unix_transport_pegs[] = +{ + { CountType::SUM, "sent_events", "mp_transport events sent count" }, + { CountType::SUM, "sent_bytes", "mp_transport events bytes sent count" }, + { CountType::SUM, "receive_events", "mp_transport events received count" }, + { CountType::SUM, "receive_bytes", "mp_transport events bytes received count" }, + { CountType::SUM, "sent_errors", "mp_transport events errors count" }, + { CountType::SUM, "successful_connections", "successful mp_transport connections count" }, + { CountType::SUM, "closed_connections", "closed mp_transport connections count" }, + { CountType::SUM, "connection_retries", "mp_transport connection retries count" }, + { CountType::END, nullptr, nullptr }, +}; + MPUnixDomainTransportModule::MPUnixDomainTransportModule(): Module(MODULE_NAME, MODULE_HELP, unix_transport_params) { config = nullptr; @@ -96,6 +110,16 @@ bool MPUnixDomainTransportModule::set(const char *, Value & v, SnortConfig *) return true; } +const PegInfo *MPUnixDomainTransportModule::get_pegs() const +{ + return mp_unix_transport_pegs; +} + +PegCount *MPUnixDomainTransportModule::get_counts() const +{ + return (PegCount*)&unix_transport_stats; +} + static struct MPTransportApi mp_unixdomain_transport_api = { { diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h index 464b4f3e8..fae5bbc17 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h @@ -42,10 +42,14 @@ class MPUnixDomainTransportModule : public Module bool begin(const char*, int, SnortConfig*) override; bool set(const char*, Value&, SnortConfig*) override; + const PegInfo* get_pegs() const override; + PegCount* get_counts() const override; + Usage get_usage() const override { return GLOBAL; } MPUnixDomainTransportConfig* config; + MPUnixTransportStats unix_transport_stats; }; static Module* mod_ctor() @@ -61,7 +65,7 @@ static void mod_dtor(Module* m) static MPTransport* mp_unixdomain_transport_ctor(Module* m) { auto unix_tr_mod = (MPUnixDomainTransportModule*)m; - return new MPUnixDomainTransport(unix_tr_mod->config); + return new MPUnixDomainTransport(unix_tr_mod->config, unix_tr_mod->unix_transport_stats); } static void mp_unixdomain_transport_dtor(MPTransport* t) diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc index c00ff41a6..9f26f4d35 100644 --- a/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc +++ b/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc @@ -42,11 +42,9 @@ namespace snort return 1; } - MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig* config) : - MPTransport() - { - this->config = config; - } + MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig* config, MPUnixTransportStats& stats) : + MPTransport(), config(config), transport_stats(stats) + { } MPUnixDomainTransport::~MPUnixDomainTransport() { destroy_cnt++; } void MPUnixDomainTransport::thread_init() @@ -73,6 +71,11 @@ namespace snort {} void MPUnixDomainTransport::disable_logging() {} + MPTransportChannelStatusHandle* MPUnixDomainTransport::get_channel_status(unsigned int& size) + { + size = 0; + return nullptr; + } char* snort_strdup(const char*) { diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc index 84cf84979..bad3074ac 100644 --- a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc +++ b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc @@ -204,6 +204,7 @@ UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorCo test_call_sock_created++; auto new_conn = new UnixDomainConnector(cfg, 0, idx); call_connector = new_conn; + update_handler(new_conn, false); return new_conn; } assert(false); @@ -269,6 +270,7 @@ bool deserialize_mock(const char* buffer, uint16_t length, DataEvent*& event) MPHelperFunctions mp_helper_functions_mock(serialize_mock, deserialize_mock); static MPUnixDomainTransportConfig test_config; +static MPUnixTransportStats test_stats; static MPUnixDomainTransport* test_transport = nullptr; static SnortConfig test_snort_config(nullptr, nullptr); @@ -278,7 +280,7 @@ TEST_GROUP(unix_transport_test_connectivity_group) void setup() override { test_snort_config.max_procs = 2; - test_transport = new MPUnixDomainTransport(&test_config); + test_transport = new MPUnixDomainTransport(&test_config, test_stats); test_transport->configure(&test_snort_config); } @@ -294,8 +296,8 @@ static MPUnixDomainTransportConfig test_config_message; static MPTransport* test_transport_message_1 = nullptr; static MPTransport* test_transport_message_2 = nullptr; -static int reciveved_1_msg_cnt = 0; -static int reciveved_2_msg_cnt = 0; +static int received_1_msg_cnt = 0; +static int received_2_msg_cnt = 0; TEST_GROUP(unix_transport_test_messaging) { @@ -305,31 +307,31 @@ TEST_GROUP(unix_transport_test_messaging) accept_cnt = 1; - test_config_message.unix_domain_socket_path = "/tmp"; + test_config_message.unix_domain_socket_path = "."; test_config_message.max_processes = 2; test_config_message.conn_retries = false; test_config_message.retry_interval_seconds = 0; test_config_message.max_retries = 0; test_config_message.connect_timeout_seconds = 30; - test_transport_message_1 = new MPUnixDomainTransport(&test_config_message); + test_transport_message_1 = new MPUnixDomainTransport(&test_config_message, test_stats); snort_instance_id = 1; test_transport_message_1->configure(&test_snort_config); test_transport_message_1->init_connection(); test_transport_message_1->register_receive_handler([](const snort::MPEventInfo& e) { - reciveved_1_msg_cnt++; + received_1_msg_cnt++; }); std::this_thread::sleep_for(std::chrono::milliseconds(100)); - test_transport_message_2 = new MPUnixDomainTransport(&test_config_message); + test_transport_message_2 = new MPUnixDomainTransport(&test_config_message, test_stats); snort_instance_id = 2; test_transport_message_2->configure(&test_snort_config); test_transport_message_2->init_connection(); test_transport_message_2->register_receive_handler([](const snort::MPEventInfo& e) { - reciveved_2_msg_cnt++; + received_2_msg_cnt++; }); } @@ -369,7 +371,7 @@ TEST(unix_transport_test_connectivity_group, set_logging_enabled_disabled) TEST(unix_transport_test_connectivity_group, init_connection_single_snort_instance) { clear_test_calls(); - test_config.unix_domain_socket_path = "/tmp"; + test_config.unix_domain_socket_path = "."; test_config.max_processes = 1; test_transport->init_connection(); @@ -389,7 +391,7 @@ TEST(unix_transport_test_connectivity_group, init_connection_first_snort_instanc clear_test_calls(); snort_instance_id = 1; - test_config.unix_domain_socket_path = "/tmp"; + test_config.unix_domain_socket_path = "."; test_config.max_processes = 2; accept_cnt = 1; @@ -406,7 +408,7 @@ TEST(unix_transport_test_connectivity_group, init_connection_second_snort_instan { clear_test_calls(); snort_instance_id = 2; - test_config.unix_domain_socket_path = "/tmp"; + test_config.unix_domain_socket_path = "."; test_config.max_processes = 2; test_transport->init_connection(); @@ -425,7 +427,7 @@ TEST(unix_transport_test_connectivity_group, connector_update_handler_call) { clear_test_calls(); - test_config.unix_domain_socket_path = "/tmp"; + test_config.unix_domain_socket_path = "."; test_config.max_processes = 2; accept_cnt = 1; @@ -470,10 +472,10 @@ TEST(unix_transport_test_messaging, send_to_transport_biderectional) std::this_thread::sleep_for(std::chrono::milliseconds(500)); - CHECK(test_deserialize_calls == 2); - CHECK(reciveved_1_msg_cnt == 1); - CHECK(reciveved_2_msg_cnt == 1); - CHECK(test_send_calls == 2); + CHECK_EQUAL(2, test_deserialize_calls); + CHECK_EQUAL(1 ,received_1_msg_cnt); + CHECK_EQUAL(1, received_2_msg_cnt); + CHECK_EQUAL(2, test_send_calls); }; TEST(unix_transport_test_messaging, send_to_transport_no_helpers) @@ -488,16 +490,16 @@ TEST(unix_transport_test_messaging, send_to_transport_no_helpers) CHECK(res == false); CHECK(test_serialize_calls == 0); CHECK(test_deserialize_calls == 0); - CHECK(reciveved_1_msg_cnt == 0); - CHECK(reciveved_2_msg_cnt == 0); + CHECK(received_1_msg_cnt == 0); + CHECK(received_2_msg_cnt == 0); CHECK(test_send_calls == 0); res = test_transport_message_2->send_to_transport(event); CHECK(res == false); CHECK(test_serialize_calls == 0); CHECK(test_deserialize_calls == 0); - CHECK(reciveved_1_msg_cnt == 0); - CHECK(reciveved_2_msg_cnt == 0); + CHECK(received_1_msg_cnt == 0); + CHECK(received_2_msg_cnt == 0); CHECK(test_send_calls == 0); } diff --git a/src/side_channel/side_channel.cc b/src/side_channel/side_channel.cc index 3fe5a0f80..b0fb87049 100644 --- a/src/side_channel/side_channel.cc +++ b/src/side_channel/side_channel.cc @@ -300,9 +300,15 @@ bool SideChannel::discard_message(SCMessage* msg) const bool SideChannel::transmit_message(SCMessage* msg) const { - if ( !connector_transmit or !msg ) + if(!msg) return false; + if ( !connector_transmit) + { + delete msg; + return false; + } + if ( msg_format == ScMsgFormat::TEXT ) { std::string text = sc_msg_data_to_text(msg->content, msg->content_length);