]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4757: mp_unix_transport: added reset stats handling
authorOleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco) <ostepano@cisco.com>
Tue, 3 Jun 2025 20:49:54 +0000 (20:49 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Tue, 3 Jun 2025 20:49:54 +0000 (20:49 +0000)
Merge in SNORT/snort3 from ~OSTEPANO/snort3:transport_opt to master

Squashed commit of the following:

commit 85abeddb909fee7f7107f6ff049004c5713840d6
Author: Oleksandr Stepanov <ostepano@cisco.com>
Date:   Mon May 12 05:50:44 2025 -0400

    mp_unix_transport: use shared mutex in message processing

13 files changed:
src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc
src/connectors/unixdomain_connector/unixdomain_connector.cc
src/framework/mp_data_bus.cc
src/framework/mp_data_bus.h
src/framework/mp_transport.h
src/main.cc
src/main/analyzer_command.h
src/main/modules.cc
src/mp_transport/mp_unix_transport/mp_unix_transport.cc
src/mp_transport/mp_unix_transport/mp_unix_transport.h
src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc
src/mp_transport/mp_unix_transport/mp_unix_transport_module.h
src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc

index f9d5c2d17d1560a70395e280224960ad8108d686..81f466349309572126a137eb675bbf104769ba2f 100644 (file)
@@ -888,7 +888,9 @@ TEST(unixdomain_connector_reconnect_helper, connect_then_reconnect_call)
     auto tmp_test_connector = test_reconnect_connector;
 
     //trigger the reconnect
-    test_reconnect_connector->process_receive();
+    test_reconnect_connector->start_receive_thread();
+
+    std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
     //collapse the reconnect_helper joining reconnect thread
     delete reconnect_helper;
index 97bb769d20ffca78edf17efb764acc13bbd709fb..ab276905997927133d94527d5a78967f2ee94dff 100644 (file)
@@ -295,6 +295,13 @@ void UnixDomainConnector::process_receive() {
     } 
     else if ((pfds[0].revents & (POLLHUP | POLLERR | POLLNVAL)) != 0) 
     {
+        if (run_thread.load() == false)
+        {
+            close(sock_fd);
+            sock_fd = -1;
+            return;
+        }
+
         ErrorMessage("UnixDomainC Input Thread: Undesirable return event while polling on socket %d: 0x%x\n",
                 pfds[0].fd, pfds[0].revents);
 
@@ -349,8 +356,8 @@ void UnixDomainConnector::start_receive_thread() {
 
 void UnixDomainConnector::stop_receive_thread() {
 
+    run_thread.store(false, std::memory_order_relaxed);
     if (receive_thread != nullptr) {
-        run_thread.store(false, std::memory_order_relaxed);
         if (receive_thread->joinable()) {
             receive_thread->join();
         }
index 0b9f2f37ff035ae9a3e507a21eb58e056d397436..ed907dd0f8c58bd4ff9f274bd0e9a8d958f6dccf 100644 (file)
@@ -375,6 +375,32 @@ void MPDataBus::sum_stats()
     }
 }
 
+void MPDataBus::reset_stats()
+{
+    std::lock_guard<std::mutex> lock(mp_stats_mutex);
+
+    for(auto& [_, pub_stats] : mp_pub_stats)
+    {
+        pub_stats.total_messages_sent = 0;
+        pub_stats.total_messages_received = 0;
+        pub_stats.total_messages_dropped = 0;
+        pub_stats.total_messages_published = 0;
+        pub_stats.total_messages_delivered = 0;
+    }
+    
+    mp_global_stats.total_messages_sent = 0;
+    mp_global_stats.total_messages_received = 0;
+    mp_global_stats.total_messages_dropped = 0;
+    mp_global_stats.total_messages_published = 0;
+    mp_global_stats.total_messages_delivered = 0;
+
+    auto transport_module = ModuleManager::get_module(transport.c_str());
+    if (transport_module)
+    {
+        transport_module->reset_stats();
+    }
+}
+
 void MPDataBus::dump_stats(ControlConn *ctrlconn, const char *module_name)
 {
     set_log_conn(ctrlconn);
index bdfcfb68faa409401b0b4cc6fb30d08636e3b3a7..d72a0101fd25b2a2f8d5976d1cc99dcc135f5a6d 100644 (file)
@@ -173,6 +173,7 @@ public:
     void set_debug_enabled(bool flag);
 
     void sum_stats();
+    void reset_stats();
 
     void dump_stats(ControlConn* ctrlconn, const char* module_name);
     void dump_events(ControlConn* ctrlconn, const char* module_name);
index 5e9cfb941e65fdb6805000df3050ec37d224880a..0ff1baa4399dd9b2b0d96a5eefd4c62b24a5bef1 100644 (file)
@@ -40,6 +40,7 @@ typedef std::function<void (const MPEventInfo& event_info)> TransportReceiveEven
 enum MPTransportChannelStatus
 {
     DISCONNECTED = 0,
+    ACCEPTING,
     CONNECTING,
     CONNECTED,
     MAX
@@ -58,6 +59,7 @@ struct MPTransportChannelStatusHandle
             case DISCONNECTED: return "DISCONNECTED";
             case CONNECTING: return "CONNECTING";
             case CONNECTED: return "CONNECTED";
+            case ACCEPTING: return "ACCEPTING";
             default: return "UNKNOWN";
         }
     }
index 4e5e0e91c81544ec8b7952b66e8b8dbef428e055..252d563e78b544cdabcba64f7c70788e5951035e 100644 (file)
@@ -96,6 +96,7 @@ static const std::map<std::string, clear_counter_type_t> counter_name_to_id =
        {"file_id", clear_counter_type_t::TYPE_FILE_ID},
        {"snort", clear_counter_type_t::TYPE_SNORT},
        {"ha", clear_counter_type_t::TYPE_HA},
+    {"messaging", clear_counter_type_t::TYPE_MESSAGING},
        {"all", clear_counter_type_t::TYPE_ALL}
 };
 
index df5c7ab261af3ad58e5c880f01d75da47d9dc51f..03e670e8f46a5d52b8a5e8b630cd1cc01dd17b1f 100644 (file)
@@ -85,6 +85,7 @@ typedef enum clear_counter_type
     TYPE_FILE_ID,
     TYPE_SNORT,
     TYPE_HA,
+    TYPE_MESSAGING,
        TYPE_ALL
 } clear_counter_type_t;
 
@@ -98,7 +99,8 @@ static std::vector<const char*> clear_counter_type_string_map
     "appid",
     "file_id",
     "snort",
-    "high_availability"
+    "high_availability",
+    "mp_data_bus"
 };
 
 class ACResetStats : public snort::AnalyzerCommand
index 767f6e21b72df970a6de974d0e1477de6444c137..ed9da58e11e756f009615c5dc056842b8f3aaa4b 100644 (file)
@@ -502,6 +502,7 @@ public:
     const Command* get_commands() const override;
     const PegInfo* get_pegs() const override;
     PegCount* get_counts() const override;
+    void reset_stats() override;
 
     Usage get_usage() const override
     { return GLOBAL; }
@@ -561,6 +562,12 @@ PegCount* MPDataBusModule::get_counts() const
         SnortConfig::get_conf()->mp_dbus->sum_stats();
     return (PegCount*)&MPDataBus::mp_global_stats;
 }
+
+void MPDataBusModule::reset_stats()
+{
+    if(SnortConfig::get_conf()->mp_dbus)
+        SnortConfig::get_conf()->mp_dbus->reset_stats();
+}
 //-------------------------------------------------------------------------
 // reference module
 //-------------------------------------------------------------------------
index c9c8354ef9b2c0e11d78cd071aadee089464e1e8..23390435720e7a422320055f90badc3dbdb4ace8 100644 (file)
@@ -31,6 +31,7 @@
 #include <sys/un.h>
 #include <sys/stat.h>
 #include <unistd.h>
+#include <shared_mutex>
 
 #include "framework/mp_data_bus.h"
 #include "log/messages.h"
@@ -38,8 +39,7 @@
 #include "main/snort_config.h"
 
 static std::mutex _receive_mutex;
-static std::mutex _send_mutex;
-static std::mutex _read_mutex;
+static std::shared_mutex _connection_update_mutex;
 
 #define UNIX_SOCKET_NAME_PREFIX "/snort_unix_connector_"
 #define MP_TRANSPORT_LOG_LABEL "MPUnixTransportDbg"
@@ -99,8 +99,8 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg)
         MPEventInfo event(std::shared_ptr<DataEvent> (internal_event), transport_message_header->event_id, transport_message_header->pub_id);
 
         (transport_receive_handler)(event);
-        transport_stats.received_events++;
-        transport_stats.received_bytes += sizeof(MPTransportMessageHeader) + transport_message_header->data_length;
+        dynamic_transport_stats.received_events++;
+        dynamic_transport_stats.received_bytes += sizeof(MPTransportMessageHeader) + transport_message_header->data_length;
     }
     delete msg;
 }
@@ -110,20 +110,19 @@ void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector
     assert(connector);
     assert(cfg);
 
-    std::lock_guard<std::mutex> guard_send(_send_mutex);
-    std::lock_guard<std::mutex> guard_read(_read_mutex);
+    std::lock_guard<std::shared_mutex> guard(_connection_update_mutex);
 
     if(!this->is_running.load())
         return;
 
-    transport_stats.successful_connections++;
+    dynamic_transport_stats.successful_connections++;
 
     auto side_channel = new SideChannel(ScMsgFormat::BINARY);
     side_channel->connector_receive = connector;
     side_channel->connector_transmit = side_channel->connector_receive;
     side_channel->register_receive_handler(std::bind(&MPUnixDomainTransport::side_channel_receive_handler, this, std::placeholders::_1));
     connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
-    this->side_channels.push_back(new SideChannelHandle(side_channel, cfg, channel_id));
+    this->side_channels.push_back(new SideChannelHandle(side_channel, cfg, channel_id+this->side_channels.size()-mp_current_process_id+1));
     connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
     connector->start_receive_thread();
 }
@@ -145,7 +144,7 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
 
     if (!serialize_func)
     {
-        transport_stats.send_errors++;
+        dynamic_transport_stats.send_errors++;
         MPTransportLog("No serialize function found for event %d\n", event.type);
         return false;
     }
@@ -158,7 +157,7 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
 
     (serialize_func)(event.event.get(), transport_message.data, &transport_message.header.data_length);
     {
-        std::lock_guard<std::mutex> guard(_send_mutex);
+        std::shared_lock<std::shared_mutex> guard(_connection_update_mutex);
 
         for (auto &&sc_handler : this->side_channels)
         {
@@ -169,12 +168,12 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
             if (!send_result)
             {
                 MPTransportLog("Failed to send message to side channel\n");
-                transport_stats.send_errors++;
+                dynamic_transport_stats.send_errors++;
             }
             else
             {
-                transport_stats.sent_events++;
-                transport_stats.sent_bytes += sizeof(MPTransportMessageHeader) + transport_message.header.data_length;
+                dynamic_transport_stats.sent_events++;
+                dynamic_transport_stats.sent_bytes += sizeof(MPTransportMessageHeader) + transport_message.header.data_length;
             }
         }
     }
@@ -215,7 +214,7 @@ void MPUnixDomainTransport::process_messages_from_side_channels()
         }
 
         {
-            std::lock_guard<std::mutex> guard(_read_mutex);
+            std::shared_lock<std::shared_mutex> guard(_connection_update_mutex);
             bool messages_left;
 
             do
@@ -241,12 +240,10 @@ void MPUnixDomainTransport::notify_process_thread()
 
 void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_reconecting, SideChannel *side_channel)
 {
-    std::lock_guard<std::mutex> guard_send(_send_mutex);
-    std::lock_guard<std::mutex> guard_read(_read_mutex);
-
-    if(!this->is_running.load())
+    if (this->is_running == false)
         return;
-
+    std::lock_guard<std::shared_mutex> guard(_connection_update_mutex);
+    
     if (side_channel->connector_receive)
     {
         delete side_channel->connector_receive;
@@ -258,7 +255,7 @@ void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connec
         connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
         side_channel->connector_receive = side_channel->connector_transmit = connector;
         connector->start_receive_thread();
-        this->transport_stats.successful_connections++;
+        this->dynamic_transport_stats.successful_connections++;
     }
     else
     {
@@ -274,11 +271,11 @@ void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connec
                     break;
                 }
             }
-            this->transport_stats.closed_connections++;
+            this->dynamic_transport_stats.closed_connections++;
         }
         else
         {
-            this->transport_stats.connection_retries++;
+            this->dynamic_transport_stats.connection_retries++;
         }
     }
 }
@@ -302,13 +299,13 @@ MPSerializeFunc MPUnixDomainTransport::get_event_serialization_function(unsigned
     auto helper_it = this->event_helpers.find(pub_id);
     if (helper_it == this->event_helpers.end())
     {
-        MPTransportLog("%s: No available helper functions is registered for %d\n", pub_id);
+        MPTransportLog("No available helper functions is registered for %d\n", pub_id);
         return nullptr;
     }
     auto helper_functions = helper_it->second.get_function_set(event_id);
     if (!helper_functions)
     {
-        MPTransportLog("%s: No serialize function found for event %d\n", event_id);
+        MPTransportLog("No serialize function found for event %d\n", event_id);
         return nullptr;
     }
     return helper_functions->serializer;
@@ -382,7 +379,6 @@ void MPUnixDomainTransport::init_side_channels()
         return;
 
     auto instance_id = mp_current_process_id = Snort::get_process_id();//Snort instance id
-    auto max_processes = config->max_processes;
 
     this->is_running = true;
 
@@ -396,9 +392,9 @@ void MPUnixDomainTransport::init_side_channels()
         }
     }
 
-    for (unsigned short i = instance_id; i < max_processes; i++)
+    if (instance_id < config->max_processes)
     {
-        auto listen_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i);
+        auto listen_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(instance_id);
 
         auto unix_listener = new UnixDomainConnectorListener(listen_path.c_str());
         
@@ -421,8 +417,8 @@ void MPUnixDomainTransport::init_side_channels()
         }
         unix_config->paths.push_back(listen_path);
 
-        unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2, instance_id + i), unix_config);
-        
+        unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2, instance_id+1), unix_config);
+
         auto unix_listener_handle = new UnixAcceptorHandle();
         unix_listener_handle->connector_config = unix_config;
         unix_listener_handle->listener = unix_listener;
@@ -459,8 +455,7 @@ void MPUnixDomainTransport::init_side_channels()
 
 void MPUnixDomainTransport::cleanup_side_channels()
 {
-    std::lock_guard<std::mutex> guard_send(_send_mutex);
-    std::lock_guard<std::mutex> guard_read(_read_mutex);
+    std::lock_guard<std::shared_mutex> guard(_connection_update_mutex);
 
     for (uint32_t i = 0; i < this->side_channels.size(); i++)
     {
@@ -505,26 +500,41 @@ bool MPUnixDomainTransport::is_logging_enabled()
     return this->is_logging_enabled_flag;
 }
 
+void MPUnixDomainTransport::sum_stats()
+{
+    std::lock_guard<std::shared_mutex> _guard(_connection_update_mutex);
+    this->transport_stats = this->dynamic_transport_stats;
+}
+
+void MPUnixDomainTransport::reset_stats()
+{
+    std::lock_guard<std::shared_mutex> _guard(_connection_update_mutex);
+    this->dynamic_transport_stats = MPUnixTransportStats();
+    this->transport_stats = MPUnixTransportStats();
+}
+
 MPTransportChannelStatusHandle *MPUnixDomainTransport::get_channel_status(unsigned& size)
 {
-    std::lock_guard<std::mutex> guard_send(_send_mutex);
-    std::lock_guard<std::mutex> guard_read(_read_mutex);
-    if (this->side_channels.size() == 0)
-    {
-        size = 0;
-        return nullptr;
-    }
-    MPTransportChannelStatusHandle* result = new MPTransportChannelStatusHandle[this->side_channels.size()];
+    std::shared_lock<std::shared_mutex> _guard(_connection_update_mutex);
+    MPTransportChannelStatusHandle* result = new MPTransportChannelStatusHandle[this->config->max_processes-1];
 
-    size = this->side_channels.size();
-    unsigned int it = 0;
+    size = this->config->max_processes-1;
 
     for (auto &&sc_handler : this->side_channels)
     {
-        result[it].id = sc_handler->channel_id;
-        result[it].status = sc_handler->side_channel->connector_receive ? MPTransportChannelStatus::CONNECTED : MPTransportChannelStatus::CONNECTING;
-        result[it].name = "Snort connection to " + std::to_string(sc_handler->channel_id) + " instance";
-        it++;
+        unsigned short idx = sc_handler->channel_id > mp_current_process_id ? sc_handler->channel_id-2 : sc_handler->channel_id-1;
+        result[idx].id = sc_handler->channel_id;
+        result[idx].status = sc_handler->side_channel->connector_receive ? MPTransportChannelStatus::CONNECTED : MPTransportChannelStatus::CONNECTING;
+        result[idx].name = "Snort connection to instance " + std::to_string(sc_handler->channel_id);
+    }
+
+    for(uint16_t it = 0; it < this->config->max_processes-1; it++)
+    {
+        if(result[it].id != 0)
+            continue;
+        result[it].id = it + 2 > mp_current_process_id ? it + 2 : it + 1;
+        result[it].status = result[it].id > mp_current_process_id ? MPTransportChannelStatus::ACCEPTING : MPTransportChannelStatus::DISCONNECTED;
+        result[it].name = "Snort connection to instance " + std::to_string(result[it].id);
     }
 
     return result;
index 5d69dbe1077b8e5c6a5098dcd62dbd6bd157210a..199fbf9fee7d9e5551efbdf10c8af001f4320e13 100644 (file)
@@ -44,6 +44,31 @@ struct MPUnixTransportStats
         closed_connections(0),
         connection_retries(0)
     { }
+
+    MPUnixTransportStats(const MPUnixTransportStats& other) :
+        sent_events(other.sent_events),
+        sent_bytes(other.sent_bytes),
+        received_events(other.received_events),
+        received_bytes(other.received_bytes),
+        send_errors(other.send_errors),
+        successful_connections(other.successful_connections),
+        closed_connections(other.closed_connections),
+        connection_retries(other.connection_retries)
+    { }
+
+    MPUnixTransportStats& operator=(const MPUnixTransportStats& other)
+    {
+        sent_events = other.sent_events;
+        sent_bytes = other.sent_bytes;
+        received_events = other.received_events;
+        received_bytes = other.received_bytes;
+        send_errors = other.send_errors;
+        successful_connections = other.successful_connections;
+        closed_connections = other.closed_connections;
+        connection_retries = other.connection_retries;
+
+        return *this;
+    }
     
     PegCount sent_events;
     PegCount sent_bytes;
@@ -127,6 +152,9 @@ class MPUnixDomainTransport : public MPTransport
     { return config; }
 
 
+    void sum_stats();
+    void reset_stats();
+
     private:
 
     void init_side_channels();
@@ -141,7 +169,7 @@ class MPUnixDomainTransport : public MPTransport
     MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id);
     MPDeserializeFunc get_event_deserialization_function(unsigned pub_id, unsigned event_id);
 
-    uint32_t mp_current_process_id = 0;
+    uint16_t mp_current_process_id = 0;
 
     TransportReceiveEventHandler transport_receive_handler = nullptr;
     MPUnixDomainTransportConfig* config = nullptr;
@@ -157,6 +185,7 @@ class MPUnixDomainTransport : public MPTransport
     std::thread* consume_thread = nullptr;
     std::condition_variable consume_thread_cv;
 
+    MPUnixTransportStats dynamic_transport_stats;
     MPUnixTransportStats& transport_stats;
 };
 
index f01109f207b3b429b3193ef2ef9bee5cd15f66da..454830dae81d776ddbe98054ccbe7931b00befbd 100644 (file)
@@ -65,7 +65,7 @@ bool MPUnixDomainTransportModule::begin(const char *, int, SnortConfig *sc)
 {
     assert(sc);
     assert(!config);
-    config = new MPUnixDomainTransportConfig;
+    config = new MPUnixDomainTransportConfig;    
     config->max_processes = sc->max_procs;
     return true;
 }
@@ -117,9 +117,24 @@ const PegInfo *MPUnixDomainTransportModule::get_pegs() const
 
 PegCount *MPUnixDomainTransportModule::get_counts() const
 {
+    if (transport_handle)
+    {
+        transport_handle->sum_stats();
+    }
+    
     return (PegCount*)&unix_transport_stats;
 }
 
+void MPUnixDomainTransportModule::reset_stats()
+{
+    unix_transport_stats = MPUnixTransportStats();
+    if (transport_handle)
+    {
+        transport_handle->reset_stats();
+    }
+    Module::reset_stats();
+}
+
 static struct MPTransportApi mp_unixdomain_transport_api =
 {
     {
index fae5bbc17b2bb7b1d0e9c0b8c62856629f819faa..0c98ff7a34d5c54f4182aa3cd1c4248bc2a157d1 100644 (file)
@@ -44,12 +44,15 @@ class MPUnixDomainTransportModule : public Module
 
     const PegInfo* get_pegs() const override;
     PegCount* get_counts() const override;
+    bool global_stats() const override { return true; }
+    void reset_stats() override;
 
     Usage get_usage() const override
     { return GLOBAL; }
 
     MPUnixDomainTransportConfig* config;
     MPUnixTransportStats unix_transport_stats;
+    MPUnixDomainTransport* transport_handle = nullptr;
 };
 
 static Module* mod_ctor()
@@ -65,7 +68,8 @@ static void mod_dtor(Module* m)
 static MPTransport* mp_unixdomain_transport_ctor(Module* m)
 {
     auto unix_tr_mod = (MPUnixDomainTransportModule*)m;
-    return new MPUnixDomainTransport(unix_tr_mod->config, unix_tr_mod->unix_transport_stats);
+    unix_tr_mod->transport_handle = new MPUnixDomainTransport(unix_tr_mod->config, unix_tr_mod->unix_transport_stats);
+    return unix_tr_mod->transport_handle;
 }
 
 static void mp_unixdomain_transport_dtor(MPTransport* t)
index 9f26f4d35d714346361ddbc1ac8acf9a4654610b..d069929eb062b88dcfff514af0d8e8f444043cc5 100644 (file)
@@ -76,6 +76,14 @@ namespace snort
         size = 0;
         return nullptr;
     }
+    void MPUnixDomainTransport::reset_stats()
+    {
+        transport_stats = MPUnixTransportStats();
+    }
+    void MPUnixDomainTransport::sum_stats()
+    {
+        
+    }
 
     char* snort_strdup(const char*)
     {