]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4718: mp_data_bus: Adding stats and CLI commands to MPDataBus
authorOleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco) <ostepano@cisco.com>
Fri, 2 May 2025 20:37:21 +0000 (20:37 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Fri, 2 May 2025 20:37:21 +0000 (20:37 +0000)
Merge in SNORT/snort3 from ~OSTEPANO/snort3:cli_stats_mp to master

Squashed commit of the following:

commit 8160a86149c4b0030e74b6a04a6919ce55bf3913
Author: Oleksandr Stepanov <ostepano@cisco.com>
Date:   Mon Apr 28 06:02:08 2025 -0400

    mp_data_bus: Adding peg stats and socket commands for MPDataBus

17 files changed:
src/connectors/unixdomain_connector/unixdomain_connector.cc
src/framework/mp_data_bus.cc
src/framework/mp_data_bus.h
src/framework/mp_transport.h
src/framework/test/mp_data_bus_test.cc
src/helpers/ring.h
src/main/modules.cc
src/main/snort.cc
src/main/snort_config.cc
src/managers/test/mp_transport_manager_test.cc
src/mp_transport/mp_unix_transport/mp_unix_transport.cc
src/mp_transport/mp_unix_transport/mp_unix_transport.h
src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc
src/mp_transport/mp_unix_transport/mp_unix_transport_module.h
src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc
src/mp_transport/mp_unix_transport/test/unix_transport_test.cc
src/side_channel/side_channel.cc

index 24a75779bf396280d14a72e20ef568da2f4c0475..461df406cdccc0d1ce884b573031dd47579fef0a 100644 (file)
@@ -85,7 +85,7 @@ static bool attempt_connection(int& sfd, const char* path) {
 // Function to handle connection retries
 static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) {
     if(update_handler)
-        update_handler(nullptr, (cfg.conn_retries > 0));
+        update_handler(nullptr, ( (cfg.conn_retries > 0) and (cfg.setup == UnixDomainConnectorConfig::Setup::CALL) ));
     else
         ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr);
 
index 8e6477050e02a9f17d9fe4ad118e62113efc424b..8197afeb4c5306e368fe3f2249c5ab6e1d55aa44 100644 (file)
 #include "utils/stats.h"
 #include "main/snort_types.h"
 #include "log/messages.h"
+#include "log/log_stats.h"
 #include "helpers/ring.h"
 #include "managers/mp_transport_manager.h"
+#include "managers/module_manager.h"
+#include "main/snort.h"
+#include "framework/module.h"
 
 using namespace snort;
 
+void MPDataBusLog(const char* msg, ...);
+
 std::condition_variable MPDataBus::queue_cv;
 std::mutex MPDataBus::queue_mutex;
 uint32_t MPDataBus::mp_max_eventq_size = DEFAULT_MAX_EVENTQ_SIZE;
 std::string MPDataBus::transport = DEFAULT_TRANSPORT;
 bool MPDataBus::enable_debug = false;
 MPTransport* MPDataBus::transport_layer = nullptr;
+MPDataBusStats MPDataBus::mp_global_stats = {};
+#ifdef REG_TEST
+bool MPDataBus::hold_events = false;
+#endif
 
 static std::unordered_map<std::string, unsigned> mp_pub_ids;
+static std::mutex mp_stats_mutex;
+static uint mp_current_process_id = 0;
+
+void MPDataBusLog(const char* msg, ...)
+{
+    if (!MPDataBus::enable_debug)
+        return;
+
+    char buf[256];
+    va_list args;
+    va_start(args, msg);
+    vsnprintf(buf, sizeof(buf), msg, args);
+    va_end(args);
+
+    LogMessage("MPDataBusDbg ID=%d %s", mp_current_process_id, buf);
+}
 
 //--------------------------------------------------------------------------
 // public methods
@@ -60,7 +86,7 @@ MPDataBus::~MPDataBus()
 {
     stop_worker_thread();
 
-    for (auto& [key, sublist] : mp_pub_sub)
+    for (auto& [_, sublist] : mp_pub_sub)
     {
         for (auto* handler : sublist)
         {
@@ -87,6 +113,8 @@ unsigned MPDataBus::init(int max_procs)
         return 1;
     }
 
+    mp_current_process_id = Snort::get_process_id();
+
     transport_layer = MPTransportManager::get_transport(transport);
     if (transport_layer == nullptr)
     {
@@ -136,6 +164,18 @@ unsigned MPDataBus::get_id(const PubKey& key)
     return mp_pub_ids[key.publisher];
 }
 
+const char* MPDataBus::get_name_from_id(unsigned id)
+{
+    for (const auto& [name, pub_id] : mp_pub_ids)
+    {
+        if (pub_id == id)
+        {
+            return name.c_str();
+        }
+    }
+    return nullptr;
+}
+
 void MPDataBus::subscribe(const PubKey& key, unsigned eid, DataHandler* h)
 {
     if(! SnortConfig::get_conf()->mp_dbus)
@@ -145,7 +185,7 @@ void MPDataBus::subscribe(const PubKey& key, unsigned eid, DataHandler* h)
     }
 
     SnortConfig::get_conf()->mp_dbus->_subscribe(key, eid, h);
-    MP_DATABUS_LOG("MPDataBus: Subscribed to event ID %u\n", eid);
+    MPDataBusLog("Subscribed to event ID %u\n", eid);
 }
 
 bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, std::shared_ptr<DataEvent> e, Flow*)
@@ -166,12 +206,9 @@ bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, std::shared_ptr<DataEv
         return false;
     }
 
-    {
-        std::lock_guard<std::mutex> lock(queue_mutex);
-        queue_cv.notify_one();
-    }
+    queue_cv.notify_one();
 
-    MP_DATABUS_LOG("MPDataBus: Event published for publisher ID %u and event ID %u\n", pub_id, evt_id);
+    MPDataBusLog("Event published for publisher ID %u and event ID %u\n", pub_id, evt_id);
 
     return true;
 }
@@ -189,7 +226,7 @@ void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSer
     MPHelperFunctions helpers(mp_serializer_helper, mp_deserializer_helper);
     
     SnortConfig::get_conf()->mp_dbus->transport_layer->register_event_helpers(pub_id, evt_id, helpers);
-    MP_DATABUS_LOG("MPDataBus: Registered event helpers for event ID %u\n", evt_id);
+    MPDataBusLog("Registered event helpers for event ID %u\n", evt_id);
 }
 
 // API for receiving the DataEvent and Event type from transport layer
@@ -199,9 +236,18 @@ void MPDataBus::receive_message(const MPEventInfo& event_info)
     unsigned evt_id = event_info.type;
     unsigned pub_id = event_info.pub_id;
 
-    MP_DATABUS_LOG("MPDataBus: Received message for publisher ID %u and event ID %u\n", pub_id, evt_id);
+    MPDataBusLog("Received message for publisher ID %u and event ID %u\n", pub_id, evt_id);
+
+    auto pub_res = _publish(pub_id, evt_id, *e, nullptr);
 
-    _publish(pub_id, evt_id, *e, nullptr);
+    {
+        std::lock_guard<std::mutex> lock(mp_stats_mutex);
+        mp_pub_stats[pub_id].total_messages_received++;
+        if(pub_res)
+        {
+            mp_pub_stats[pub_id].total_messages_delivered++;
+        }
+    }
 }
 
 
@@ -210,25 +256,42 @@ void MPDataBus::receive_message(const MPEventInfo& event_info)
 //--------------------------------------------------------------------------
 void MPDataBus::process_event_queue()
 {
+#ifdef REG_TEST
+    if (hold_events)
+    {
+        return;
+    }
+#endif
     if (!mp_event_queue) {
         return;
     }
 
-    std::unique_lock<std::mutex> lock(queue_mutex);
+    std::unique_lock<std::mutex> u_lock(queue_mutex);
 
-    queue_cv.wait_for(lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP), [this]() {
+    queue_cv.wait_for(u_lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP), [this]() {
         return mp_event_queue != nullptr && !mp_event_queue->empty();
     });
 
-    lock.unlock();
-
     while (!mp_event_queue->empty()) {
         std::shared_ptr<MPEventInfo> event_info = mp_event_queue->get(nullptr);
         if (event_info) {
-            MP_DATABUS_LOG("MPDataBus: Processing event for publisher ID %u \n",
+            MPDataBusLog("Processing event for publisher ID %u \n",
                         event_info->pub_id);
 
-            transport_layer->send_to_transport(*event_info);
+            auto send_res = transport_layer->send_to_transport(*event_info);
+
+            {
+                std::lock_guard<std::mutex> lock(mp_stats_mutex);
+                mp_pub_stats[event_info->pub_id].total_messages_published++;
+                if (!send_res)
+                {
+                    mp_pub_stats[event_info->pub_id].total_messages_dropped++;
+                }
+                else
+                {
+                    mp_pub_stats[event_info->pub_id].total_messages_sent++;
+                }
+            }
         }
     }
 }
@@ -270,7 +333,164 @@ static bool compare(DataHandler* a, DataHandler* b)
     return false;
 }
 
-void MPDataBus::_subscribe(unsigned pid, unsigned eid, DataHandler* h)
+void snort::MPDataBus::set_debug_enabled(bool flag)
+{
+    enable_debug = flag;
+    if(transport_layer)
+    {
+        if(flag)
+        {
+            transport_layer->enable_logging();
+        }
+        else
+        {
+            transport_layer->disable_logging();
+        }
+    }
+}
+
+void MPDataBus::sum_stats()
+{
+    std::lock_guard<std::mutex> lock(mp_stats_mutex);
+
+    mp_global_stats.total_messages_sent = 0;
+    mp_global_stats.total_messages_received = 0;
+    mp_global_stats.total_messages_dropped = 0;
+    mp_global_stats.total_messages_published = 0;
+    mp_global_stats.total_messages_delivered = 0;
+
+    for(auto& [_, pub_stats] : mp_pub_stats)
+    {
+        mp_global_stats.total_messages_dropped += pub_stats.total_messages_dropped;
+        mp_global_stats.total_messages_published += pub_stats.total_messages_published;
+        mp_global_stats.total_messages_received += pub_stats.total_messages_received;
+        mp_global_stats.total_messages_sent += pub_stats.total_messages_sent;
+        mp_global_stats.total_messages_delivered += pub_stats.total_messages_delivered;
+    }
+}
+
+void MPDataBus::dump_stats(ControlConn *ctrlconn, const char *module_name)
+{
+    set_log_conn(ctrlconn);
+    if (module_name)
+    {
+        auto mod_id = mp_pub_ids.find(module_name);
+        if (mod_id == mp_pub_ids.end())
+        {
+            return;
+        }
+        std::lock_guard<std::mutex> lock(mp_stats_mutex);
+        auto mod_stats = mp_pub_stats[mod_id->second];
+
+        LogMessage("MPDataBus Stats for %s\n", module_name);
+        show_stats((PegCount*)&mod_stats, mp_databus_pegs, array_size(mp_databus_pegs)-1);
+    }
+    else
+    {
+        sum_stats();
+        
+        show_stats((PegCount*)&mp_global_stats, mp_databus_pegs, array_size(mp_databus_pegs)-1);
+
+        auto transport_module = ModuleManager::get_module(transport.c_str());
+        if(transport_module)
+        {
+            auto transport_pegs = transport_module->get_pegs();
+            if(transport_pegs)
+            {
+                uint size = 0;
+                while(transport_pegs[size].type != CountType::END)
+                {
+                    size++;
+                }
+                show_stats(transport_module->get_counts(), transport_pegs, size);
+            }
+        }
+    }
+    set_log_conn(nullptr);
+}
+
+void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name)
+{
+    int current_read_idx = 0;
+    uint ring_items = mp_event_queue->count();
+    if(ring_items == 0)
+    {
+        if (ctrlconn)
+        {
+            ctrlconn->respond("No events in the event queue\n");
+        }
+        else
+        {
+            LogMessage("No events in the event queue\n");
+        }
+        return;
+    }
+    auto event_queue_store = mp_event_queue->grab_store(current_read_idx);
+
+    if (current_read_idx == 0)
+    {
+        current_read_idx = mp_max_eventq_size - 1;
+    }
+    else
+    {
+        current_read_idx--;
+    }
+
+    for (uint i = current_read_idx; i <= ring_items; i++)
+    {
+        if(i >= mp_max_eventq_size)
+        {
+            i = 0;
+            ring_items -= mp_max_eventq_size;
+        }
+        auto event_info = event_queue_store[i];
+        if (event_info)
+        {
+            if (module_name)
+            {
+                if (event_info->pub_id != mp_pub_ids[module_name])
+                {
+                    continue;
+                }
+            }
+            if (ctrlconn)
+            {
+                ctrlconn->respond("Publisher module: %s, Event ID: %u\n", get_name_from_id(event_info->pub_id), event_info->type);
+            }
+            else
+            {
+                LogMessage("Publisher module: %s, Event ID: %u\n", get_name_from_id(event_info->pub_id), event_info->type);
+            }
+        }
+    }
+}
+
+void snort::MPDataBus::show_channel_status(ControlConn *ctrlconn)
+{
+    if(!transport_layer or !ctrlconn)
+    {
+        return;
+    }
+
+    uint size = 0;
+    auto transport_status = transport_layer->get_channel_status(size);
+    if (size == 0)
+    {
+        ctrlconn->respond("No active connections\n");
+        return;
+    }
+    std::string response;
+    for (uint i = 0; i < size; i++)
+    {
+        const auto& channel = transport_status[i];
+        response += "Channel ID: " + std::to_string(channel.id) + ", Name: " + channel.name + ", Status: " + channel.get_status_string() + "\n";
+    }
+
+    ctrlconn->respond("%s", response.c_str());
+    delete[] transport_status;
+}
+
+void MPDataBus::_subscribe(unsigned pid, unsigned eid, DataHandler *h)
 {
     std::pair<unsigned, unsigned> key = {pid, eid};
 
@@ -287,15 +507,15 @@ void MPDataBus::_subscribe(const PubKey& key, unsigned eid, DataHandler* h)
 }
 
 
-void MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f)
+bool MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f)
 {
     std::pair<unsigned, unsigned> key = {pid, eid};
 
     auto it = mp_pub_sub.find(key);
     if (it == mp_pub_sub.end())
     {
-        MP_DATABUS_LOG("MPDataBus: No subscribers for publisher ID %u and event ID %u\n", pid, eid);
-        return;
+        MPDataBusLog("No subscribers for publisher ID %u and event ID %u\n", pid, eid);
+        return false;
     }
     const SubList& subs = it->second;
 
@@ -303,5 +523,7 @@ void MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f)
     {
         handler->handle(e, f);
     }
+
+    return true;
 }
 
index ee1ce340e25005f0fa18ff96810a5c22db9e5506..06c9efe9b7af139c2ca7b8c2266b7c2745b539fb 100644 (file)
 #include <queue>
 #include <atomic>
 #include <thread>
+#include <bitset>
 
+#include "control/control.h"
+#include "framework/mp_transport.h"
+#include "framework/counts.h"
 #include "main/snort_types.h"
 #include "data_bus.h"
-#include "framework/mp_transport.h"
-#include <bitset>
-#include "framework/mp_transport.h"
 
 #define DEFAULT_TRANSPORT "unix_transport"
 #define DEFAULT_MAX_EVENTQ_SIZE 1000
 #define WORKER_THREAD_SLEEP 100
 
-#define MP_DATABUS_LOG(msg, ...) do { \
-        if (!MPDataBus::enable_debug) \
-            break; \
-        LogMessage(msg, __VA_ARGS__); \
-    } while (0)
-
-
 template <typename T>
 class Ring;
 
@@ -67,6 +61,33 @@ class Flow;
 struct Packet;
 struct SnortConfig;
 
+struct MPDataBusStats
+{
+    MPDataBusStats() :
+        total_messages_sent(0),
+        total_messages_received(0),
+        total_messages_dropped(0),
+        total_messages_published(0),
+        total_messages_delivered(0)
+    { }
+
+    PegCount total_messages_sent;
+    PegCount total_messages_received;
+    PegCount total_messages_dropped;
+    PegCount total_messages_published;
+    PegCount total_messages_delivered;
+};
+
+static const PegInfo mp_databus_pegs[] =
+{
+    { CountType::SUM, "total_messages_sent", "total messages sent" },
+    { CountType::SUM, "total_messages_received", "total messages received" },
+    { CountType::SUM, "total_messages_dropped", "total messages dropped" },
+    { CountType::SUM, "total_messages_published", "total messages published" },
+    { CountType::SUM, "total_messages_delivered", "total messages delivered" },
+    { CountType::END, nullptr, nullptr },
+};
+
 typedef bool (*MPSerializeFunc)(DataEvent* event, char*& buffer, uint16_t* length);
 typedef bool (*MPDeserializeFunc)(const char* buffer, uint16_t length, DataEvent*& event);
 
@@ -115,12 +136,17 @@ public:
     static uint32_t mp_max_eventq_size;
     static std::string transport;
     static bool enable_debug;
+#ifdef REG_TEST
+    static bool hold_events;
+#endif
 
     static MPTransport * transport_layer;
+    static MPDataBusStats mp_global_stats;
     unsigned init(int);
     void clone(MPDataBus& from, const char* exclude_name = nullptr);
 
     static unsigned get_id(const PubKey& key);
+    static const char* get_name_from_id(unsigned id);
 
     static bool valid(unsigned pub_id)
     { return pub_id != 0; }
@@ -142,11 +168,19 @@ public:
     Ring<std::shared_ptr<MPEventInfo>>* get_event_queue()
     { return mp_event_queue; }
 
+    void set_debug_enabled(bool flag);
+
+    void sum_stats();
+
+    void dump_stats(ControlConn* ctrlconn, const char* module_name);
+    void dump_events(ControlConn* ctrlconn, const char* module_name);
+    void show_channel_status(ControlConn* ctrlconn);
+
 private: 
     void _subscribe(unsigned pid, unsigned eid, DataHandler* h);
     void _subscribe(const PubKey& key, unsigned eid, DataHandler* h);
 
-    void _publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f);
+    bool _publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f);
 
 private:
     typedef std::vector<DataHandler*> SubList;
@@ -161,6 +195,8 @@ private:
     static std::condition_variable queue_cv;
     static std::mutex queue_mutex;
 
+    std::unordered_map<unsigned, MPDataBusStats> mp_pub_stats;
+
     void start_worker_thread();
     void stop_worker_thread();
     void worker_thread_func();
index 4b2decb7406a41c72f19eb829f5c2b5866306e7c..882fb7fd5b942dd15103d8268ee56d8f26efbc96 100644 (file)
@@ -24,6 +24,7 @@
 #include "framework/base_api.h"
 
 #include <functional>
+#include <string>
 
 namespace snort
 {
@@ -36,6 +37,32 @@ struct MPHelperFunctions;
 
 typedef std::function<void (const MPEventInfo& event_info)> TransportReceiveEventHandler;
 
+enum MPTransportChannelStatus
+{
+    DISCONNECTED = 0,
+    CONNECTING,
+    CONNECTED,
+    MAX
+};
+
+struct MPTransportChannelStatusHandle
+{
+    int id = 0;
+    std::string name;
+    MPTransportChannelStatus status = DISCONNECTED;
+
+    const char* get_status_string() const
+    {
+        switch (status)
+        {
+            case DISCONNECTED: return "DISCONNECTED";
+            case CONNECTING: return "CONNECTING";
+            case CONNECTED: return "CONNECTED";
+            default: return "UNKNOWN";
+        }
+    }
+};
+
 class MPTransport
 {
     public:
@@ -54,6 +81,7 @@ class MPTransport
     virtual void enable_logging() = 0;
     virtual void disable_logging() = 0;
     virtual bool is_logging_enabled() = 0;
+    virtual MPTransportChannelStatusHandle* get_channel_status(uint& size) = 0;
 };
 
 
index 658d0f8949584bca2fb7142133b476efcea38871..0dad5aae1ef0d6ab54fedf5f113ca26ebc462181 100644 (file)
@@ -26,6 +26,8 @@
 #include "../main/snort_config.h"
 #include "utils/stats.h"
 #include "helpers/ring.h"
+#include "main/snort.h"
+#include "managers/module_manager.h"
 #include <condition_variable>
 
 #include <CppUTest/CommandLineTestRunner.h>
@@ -55,7 +57,30 @@ SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 
 SnortConfig::~SnortConfig()
 { }
+
+void set_log_conn(ControlConn*) { }
+
+unsigned Snort::get_process_id()
+{
+    return 0;
+}
+
+Module* ModuleManager::get_module(const char*)
+{
+    return nullptr;
 }
+}
+
+void show_stats(PegCount*, const PegInfo*, unsigned, const char*) 
+{
+    mock().actualCall("show_stats");
+}
+void show_stats(PegCount*, const PegInfo*, const std::vector<unsigned>&, const char*, FILE*) { }
+void show_stats(unsigned long*, PegInfo const*, char const*) {}
+
+bool ControlConn::respond(const char*, ...) { return true; }
+
+static bool test_transport_send_result = true;
 
 class MockMPTransport : public MPTransport
 {
@@ -68,6 +93,11 @@ public:
         return count;
     }
 
+    static void reset_count()
+    {
+        count = 0;
+    }
+
     static int get_test_register_helpers_calls()
     {
         return test_register_helpers_calls;
@@ -76,7 +106,7 @@ public:
     bool send_to_transport(MPEventInfo&) override
     {
         count++;
-        return true;
+        return test_transport_send_result;
     }
 
     void register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&) override
@@ -129,6 +159,13 @@ public:
     {
         return true;
     }
+
+    MPTransportChannelStatusHandle* get_channel_status(uint& size) override
+    {
+        size = 0;
+        return nullptr;
+    }
+
 private:
     inline static int count = 0;
     inline static int test_register_helpers_calls = 0;
@@ -224,6 +261,7 @@ TEST_GROUP(mp_data_bus_pub)
     MPDataBus* mp_dbus = nullptr;
     void setup() override
     {
+        MockMPTransport::reset_count();
         mp_dbus = new MPDataBus();
         mp_dbus->init(2);
         pub_id1 = MPDataBus::get_id(pub_key1);
@@ -247,9 +285,44 @@ TEST(mp_data_bus_pub, publish)
 
     std::this_thread::sleep_for(std::chrono::milliseconds(150));
 
+    mp_dbus->sum_stats();
+    CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_published);
+    CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_sent);
+    CHECK_EQUAL(0, MPDataBus::mp_global_stats.total_messages_dropped);
+
+    mock().expectNCalls(2, "show_stats");
+
+    mp_dbus->dump_stats(nullptr, nullptr);
+    mp_dbus->dump_stats(nullptr, "mp_ut1");
+
+    mock().checkExpectations();
+
+    delete mp_dbus;
+
+    CHECK_EQUAL(1, MockMPTransport::get_count());
+}
+
+TEST(mp_data_bus_pub, publish_fail_to_send)
+{
+    CHECK_TRUE(mp_dbus->get_event_queue()->empty());
+    CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0);
+
+    test_transport_send_result = false;
+
+    std::shared_ptr<UTestEvent> event = std::make_shared<UTestEvent>(100);
+
+    mp_dbus->publish(pub_id1, DbUtIds::EVENT1, event);
+
+    std::this_thread::sleep_for(std::chrono::milliseconds(150));
+
+    mp_dbus->sum_stats();
+    CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_published);
+    CHECK_EQUAL(0, MPDataBus::mp_global_stats.total_messages_sent);
+    CHECK_EQUAL(1, MPDataBus::mp_global_stats.total_messages_dropped);
+
     delete mp_dbus;
 
-    CHECK(1 == MockMPTransport::get_count());
+    test_transport_send_result = true;
 }
 
 TEST_GROUP(mp_data_bus)
@@ -331,6 +404,9 @@ TEST(mp_data_bus, subscribe_and_receive)
     MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
     SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
     
+    SnortConfig::get_conf()->mp_dbus->sum_stats();
+    CHECK_EQUAL(2, MPDataBus::mp_global_stats.total_messages_received);
+
     CHECK_EQUAL(200, h1->evt_msg);
 }
 
index 914b786898ec341c0f10ba33a76bc67df5e94ebb..57f18d6b67ba387a4a5a3eaa3a891a34016e4ef2 100644 (file)
@@ -43,6 +43,8 @@ public:
     T get(T);
     bool put(T);
 
+    T* grab_store(int& ix);
+
     int count();
     bool full();
     bool empty();
@@ -112,6 +114,16 @@ bool Ring<T>::put(T v)
     return true;
 }
 
+template <typename T>
+T* Ring<T>::grab_store(int& ix)
+{
+    int i = logic.read();
+    if ( i < 0 )
+        return nullptr;
+    ix = i;
+    return store;
+}
+
 template <typename T>
 int Ring<T>::count()
 {
index dbaa053fdd09bceb5ef25f5b8e7d1bd75d93eb16..767f6e21b72df970a6de974d0e1477de6444c137 100644 (file)
@@ -47,6 +47,7 @@
 #include "js_norm/js_norm_module.h"
 #include "latency/latency_module.h"
 #include "log/messages.h"
+#include "lua/lua.h"
 #include "managers/module_manager.h"
 #include "managers/plugin_manager.h"
 #include "memory/memory_module.h"
@@ -406,14 +407,64 @@ static const Parameter mp_data_bus_params[] =
 
     { "debug", Parameter::PT_BOOL, nullptr, "false",
       "enable debugging" },
+#ifdef REG_TEST
+    { "hold_events", Parameter::PT_BOOL, nullptr, "false",
+      "hold events from publishing" },
+#endif
 
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
+static int dump_mp_stats(lua_State* L)
+{
+    if (!L)
+        return 0;
+
+    ControlConn* ctrlconn = ControlConn::query_from_lua(L);
+    
+    auto mod_name = luaL_optstring(L, 1, nullptr);
+
+    if (mod_name and strlen(mod_name) == 0)
+        mod_name = nullptr;
+
+    if(SnortConfig::get_conf()->mp_dbus)
+        SnortConfig::get_conf()->mp_dbus->dump_stats(ctrlconn, mod_name);
+
+    return 0;
+}
+
+static int dump_mp_events(lua_State* L)
+{
+    if (!L)
+        return 0;
+
+    ControlConn* ctrlconn = ControlConn::query_from_lua(L);
+    
+    auto mod_name = luaL_optstring(L, 1, nullptr);
+
+    if (mod_name and strlen(mod_name) == 0)
+        mod_name = nullptr;
+
+    if(SnortConfig::get_conf()->mp_dbus)
+        SnortConfig::get_conf()->mp_dbus->dump_events(ctrlconn, mod_name);
+
+    return 0;
+}
+
+static int show_mp_channel_status(lua_State* L)
+{
+    ControlConn* ctrlconn = ControlConn::query_from_lua(L);
+
+    if(SnortConfig::get_conf()->mp_dbus)
+        SnortConfig::get_conf()->mp_dbus->show_channel_status(ctrlconn);
+
+    return 0;
+}
+
 static int enable_debug(lua_State*)
 {
     if(SnortConfig::get_conf()->mp_dbus)
-        SnortConfig::get_conf()->mp_dbus->enable_debug = true;
+        SnortConfig::get_conf()->mp_dbus->set_debug_enabled(true);
 
     return 0;
 }
@@ -421,7 +472,7 @@ static int enable_debug(lua_State*)
 static int disable_debug(lua_State*)
 {
     if(SnortConfig::get_conf()->mp_dbus)
-        SnortConfig::get_conf()->mp_dbus->enable_debug = false;
+        SnortConfig::get_conf()->mp_dbus->set_debug_enabled(false);
 
     return 0;
 }
@@ -430,6 +481,9 @@ static const Command mp_dbus_cmds[] =
 {
     {"enable", enable_debug, nullptr, "enable multiprocess data bus debugging"},
     {"disable", disable_debug, nullptr, "disable multiprocess data bus debugging"},
+    { "dump_stats", dump_mp_stats, nullptr, "dump multiprocess data bus statistics" },
+    { "dump_events", dump_mp_events, nullptr, "dump multiprocess data bus events" },
+    { "show_channel_status", show_mp_channel_status, nullptr, "show multiprocess data bus channel status" },
     {nullptr, nullptr, nullptr, nullptr}
 };
 
@@ -446,6 +500,8 @@ public:
     bool begin(const char*, int, SnortConfig*) override;
     bool end(const char*, int, SnortConfig*) override;
     const Command* get_commands() const override;
+    const PegInfo* get_pegs() const override;
+    PegCount* get_counts() const override;
 
     Usage get_usage() const override
     { return GLOBAL; }
@@ -475,6 +531,12 @@ bool MPDataBusModule::set(const char*, Value& v, SnortConfig*)
     {
         MPDataBus::enable_debug = v.get_bool();
     }
+#ifdef REG_TEST
+    else if ( v.is("hold_events") )
+    {
+        MPDataBus::hold_events = v.get_bool();
+    }
+#endif
     else 
     {
         WarningMessage("MPDataBus: Unknown parameter '%s' in mp_data_bus module\n", v.get_name());
@@ -488,6 +550,17 @@ const Command* MPDataBusModule::get_commands() const
     return mp_dbus_cmds;
 }
 
+const PegInfo* MPDataBusModule::get_pegs() const
+{
+    return mp_databus_pegs;
+}
+
+PegCount* MPDataBusModule::get_counts() const
+{
+    if(SnortConfig::get_conf()->mp_dbus)
+        SnortConfig::get_conf()->mp_dbus->sum_stats();
+    return (PegCount*)&MPDataBus::mp_global_stats;
+}
 //-------------------------------------------------------------------------
 // reference module
 //-------------------------------------------------------------------------
index b735a25b607e9bf15782646fe13cee4d967559a5..8d18aa83ddd65408101d5b2f1d129f111a48c9d4 100644 (file)
@@ -174,7 +174,6 @@ void Snort::init(int argc, char** argv)
     if (sc->max_procs > 1)
     {
         sc->mp_dbus = new MPDataBus();
-        sc->mp_dbus->init(sc->max_procs);
     }
 
     PluginManager::load_so_plugins(sc);
@@ -195,6 +194,11 @@ void Snort::init(int argc, char** argv)
     ModuleManager::init_stats();
     ModuleManager::reset_stats(sc);
 
+    if (sc->mp_dbus)
+    {
+        sc->mp_dbus->init(sc->max_procs);
+    }
+
     if (sc->alert_before_pass())
         sc->rule_order = IpsAction::get_default_priorities(true);
 
index 5db70d17bbef0fa4111dcfa301ec192b9591145d..1a863cefd26e75b636550ddb5667d8295a34a241 100644 (file)
@@ -63,6 +63,7 @@
 #include "managers/mpse_manager.h"
 #include "managers/plugin_manager.h"
 #include "managers/so_manager.h"
+#include "managers/mp_transport_manager.h"
 #include "memory/memory_config.h"
 #include "packet_io/sfdaq.h"
 #include "packet_io/sfdaq_config.h"
@@ -1133,6 +1134,7 @@ void SnortConfig::cleanup_fatal_error()
     const SnortConfig* sc = SnortConfig::get_conf();
     if ( sc && !sc->dirty_pig )
     {
+        MPTransportManager::term();
         ModuleManager::term();
         EventManager::release_plugins();
         IpsManager::release_plugins();
index 5438ac5294fbc1f96798d48eaba1b2b414f028b0..d4900d3ca3c2479b525a691c5726aaa4d5f0e5f3 100644 (file)
@@ -56,6 +56,11 @@ class MockTransport : public MPTransport
     { }
     virtual bool is_logging_enabled() override
     { return false; }
+    MPTransportChannelStatusHandle* get_channel_status(uint& size) override
+    {
+        size = 0;
+        return nullptr;
+    }
 };
 
 unsigned get_instance_id() { return 0; }
@@ -65,10 +70,8 @@ unsigned ThreadConfig::get_instance_max() { return 1; }
 
 using namespace snort;
 
-void show_stats(unsigned long*, PegInfo const*, std::vector<unsigned int, std::allocator<unsigned int> > const&, char const*, _IO_FILE*)
-{}
-void show_stats(unsigned long*, PegInfo const*, unsigned int, char const*)
-{}
+void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
+void show_stats(PegCount*, const PegInfo*, const std::vector<unsigned>&, const char*, FILE*) { }
 
 static void mock_transport_tinit(MPTransport* t)
 {
index 8e40562f71b8a982faafd2ff444af524f515720b..783c69aa47a3ec1b073d249249efc3d716260716 100644 (file)
@@ -25,6 +25,7 @@
 
 #include <cstring>
 #include <fcntl.h>
+#include <filesystem>
 #include <iostream>
 #include <poll.h>
 #include <sys/socket.h>
 #include "main/snort_config.h"
 
 static std::mutex _receive_mutex;
-static std::mutex _update_connectors_mutex;
+static std::mutex _send_mutex;
+static std::mutex _read_mutex;
 
 #define UNIX_SOCKET_NAME_PREFIX "/snort_unix_connector_"
-
-#define MP_TRANSPORT_LOG_LABEL "MPUnixTransport"
-
-#define MP_TRANSPORT_LOG(msg, ...) do { \
-        if (!this->is_logging_enabled_flag) \
-            break; \
-        LogMessage(msg, __VA_ARGS__); \
-    } while (0)
+#define MP_TRANSPORT_LOG_LABEL "MPUnixTransportDbg"
 
 namespace snort
 {
@@ -80,7 +75,7 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg)
     {
         if (msg->content_length < sizeof(MPTransportMessage))
         {
-            MP_TRANSPORT_LOG("%s: Incomplete message received\n", MP_TRANSPORT_LOG_LABEL);
+            MPTransportLog("Incomplete message received\n");
             return;
         }
 
@@ -88,14 +83,14 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg)
         
         if (transport_message_header->type >= MAX_TYPE)
         {
-            MP_TRANSPORT_LOG("%s: Invalid message type received\n", MP_TRANSPORT_LOG_LABEL);
+            MPTransportLog("Invalid message type received\n");
             return;
         }
 
         auto deserialize_func = get_event_deserialization_function(transport_message_header->pub_id, transport_message_header->event_id);
         if (!deserialize_func)
         {
-            MP_TRANSPORT_LOG("%s: No deserialization function found for event: type %d, id %d\n", MP_TRANSPORT_LOG_LABEL, transport_message_header->type, transport_message_header->event_id);
+            MPTransportLog("No deserialization function found for event: type %d, id %d\n", transport_message_header->type, transport_message_header->event_id);
             return;
         }
 
@@ -104,29 +99,33 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg)
         MPEventInfo event(std::shared_ptr<DataEvent> (internal_event), transport_message_header->event_id, transport_message_header->pub_id);
 
         (transport_receive_handler)(event);
-
+        transport_stats.received_events++;
+        transport_stats.received_bytes += sizeof(MPTransportMessageHeader) + transport_message_header->data_length;
     }
     delete msg;
 }
 
-void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector, UnixDomainConnectorConfig* cfg)
+void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector, UnixDomainConnectorConfig* cfg, const ushort& channel_id)
 {
     assert(connector);
     assert(cfg);
 
-    std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+    std::lock_guard<std::mutex> guard_send(_send_mutex);
+    std::lock_guard<std::mutex> guard_read(_read_mutex);
+
+    transport_stats.successful_connections++;
 
     auto side_channel = new SideChannel(ScMsgFormat::BINARY);
     side_channel->connector_receive = connector;
     side_channel->connector_transmit = side_channel->connector_receive;
     side_channel->register_receive_handler(std::bind(&MPUnixDomainTransport::side_channel_receive_handler, this, std::placeholders::_1));
     connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
-    this->side_channels.push_back(new SideChannelHandle(side_channel, cfg));
+    this->side_channels.push_back(new SideChannelHandle(side_channel, cfg, channel_id));
     connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
 }
 
-MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c) : MPTransport(), 
-    config(c)
+MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c, MPUnixTransportStats& stats) : MPTransport(), 
+    config(c), transport_stats(stats)
 {
     this->is_logging_enabled_flag = c->enable_logging;
 }
@@ -142,7 +141,8 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
 
     if (!serialize_func)
     {
-        MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event.type);
+        transport_stats.send_errors++;
+        MPTransportLog("No serialize function found for event %d\n", event.type);
         return false;
     }
 
@@ -153,15 +153,25 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
     
 
     (serialize_func)(event.event.get(), transport_message.data, &transport_message.header.data_length);
-    for (auto &&sc_handler : this->side_channels)
     {
-        auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length);
-        memcpy(msg->content, &transport_message, sizeof(MPTransportMessageHeader));
-        memcpy(msg->content + sizeof(MPTransportMessageHeader), transport_message.data, transport_message.header.data_length);
-        auto send_result = sc_handler->side_channel->transmit_message(msg);
-        if (!send_result)
+        std::lock_guard<std::mutex> guard(_send_mutex);
+
+        for (auto &&sc_handler : this->side_channels)
         {
-            MP_TRANSPORT_LOG("%s: Failed to send message to side channel\n", MP_TRANSPORT_LOG_LABEL);
+            auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length);
+            memcpy(msg->content, &transport_message, sizeof(MPTransportMessageHeader));
+            memcpy(msg->content + sizeof(MPTransportMessageHeader), transport_message.data, transport_message.header.data_length);
+            auto send_result = sc_handler->side_channel->transmit_message(msg);
+            if (!send_result)
+            {
+                MPTransportLog("Failed to send message to side channel\n");
+                transport_stats.send_errors++;
+            }
+            else
+            {
+                transport_stats.sent_events++;
+                transport_stats.sent_bytes += sizeof(MPTransportMessageHeader) + transport_message.header.data_length;
+            }
         }
     }
 
@@ -176,7 +186,7 @@ void MPUnixDomainTransport::register_event_helpers(const unsigned& pub_id, const
     assert(helper.serializer);
     
     this->event_helpers[pub_id] = SerializeFunctionHandle();
-    this->event_helpers[pub_id].serialize_functions.insert({event_id, helper});
+    this->event_helpers[pub_id].serialize_functions.insert({event_id, std::move(helper)});
 }
 
 void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler& handler)
@@ -201,7 +211,7 @@ void MPUnixDomainTransport::process_messages_from_side_channels()
         }
 
         {
-            std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+            std::lock_guard<std::mutex> guard(_read_mutex);
             bool messages_left;
 
             do
@@ -227,7 +237,8 @@ void MPUnixDomainTransport::notify_process_thread()
 
 void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_recconecting, SideChannel *side_channel)
 {
-    std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+    std::lock_guard<std::mutex> guard_send(_send_mutex);
+    std::lock_guard<std::mutex> guard_read(_read_mutex);
     if (side_channel->connector_receive)
     {
         delete side_channel->connector_receive;
@@ -236,13 +247,15 @@ void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connec
 
     if (connector)
     {
+        connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
         side_channel->connector_receive = side_channel->connector_transmit = connector;
+        this->transport_stats.successful_connections++;
     }
     else
     {
         if (is_recconecting == false)
         {
-            MP_TRANSPORT_LOG("%s: Accepted connection interrupted, removing handle\n", MP_TRANSPORT_LOG_LABEL);
+            MPTransportLog("Accepted connection interrupted, removing handle\n");
             for(auto it = this->side_channels.begin(); it != this->side_channels.end(); ++it)
             {
                 if ((*it)->side_channel == side_channel)
@@ -252,22 +265,41 @@ void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connec
                     break;
                 }
             }
+            this->transport_stats.closed_connections++;
+        }
+        else
+        {
+            this->transport_stats.connection_retries++;
         }
     }
 }
 
+void MPUnixDomainTransport::MPTransportLog(const char *msg, ...)
+{
+    if (!is_logging_enabled_flag)
+        return;
+
+    char buf[256];
+    va_list args;
+    va_start(args, msg);
+    vsnprintf(buf, sizeof(buf), msg, args);
+    va_end(args);
+
+    LogMessage("%s ID=%d %s", MP_TRANSPORT_LOG_LABEL, mp_current_process_id, buf);
+}
+
 MPSerializeFunc MPUnixDomainTransport::get_event_serialization_function(unsigned pub_id, unsigned event_id)
 {
     auto helper_it = this->event_helpers.find(pub_id);
     if (helper_it == this->event_helpers.end())
     {
-        MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id);
+        MPTransportLog("%s: No available helper functions is registered for %d\n", pub_id);
         return nullptr;
     }
     auto helper_functions = helper_it->second.get_function_set(event_id);
     if (!helper_functions)
     {
-        MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id);
+        MPTransportLog("%s: No serialize function found for event %d\n", event_id);
         return nullptr;
     }
     return helper_functions->serializer;
@@ -278,13 +310,13 @@ MPDeserializeFunc MPUnixDomainTransport::get_event_deserialization_function(unsi
     auto helper_it = this->event_helpers.find(pub_id);
     if (helper_it == this->event_helpers.end())
     {
-        MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id);
+        MPTransportLog("No available helper functions is registered for %d\n", pub_id);
         return nullptr;
     }
     auto helper_functions = helper_it->second.get_function_set(event_id);
     if (!helper_functions)
     {
-        MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id);
+        MPTransportLog("No serialize function found for event %d\n", event_id);
         return nullptr;
     }
     return helper_functions->deserializer;
@@ -337,14 +369,26 @@ void MPUnixDomainTransport::init_side_channels()
     if (config->max_processes < 2)
         return;
 
-    auto instance_id = Snort::get_process_id();//Snort instance id
+    auto instance_id = mp_current_process_id = Snort::get_process_id();//Snort instance id
     auto max_processes = config->max_processes;
 
     this->is_running = true;
 
+    if ( std::filesystem::is_directory(config->unix_domain_socket_path) == false )
+    {
+        std::error_code ec;
+        std::filesystem::create_directories(config->unix_domain_socket_path, ec);
+        if (ec)
+        {
+            MPTransportLog("Failed to create directory %s\n", config->unix_domain_socket_path.c_str());
+            return;
+        }
+    }
+
     for (ushort i = instance_id; i < max_processes; i++)
     {
         auto listen_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i);
+
         auto unix_listener = new UnixDomainConnectorListener(listen_path.c_str());
         
         UnixDomainConnectorConfig* unix_config = new UnixDomainConnectorConfig();
@@ -366,7 +410,7 @@ void MPUnixDomainTransport::init_side_channels()
         }
         unix_config->paths.push_back(listen_path);
 
-        unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2), unix_config);
+        unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2, instance_id + i), unix_config);
         
         auto unix_listener_handle = new UnixAcceptorHandle();
         unix_listener_handle->connector_config = unix_config;
@@ -379,7 +423,7 @@ void MPUnixDomainTransport::init_side_channels()
         auto side_channel = new SideChannel(ScMsgFormat::BINARY);
         side_channel->register_receive_handler([this](SCMessage* msg) { this->side_channel_receive_handler(msg); });
 
-        auto send_path = config->unix_domain_socket_path + "/" + "snort_unix_connector_" + std::to_string(i);
+        auto send_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i);
 
         UnixDomainConnectorConfig* connector_conf = new UnixDomainConnectorConfig();
         connector_conf->setup = UnixDomainConnectorConfig::Setup::CALL;
@@ -390,26 +434,22 @@ void MPUnixDomainTransport::init_side_channels()
         connector_conf->connect_timeout_seconds = config->connect_timeout_seconds;
         connector_conf->paths.push_back(send_path);
 
-        auto connector = unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
-
-        if (connector)
-            connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
-
-        side_channel->connector_receive = connector;
-        side_channel->connector_transmit = side_channel->connector_receive;
-        this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf));
+        unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+        
+        this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf, i));
     }
 
     this->consume_thread = new std::thread(&MPUnixDomainTransport::process_messages_from_side_channels, this);
 }
+
 void MPUnixDomainTransport::cleanup_side_channels()
 {
-    std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+    std::lock_guard<std::mutex> guard_send(_send_mutex);
+    std::lock_guard<std::mutex> guard_read(_read_mutex);
 
     for (uint i = 0; i < this->side_channels.size(); i++)
     {
-        auto side_channel = this->side_channels[i];
-        delete side_channel;
+        delete this->side_channels[i];
     }
 
     this->side_channels.clear();
@@ -428,6 +468,7 @@ SideChannelHandle::~SideChannelHandle()
     if (connector_config)
         delete connector_config;
 }
+
 void MPUnixDomainTransport::enable_logging()
 {
     this->is_logging_enabled_flag = true;
@@ -443,4 +484,29 @@ bool MPUnixDomainTransport::is_logging_enabled()
     return this->is_logging_enabled_flag;
 }
 
-};
+MPTransportChannelStatusHandle *MPUnixDomainTransport::get_channel_status(uint &size)
+{
+    std::lock_guard<std::mutex> guard_send(_send_mutex);
+    std::lock_guard<std::mutex> guard_read(_read_mutex);
+    if (this->side_channels.size() == 0)
+    {
+        size = 0;
+        return nullptr;
+    }
+    MPTransportChannelStatusHandle* result = new MPTransportChannelStatusHandle[this->side_channels.size()];
+
+    size = this->side_channels.size();
+    uint it = 0;
+
+    for (auto &&sc_handler : this->side_channels)
+    {
+        result[it].id = sc_handler->channel_id;
+        result[it].status = sc_handler->side_channel->connector_receive ? MPTransportChannelStatus::CONNECTED : MPTransportChannelStatus::CONNECTING;
+        result[it].name = "Snort connection to " + std::to_string(sc_handler->channel_id) + " instance";
+        it++;
+    }
+
+    return result;
+}
+
+}
index 8208e030e2fcc12a1296e388877070ae1bcf4a7b..0c23eb087314d0b993b1876a2b7d45ddbcea481a 100644 (file)
 namespace snort
 {
 
+struct MPUnixTransportStats
+{
+    MPUnixTransportStats() :
+        sent_events(0),
+        sent_bytes(0),
+        received_events(0),
+        received_bytes(0),
+        send_errors(0),
+        successful_connections(0),
+        closed_connections(0),
+        connection_retries(0)
+    { }
+    
+    PegCount sent_events;
+    PegCount sent_bytes;
+    PegCount received_events;
+    PegCount received_bytes;
+    PegCount send_errors;
+    PegCount successful_connections;
+    PegCount closed_connections;
+    PegCount connection_retries;
+};
+
 struct MPUnixDomainTransportConfig
 {
     std::string unix_domain_socket_path;
@@ -60,14 +83,15 @@ struct SerializeFunctionHandle
 
 struct SideChannelHandle
 {
-    SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc) :
-        side_channel(sc), connector_config(cc)
+    SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc, const ushort& channel_id) :
+        side_channel(sc), connector_config(cc), channel_id(channel_id)
     { }
 
     ~SideChannelHandle();
 
     SideChannel* side_channel;
     UnixDomainConnectorConfig* connector_config;
+    ushort channel_id;
 };
 
 struct UnixAcceptorHandle
@@ -80,7 +104,7 @@ class MPUnixDomainTransport : public MPTransport
 {
     public:
 
-    MPUnixDomainTransport(MPUnixDomainTransportConfig* c);
+    MPUnixDomainTransport(MPUnixDomainTransportConfig* c, MPUnixTransportStats& stats);
     ~MPUnixDomainTransport() override;
 
     bool configure(const SnortConfig*) override;
@@ -95,6 +119,7 @@ class MPUnixDomainTransport : public MPTransport
     void disable_logging() override;
     bool is_logging_enabled() override;
     void cleanup();
+    MPTransportChannelStatusHandle* get_channel_status(uint& size) override;
 
     MPUnixDomainTransportConfig* get_config()
     { return config; }
@@ -105,14 +130,17 @@ class MPUnixDomainTransport : public MPTransport
     void init_side_channels();
     void cleanup_side_channels();
     void side_channel_receive_handler(SCMessage* msg);
-    void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg);
+    void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg, const ushort& channel_id);
     void process_messages_from_side_channels();
     void notify_process_thread();
     void connector_update_handler(UnixDomainConnector* connector, bool is_recconecting, SideChannel* side_channel);
+    void MPTransportLog(const char* msg, ...);
 
     MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id);
     MPDeserializeFunc get_event_deserialization_function(unsigned pub_id, unsigned event_id);
 
+    uint mp_current_process_id = 0;
+
     TransportReceiveEventHandler transport_receive_handler = nullptr;
     MPUnixDomainTransportConfig* config = nullptr;
 
@@ -126,6 +154,8 @@ class MPUnixDomainTransport : public MPTransport
 
     std::thread* consume_thread = nullptr;
     std::condition_variable consume_thread_cv;
+
+    MPUnixTransportStats& transport_stats;
 };
 
 }
index b0c8c5f056c30daf4ff024402bdefcb3cb9c4dd3..f01109f207b3b429b3193ef2ef9bee5cd15f66da 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "main/snort_config.h"
 #include "log/messages.h"
+#include "utils/stats.h"
 
 #define DEFAULT_UNIX_DOMAIN_SOCKET_PATH "/tmp/snort_unix_connectors"
 
@@ -42,6 +43,19 @@ static const Parameter unix_transport_params[] =
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
+static const PegInfo mp_unix_transport_pegs[] =
+{
+    { CountType::SUM, "sent_events", "mp_transport events sent count" },
+    { CountType::SUM, "sent_bytes", "mp_transport events bytes sent count" },
+    { CountType::SUM, "receive_events", "mp_transport events received count" },
+    { CountType::SUM, "receive_bytes", "mp_transport events bytes received count" },
+    { CountType::SUM, "sent_errors", "mp_transport events errors count" },
+    { CountType::SUM, "successful_connections", "successful mp_transport connections count" },
+    { CountType::SUM, "closed_connections", "closed mp_transport connections count" },
+    { CountType::SUM, "connection_retries", "mp_transport connection retries count" },
+    { CountType::END, nullptr, nullptr },
+};
+
 MPUnixDomainTransportModule::MPUnixDomainTransportModule(): Module(MODULE_NAME, MODULE_HELP, unix_transport_params)
 { 
     config = nullptr;
@@ -96,6 +110,16 @@ bool MPUnixDomainTransportModule::set(const char *, Value & v, SnortConfig *)
     return true;
 }
 
+const PegInfo *MPUnixDomainTransportModule::get_pegs() const
+{
+    return mp_unix_transport_pegs;
+}
+
+PegCount *MPUnixDomainTransportModule::get_counts() const
+{
+    return (PegCount*)&unix_transport_stats;
+}
+
 static struct MPTransportApi mp_unixdomain_transport_api =
 {
     {
index 464b4f3e852ed77d4684a8c81fd838885c4e042e..fae5bbc17b2bb7b1d0e9c0b8c62856629f819faa 100644 (file)
@@ -42,10 +42,14 @@ class MPUnixDomainTransportModule : public Module
     bool begin(const char*, int, SnortConfig*) override;
     bool set(const char*, Value&, SnortConfig*) override;
 
+    const PegInfo* get_pegs() const override;
+    PegCount* get_counts() const override;
+
     Usage get_usage() const override
     { return GLOBAL; }
 
     MPUnixDomainTransportConfig* config;
+    MPUnixTransportStats unix_transport_stats;
 };
 
 static Module* mod_ctor()
@@ -61,7 +65,7 @@ static void mod_dtor(Module* m)
 static MPTransport* mp_unixdomain_transport_ctor(Module* m)
 {
     auto unix_tr_mod = (MPUnixDomainTransportModule*)m;
-    return new MPUnixDomainTransport(unix_tr_mod->config);
+    return new MPUnixDomainTransport(unix_tr_mod->config, unix_tr_mod->unix_transport_stats);
 }
 
 static void mp_unixdomain_transport_dtor(MPTransport* t)
index c00ff41a61463a3d048ec7176371c65915037061..9f26f4d35d714346361ddbc1ac8acf9a4654610b 100644 (file)
@@ -42,11 +42,9 @@ namespace snort
         return 1;
     }
 
-    MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig* config) :
-        MPTransport()
-    {
-        this->config = config;
-    }
+    MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig* config, MPUnixTransportStats& stats) :
+        MPTransport(), config(config), transport_stats(stats)
+    { }
     MPUnixDomainTransport::~MPUnixDomainTransport()
     { destroy_cnt++; }
     void MPUnixDomainTransport::thread_init()
@@ -73,6 +71,11 @@ namespace snort
     {}
     void MPUnixDomainTransport::disable_logging()
     {}
+    MPTransportChannelStatusHandle* MPUnixDomainTransport::get_channel_status(unsigned int& size)
+    {
+        size = 0;
+        return nullptr;
+    }
 
     char* snort_strdup(const char*)
     {
index 84cf8497999dd4a73f2d43feff2af44280c58c82..bad3074ac125d6b87379c1e789c4d7549a8dead2 100644 (file)
@@ -204,6 +204,7 @@ UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorCo
         test_call_sock_created++;
         auto new_conn = new UnixDomainConnector(cfg, 0, idx);
         call_connector = new_conn;
+        update_handler(new_conn, false);
         return new_conn;
     }
     assert(false);
@@ -269,6 +270,7 @@ bool deserialize_mock(const char* buffer, uint16_t length, DataEvent*& event)
 MPHelperFunctions mp_helper_functions_mock(serialize_mock, deserialize_mock);
 
 static MPUnixDomainTransportConfig test_config;
+static MPUnixTransportStats test_stats;
 static MPUnixDomainTransport* test_transport = nullptr;
 
 static SnortConfig test_snort_config(nullptr, nullptr);
@@ -278,7 +280,7 @@ TEST_GROUP(unix_transport_test_connectivity_group)
     void setup() override
     {
         test_snort_config.max_procs = 2;
-        test_transport = new MPUnixDomainTransport(&test_config);
+        test_transport = new MPUnixDomainTransport(&test_config, test_stats);
         test_transport->configure(&test_snort_config);
     }
 
@@ -294,8 +296,8 @@ static MPUnixDomainTransportConfig test_config_message;
 static MPTransport* test_transport_message_1 = nullptr;
 static MPTransport* test_transport_message_2 = nullptr;
 
-static int reciveved_1_msg_cnt = 0;
-static int reciveved_2_msg_cnt = 0;
+static int received_1_msg_cnt = 0;
+static int received_2_msg_cnt = 0;
 
 TEST_GROUP(unix_transport_test_messaging)
 {
@@ -305,31 +307,31 @@ TEST_GROUP(unix_transport_test_messaging)
 
         accept_cnt = 1;
 
-        test_config_message.unix_domain_socket_path = "/tmp";
+        test_config_message.unix_domain_socket_path = ".";
         test_config_message.max_processes = 2;
         test_config_message.conn_retries = false;
         test_config_message.retry_interval_seconds = 0;
         test_config_message.max_retries = 0;
         test_config_message.connect_timeout_seconds = 30;
 
-        test_transport_message_1 = new MPUnixDomainTransport(&test_config_message);
+        test_transport_message_1 = new MPUnixDomainTransport(&test_config_message, test_stats);
         snort_instance_id = 1;
         test_transport_message_1->configure(&test_snort_config);
         test_transport_message_1->init_connection();
         test_transport_message_1->register_receive_handler([](const snort::MPEventInfo& e)
         {
-            reciveved_1_msg_cnt++;
+            received_1_msg_cnt++;
         });
         
         std::this_thread::sleep_for(std::chrono::milliseconds(100));
 
-        test_transport_message_2 = new MPUnixDomainTransport(&test_config_message);
+        test_transport_message_2 = new MPUnixDomainTransport(&test_config_message, test_stats);
         snort_instance_id = 2;
         test_transport_message_2->configure(&test_snort_config);
         test_transport_message_2->init_connection();
         test_transport_message_2->register_receive_handler([](const snort::MPEventInfo& e)
         {
-            reciveved_2_msg_cnt++;
+            received_2_msg_cnt++;
         });
     }
 
@@ -369,7 +371,7 @@ TEST(unix_transport_test_connectivity_group, set_logging_enabled_disabled)
 TEST(unix_transport_test_connectivity_group, init_connection_single_snort_instance)
 {
     clear_test_calls();
-    test_config.unix_domain_socket_path = "/tmp";
+    test_config.unix_domain_socket_path = ".";
     test_config.max_processes = 1;
 
     test_transport->init_connection();
@@ -389,7 +391,7 @@ TEST(unix_transport_test_connectivity_group, init_connection_first_snort_instanc
     clear_test_calls();
     snort_instance_id = 1;
 
-    test_config.unix_domain_socket_path = "/tmp";
+    test_config.unix_domain_socket_path = ".";
     test_config.max_processes = 2;
 
     accept_cnt = 1;
@@ -406,7 +408,7 @@ TEST(unix_transport_test_connectivity_group, init_connection_second_snort_instan
 {
     clear_test_calls();
     snort_instance_id = 2;
-    test_config.unix_domain_socket_path = "/tmp";
+    test_config.unix_domain_socket_path = ".";
     test_config.max_processes = 2;
 
     test_transport->init_connection();
@@ -425,7 +427,7 @@ TEST(unix_transport_test_connectivity_group, connector_update_handler_call)
 {
     clear_test_calls();
     
-    test_config.unix_domain_socket_path = "/tmp";
+    test_config.unix_domain_socket_path = ".";
     test_config.max_processes = 2;
 
     accept_cnt = 1;
@@ -470,10 +472,10 @@ TEST(unix_transport_test_messaging, send_to_transport_biderectional)
 
     std::this_thread::sleep_for(std::chrono::milliseconds(500));
 
-    CHECK(test_deserialize_calls == 2);
-    CHECK(reciveved_1_msg_cnt == 1);
-    CHECK(reciveved_2_msg_cnt == 1);
-    CHECK(test_send_calls == 2);
+    CHECK_EQUAL(2, test_deserialize_calls);
+    CHECK_EQUAL(1 ,received_1_msg_cnt);
+    CHECK_EQUAL(1, received_2_msg_cnt);
+    CHECK_EQUAL(2, test_send_calls);
 };
 
 TEST(unix_transport_test_messaging, send_to_transport_no_helpers)
@@ -488,16 +490,16 @@ TEST(unix_transport_test_messaging, send_to_transport_no_helpers)
     CHECK(res == false);
     CHECK(test_serialize_calls == 0);
     CHECK(test_deserialize_calls == 0);
-    CHECK(reciveved_1_msg_cnt == 0);
-    CHECK(reciveved_2_msg_cnt == 0);
+    CHECK(received_1_msg_cnt == 0);
+    CHECK(received_2_msg_cnt == 0);
     CHECK(test_send_calls == 0);
 
     res = test_transport_message_2->send_to_transport(event);
     CHECK(res == false);
     CHECK(test_serialize_calls == 0);
     CHECK(test_deserialize_calls == 0);
-    CHECK(reciveved_1_msg_cnt == 0);
-    CHECK(reciveved_2_msg_cnt == 0);
+    CHECK(received_1_msg_cnt == 0);
+    CHECK(received_2_msg_cnt == 0);
     CHECK(test_send_calls == 0);
 }
 
index 3fe5a0f80f7a822f48bf80b236c0f476efb21ffc..b0fb870492881213a27c342c26bc44682b7cdb7d 100644 (file)
@@ -300,9 +300,15 @@ bool SideChannel::discard_message(SCMessage* msg) const
 
 bool SideChannel::transmit_message(SCMessage* msg) const
 {
-    if ( !connector_transmit or !msg )
+    if(!msg)
         return false;
 
+    if ( !connector_transmit)
+    {
+        delete msg;
+        return false;
+    }
+
     if ( msg_format == ScMsgFormat::TEXT )
     {
         std::string text = sc_msg_data_to_text(msg->content, msg->content_length);