]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4692: mp_data_bus: core logic for mp databus
authorUmang Sharma (umasharm) <umasharm@cisco.com>
Sat, 26 Apr 2025 00:34:17 +0000 (00:34 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Sat, 26 Apr 2025 00:34:17 +0000 (00:34 +0000)
Merge in SNORT/snort3 from ~UMASHARM/snort3:mp_dbus to master

Squashed commit of the following:

commit 7fc8f62dac71aea14203346fe12d2d3bc9605f9c
Author: Umang Sharma <umasharm@cisco.com>
Date:   Thu Apr 24 15:29:53 2025 -0400

    mp_data_bus: core logic for mp databus

src/connectors/unixdomain_connector/unixdomain_connector.cc
src/framework/mp_data_bus.cc
src/framework/mp_data_bus.h
src/framework/test/CMakeLists.txt
src/framework/test/mp_data_bus_test.cc [new file with mode: 0644]
src/main/help.cc
src/main/modules.cc
src/main/snort.cc
src/main/snort_config.cc
src/mp_transport/mp_unix_transport/mp_unix_transport.cc
src/mp_transport/mp_unix_transport/test/unix_transport_test.cc

index f3e9b382b39b6bac5837357be980bddcfd51bed6..24a75779bf396280d14a72e20ef568da2f4c0475 100644 (file)
@@ -110,7 +110,7 @@ static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_
                 if (attempt_connection(sfd, path)) {
                     // Connection successful
                     UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
-                    LogMessage("UnixDomainC: Connected to %s", path);
+                    LogMessage("UnixDomainC: Connected to %s\n", path);
                     if(update_handler)
                     {
                         unixdomain_conn->set_update_handler(update_handler);
@@ -398,7 +398,7 @@ UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorCo
             return nullptr;
         }
     }
-    LogMessage("UnixDomainC: Connected to %s", path);
+    LogMessage("UnixDomainC: Connected to %s\n", path);
     UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
     unixdomain_conn->set_update_handler(update_handler);
     if(update_handler)
@@ -591,11 +591,15 @@ void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnecto
             {
                 ErrorMessage("UnixDomainC: Too many errors, stopping accept thread\n");
                 close(sock_fd);
+                sock_fd = -1;
+                should_accept = false;
                 return;
             }
             int peer_sfd = accept(sock_fd, nullptr, nullptr);
             if (peer_sfd == -1) 
             {
+                if (!should_accept)
+                    return;
                 error_count++;
                 ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno));
                 continue;
@@ -611,13 +615,16 @@ void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnecto
 
 void UnixDomainConnectorListener::stop_accepting_connections()
 {
-    if(should_accept)
+    should_accept = false;
+    if (sock_fd)
     {
-        should_accept = false;
+        shutdown(sock_fd, SHUT_RDWR);
         close(sock_fd);
-        if (accept_thread && accept_thread->joinable()) {
-            accept_thread->join();
-        }
+        sock_fd = -1;
+    }
+    
+    if (accept_thread && accept_thread->joinable()) {
+        accept_thread->join();
         delete accept_thread;
         accept_thread = nullptr;
     }
index 1263b2e864060e830158901269635982214c8e21..8e6477050e02a9f17d9fe4ad118e62113efc424b 100644 (file)
 #include "pub_sub/intrinsic_event_ids.h"
 #include "utils/stats.h"
 #include "main/snort_types.h"
-#include "managers/mp_transport_manager.h"
 #include "log/messages.h"
+#include "helpers/ring.h"
+#include "managers/mp_transport_manager.h"
 
 using namespace snort;
 
+std::condition_variable MPDataBus::queue_cv;
+std::mutex MPDataBus::queue_mutex;
+uint32_t MPDataBus::mp_max_eventq_size = DEFAULT_MAX_EVENTQ_SIZE;
+std::string MPDataBus::transport = DEFAULT_TRANSPORT;
+bool MPDataBus::enable_debug = false;
+MPTransport* MPDataBus::transport_layer = nullptr;
+
 static std::unordered_map<std::string, unsigned> mp_pub_ids;
 
 //--------------------------------------------------------------------------
 // public methods
 //-------------------------------------------------------------------------
 
-MPDataBus::MPDataBus() = default;
+MPDataBus::MPDataBus() : run_thread(true)
+{
+    mp_event_queue = new Ring<std::shared_ptr<MPEventInfo>>(mp_max_eventq_size);
+    start_worker_thread();
+}
 
 MPDataBus::~MPDataBus()
 {
-    // Clean up mp_pub_sub
-    for (auto& sublist : mp_pub_sub)
+    stop_worker_thread();
+
+    for (auto& [key, sublist] : mp_pub_sub)
     {
         for (auto* handler : sublist)
         {
             if (handler->cloned)
+            {
                 handler->cloned = false;
+            }
             else
-                delete handler;
+            {
+                delete handler; 
+            }
         }
         sublist.clear();
     }
     mp_pub_sub.clear();
+    delete mp_event_queue;
+    mp_event_queue = nullptr;
 }
 
 unsigned MPDataBus::init(int max_procs)
 {
-    UNUSED(max_procs);
+    if (max_procs <= 1)
+    {
+        return 1;
+    }
+
+    transport_layer = MPTransportManager::get_transport(transport);
+    if (transport_layer == nullptr)
+    {
+        ErrorMessage("MPDataBus: Failed to get transport layer\n");
+        return 0;
+    }
+
+    transport_layer->register_receive_handler(std::bind(&MPDataBus::receive_message, this, std::placeholders::_1));
+    transport_layer->init_connection();
+
     return 0;
 }
 
 void MPDataBus::clone(MPDataBus& from, const char* exclude_name)
-{ 
-    UNUSED(from);
-    UNUSED(exclude_name);
+{
+    from.stop_worker_thread();
+    for (const auto& [key, sublist] : from.mp_pub_sub)
+    {
+        unsigned pid = key.first; 
+        unsigned eid = key.second;
+
+        for (auto* h : sublist)
+        {
+            if (!exclude_name || strcmp(exclude_name, h->module_name) != 0)
+            {
+                h->cloned = true;
+                _subscribe(pid, eid, h);
+            }
+        }
+    }
+}
+
+unsigned MPDataBus::get_id(const PubKey& key)
+{
+    // Generate a unique hash for the publisher's name, 
+    std::hash<std::string> hasher;
+    unsigned unique_id = (hasher(key.publisher) % 10000);
+
+    auto it = mp_pub_ids.find(key.publisher);
+
+    if (it == mp_pub_ids.end())
+    {
+        // Map the unique hash to the publisher
+        mp_pub_ids[key.publisher] = unique_id;
+    }
+    // Return the unique ID for the publisher
+    return mp_pub_ids[key.publisher];
 }
 
-// module subscribes to an event from a peer snort process
 void MPDataBus::subscribe(const PubKey& key, unsigned eid, DataHandler* h)
 {
-    UNUSED(key);
-    UNUSED(eid);
-    UNUSED(h);
+    if(! SnortConfig::get_conf()->mp_dbus)
+    {
+        ErrorMessage("MPDataBus: MPDataBus not initialized\n");
+        return;
+    }
+
+    SnortConfig::get_conf()->mp_dbus->_subscribe(key, eid, h);
+    MP_DATABUS_LOG("MPDataBus: Subscribed to event ID %u\n", eid);
 }
 
-// publish event to all peer snort processes subscribed to the event
-bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, DataEvent& e, Flow* f) 
+bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, std::shared_ptr<DataEvent> e, Flow*)
 {
-    // Publish implementation
-    UNUSED(pub_id);
-    UNUSED(evt_id);
-    UNUSED(e);
-    UNUSED(f);
+    std::shared_ptr<MPEventInfo> event_info = 
+                std::make_shared<MPEventInfo>(std::move(e), MPEventType(evt_id), pub_id);
+
+    const SnortConfig *sc = SnortConfig::get_conf();
+
+    if (sc->mp_dbus == nullptr)
+    {
+        ErrorMessage("MPDataBus: MPDataBus not initialized\n");
+        return false;
+    }
+
+    if (sc->mp_dbus->mp_event_queue != nullptr and !sc->mp_dbus->mp_event_queue->full() and !sc->mp_dbus->mp_event_queue->put(event_info)) {
+        ErrorMessage("MPDataBus: Failed to enqueue event for publisher ID %u and event ID %u\n", pub_id, evt_id);
+        return false;
+    }
+
+    {
+        std::lock_guard<std::mutex> lock(queue_mutex);
+        queue_cv.notify_one();
+    }
+
+    MP_DATABUS_LOG("MPDataBus: Event published for publisher ID %u and event ID %u\n", pub_id, evt_id);
+
     return true;
 }
 
-// register event helpers for serialization and deserialization of msg events
-void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc* mp_serializer_helper, MPDeserializeFunc* mp_deserializer_helper)
+void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc& mp_serializer_helper, MPDeserializeFunc& mp_deserializer_helper)
 {
-    UNUSED(key);
-    UNUSED(evt_id);
-    UNUSED(mp_serializer_helper);
-    UNUSED(mp_deserializer_helper);
+    if (!SnortConfig::get_conf()->mp_dbus && !SnortConfig::get_conf()->mp_dbus->transport_layer)
+    {
+        ErrorMessage("MPDataBus: MPDataBus or transport layer not initialized\n");
+        return;
+    }
+
+    unsigned pub_id = get_id(key);
+
+    MPHelperFunctions helpers(mp_serializer_helper, mp_deserializer_helper);
+    
+    SnortConfig::get_conf()->mp_dbus->transport_layer->register_event_helpers(pub_id, evt_id, helpers);
+    MP_DATABUS_LOG("MPDataBus: Registered event helpers for event ID %u\n", evt_id);
 }
 
 // API for receiving the DataEvent and Event type from transport layer
 void MPDataBus::receive_message(const MPEventInfo& event_info)
 {
-    UNUSED(event_info);
+    DataEvent *e = event_info.event.get();
+    unsigned evt_id = event_info.type;
+    unsigned pub_id = event_info.pub_id;
+
+    MP_DATABUS_LOG("MPDataBus: Received message for publisher ID %u and event ID %u\n", pub_id, evt_id);
+
+    _publish(pub_id, evt_id, *e, nullptr);
 }
 
 
 //--------------------------------------------------------------------------
 // private methods
 //--------------------------------------------------------------------------
+void MPDataBus::process_event_queue()
+{
+    if (!mp_event_queue) {
+        return;
+    }
+
+    std::unique_lock<std::mutex> lock(queue_mutex);
+
+    queue_cv.wait_for(lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP), [this]() {
+        return mp_event_queue != nullptr && !mp_event_queue->empty();
+    });
+
+    lock.unlock();
+
+    while (!mp_event_queue->empty()) {
+        std::shared_ptr<MPEventInfo> event_info = mp_event_queue->get(nullptr);
+        if (event_info) {
+            MP_DATABUS_LOG("MPDataBus: Processing event for publisher ID %u \n",
+                        event_info->pub_id);
+
+            transport_layer->send_to_transport(*event_info);
+        }
+    }
+}
+
+void MPDataBus::worker_thread_func()
+{
+    while (run_thread.load() ) {
+        process_event_queue();
+    }
+}
+
+void MPDataBus::start_worker_thread()
+{
+    run_thread.store(true);
+    worker_thread = std::make_unique<std::thread>(&MPDataBus::worker_thread_func, this);
+}
+
+void MPDataBus::stop_worker_thread()
+{
+    run_thread.store(false);
+    queue_cv.notify_one();
+
+    if (worker_thread && worker_thread->joinable())
+    {
+        worker_thread->join();
+    }
+
+    worker_thread.reset();
+}
+
+static bool compare(DataHandler* a, DataHandler* b)
+{
+    if ( a->order and b->order )
+        return a->order < b->order;
+
+    if ( a->order )
+        return true;
+
+    return false;
+}
 
 void MPDataBus::_subscribe(unsigned pid, unsigned eid, DataHandler* h)
 {
-    UNUSED(pid);
-    UNUSED(eid);
-    UNUSED(h);
+    std::pair<unsigned, unsigned> key = {pid, eid};
+
+    SubList& subs = mp_pub_sub[key];
+    subs.emplace_back(h);
+
+    std::sort(subs.begin(), subs.end(), compare);
 }
 
-void MPDataBus::_publish(unsigned int pid, unsigned int eid, DataEvent& e, Flow* f)
+void MPDataBus::_subscribe(const PubKey& key, unsigned eid, DataHandler* h)
 {
-    UNUSED(pid);
-    UNUSED(eid);
-    UNUSED(e);
-    UNUSED(f);
+    unsigned pid = get_id(key);
+    _subscribe(pid, eid, h);
+}
+
+
+void MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f)
+{
+    std::pair<unsigned, unsigned> key = {pid, eid};
+
+    auto it = mp_pub_sub.find(key);
+    if (it == mp_pub_sub.end())
+    {
+        MP_DATABUS_LOG("MPDataBus: No subscribers for publisher ID %u and event ID %u\n", pid, eid);
+        return;
+    }
+    const SubList& subs = it->second;
+
+    for (auto* handler : subs)
+    {
+        handler->handle(e, f);
+    }
 }
 
index d69918c94b710e6050b53d01e7829c17f118a466..ee1ce340e25005f0fa18ff96810a5c22db9e5506 100644 (file)
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
+#include <mutex>
+#include <condition_variable>
 #include <queue>
+#include <atomic>
+#include <thread>
 
 #include "main/snort_types.h"
 #include "data_bus.h"
 #include "framework/mp_transport.h"
 #include <bitset>
+#include "framework/mp_transport.h"
+
+#define DEFAULT_TRANSPORT "unix_transport"
+#define DEFAULT_MAX_EVENTQ_SIZE 1000
+#define WORKER_THREAD_SLEEP 100
+
+#define MP_DATABUS_LOG(msg, ...) do { \
+        if (!MPDataBus::enable_debug) \
+            break; \
+        LogMessage(msg, __VA_ARGS__); \
+    } while (0)
+
+
+template <typename T>
+class Ring;
 
 namespace snort
 {
@@ -59,13 +78,13 @@ typedef bool (*MPDeserializeFunc)(const char* buffer, uint16_t length, DataEvent
 // manner analogous to the approach used for intra-snort pub_sub.
 typedef unsigned MPEventType;
 
-struct MPEventInfo 
+struct MPEventInfo
 {
-    MPEventType type;
     unsigned pub_id;
-    DataEvent* event;
-    MPEventInfo(DataEvent* e, MPEventType t, unsigned id = 0)
-        : type(t), pub_id(id), event(e) {}
+    MPEventType type;
+    std::shared_ptr<DataEvent> event;
+    MPEventInfo(std::shared_ptr<DataEvent> e, MPEventType t, unsigned id = 0)
+        : pub_id(id), type(t), event(std::move(e)) {}
 };
 
 struct MPHelperFunctions {
@@ -76,39 +95,78 @@ struct MPHelperFunctions {
         : serializer(s), deserializer(d) {}
 };
 
+struct pair_hash
+{
+    template <class T1, class T2>
+    std::size_t operator()(const std::pair<T1, T2>& pair) const
+    {
+        std::hash<T1> hash1;
+        std::hash<T2> hash2;
+        return hash1(pair.first) ^ (hash2(pair.second) << 1);
+    }
+};
+
 class SO_PUBLIC MPDataBus
 { 
 public: 
-    MPDataBus(); 
+    MPDataBus();
     ~MPDataBus();
+    
+    static uint32_t mp_max_eventq_size;
+    static std::string transport;
+    static bool enable_debug;
 
-    static unsigned init(int);
+    static MPTransport * transport_layer;
+    unsigned init(int);
     void clone(MPDataBus& from, const char* exclude_name = nullptr);
 
-    unsigned get_id(const PubKey& key) 
-    { return DataBus::get_id(key); }
+    static unsigned get_id(const PubKey& key);
 
-    bool valid(unsigned pub_id)
+    static bool valid(unsigned pub_id)
     { return pub_id != 0; }
 
-    void subscribe(const PubKey& key, unsigned id, DataHandler* handler); 
+    static void subscribe(const PubKey& key, unsigned id, DataHandler* handler); 
 
-    bool publish(unsigned pub_id, unsigned evt_id, DataEvent& e, Flow* f = nullptr); 
+    // API for publishing the DataEvent to the peer Snort processes
+    // The user needs to pass a shared_ptr to the DataEvent object as the third argument
+    // This is to ensure that the DataEvent object is not deleted before it is published
+    // or consumed by the worker thread
+    // and the shared_ptr will handle the memory management by reference counting
+    static bool publish(unsigned pub_id, unsigned evt_id, std::shared_ptr<DataEvent> e, Flow* f = nullptr);
 
-    void register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc* mp_serializer_helper, MPDeserializeFunc* mp_deserializer_helper);
+    static void register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc& mp_serializer_helper, MPDeserializeFunc& mp_deserializer_helper);
 
     // API for receiving the DataEvent and Event type from transport layer using EventInfo
     void receive_message(const MPEventInfo& event_info);
 
+    Ring<std::shared_ptr<MPEventInfo>>* get_event_queue()
+    { return mp_event_queue; }
+
 private: 
     void _subscribe(unsigned pid, unsigned eid, DataHandler* h);
+    void _subscribe(const PubKey& key, unsigned eid, DataHandler* h);
+
     void _publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f);
 
 private:
     typedef std::vector<DataHandler*> SubList;
-    std::vector<SubList> mp_pub_sub;
+
+    std::unordered_map<std::pair<unsigned, unsigned>, SubList, pair_hash> mp_pub_sub;
+
+    std::atomic<bool> run_thread;
+    std::unique_ptr<std::thread> worker_thread;
+
+    Ring<std::shared_ptr<MPEventInfo>>* mp_event_queue;
+
+    static std::condition_variable queue_cv;
+    static std::mutex queue_mutex;
+
+    void start_worker_thread();
+    void stop_worker_thread();
+    void worker_thread_func();
+    void process_event_queue();
+};
 };
-}
 
 #endif
 
index 0f554e0f6ad4edc6309b6c937990732a744c0438..5a16663a6752f8daeeac4fe34a99f41d87c1aa7c 100644 (file)
@@ -3,6 +3,10 @@ add_cpputest( data_bus_test
     SOURCES ../data_bus.cc
 )
 
+add_cpputest( mp_data_bus_test
+    SOURCES ../mp_data_bus.cc
+)
+
 # libapi_def.a is actually a text file with the preprocessed header source
 
 if ( ENABLE_UNIT_TESTS )
@@ -11,4 +15,4 @@ if ( ENABLE_UNIT_TESTS )
     install(TARGETS api_def)
 endif ()
 
-
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
diff --git a/src/framework/test/mp_data_bus_test.cc b/src/framework/test/mp_data_bus_test.cc
new file mode 100644 (file)
index 0000000..658d0f8
--- /dev/null
@@ -0,0 +1,445 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2019-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_data_bus_test.cc author Umang Sharma <umasharm@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+
+#include "../mp_data_bus.h"
+#include "../main/snort_config.h"
+#include "utils/stats.h"
+#include "helpers/ring.h"
+#include <condition_variable>
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+#include <CppUTestExt/MockSupport.h>
+
+#include <managers/mp_transport_manager.h>
+#include <framework/mp_transport.h>
+
+using namespace snort;
+
+namespace snort
+{
+static SnortConfig s_conf;
+
+THREAD_LOCAL SnortConfig* snort_conf = &s_conf;
+
+void ErrorMessage(const char*, ...) { }
+void LogMessage(const char*, ...) { }
+
+const SnortConfig* SnortConfig::get_conf()
+{ return snort_conf; }
+
+SnortConfig::SnortConfig(const SnortConfig* const, const char*)
+: daq_config(nullptr), thread_config(nullptr)
+{ }
+
+SnortConfig::~SnortConfig()
+{ }
+}
+
+class MockMPTransport : public MPTransport
+{
+public:
+    MockMPTransport() = default;
+    ~MockMPTransport() override = default;
+
+    static int get_count()
+    {
+        return count;
+    }
+
+    static int get_test_register_helpers_calls()
+    {
+        return test_register_helpers_calls;
+    }
+
+    bool send_to_transport(MPEventInfo&) override
+    {
+        count++;
+        return true;
+    }
+
+    void register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&) override
+    {
+        test_register_helpers_calls++;
+        return;
+    }
+
+    void init_connection() override
+    {
+        return;
+    }
+
+    void register_receive_handler(const TransportReceiveEventHandler&) override
+    {
+        return;
+    }
+
+    void unregister_receive_handler() override
+    {
+        return;
+    }
+
+    void thread_init() override
+    {
+        return;
+    }
+
+    void thread_term() override
+    {
+        return;
+    }
+
+    bool configure(const SnortConfig*) override
+    {
+        return true;
+    }
+
+    void enable_logging() override
+    {
+        return;
+    }
+
+    void disable_logging() override
+    {
+        return;
+    }
+
+    bool is_logging_enabled() override
+    {
+        return true;
+    }
+private:
+    inline static int count = 0;
+    inline static int test_register_helpers_calls = 0;
+};
+
+static MockMPTransport mp_transport_pointer;
+
+MPTransport* MPTransportManager::get_transport(const std::string&)
+{
+    return &mp_transport_pointer;
+}
+
+class UTestEvent : public DataEvent
+{
+public:
+    UTestEvent(int m) : msg(m) { }
+
+    int get_message() const
+    { return msg; }
+
+private:
+    int msg;
+};
+
+bool serialize_mock(DataEvent*, char*& buffer, uint16_t* length)
+{
+    buffer = new char[9];
+    *length = 9;
+    memcpy(buffer, "test_data", 9);
+    return true;
+}
+
+bool deserialize_mock(const char*, uint16_t length, DataEvent*& event)
+{
+    event = new UTestEvent(length);
+    return true;
+}
+
+class UTestHandler1 : public DataHandler
+{
+public:
+    UTestHandler1(unsigned u = 0) : DataHandler("unit_test1")
+    { if (u) order = u; }
+
+    void handle(DataEvent& event, Flow*) override;
+
+    int evt_msg = 0;
+};
+
+class UTestHandler2 : public DataHandler
+{
+public:
+    UTestHandler2(unsigned u = 0) : DataHandler("unit_test2")
+    { if (u) order = u; }
+
+    void handle(DataEvent& event, Flow*) override;
+
+    int evt_msg = 1;
+};
+
+void UTestHandler1::handle(DataEvent& event, Flow*)
+{
+    UTestEvent* evt = static_cast<UTestEvent*>(&event);
+    if (evt)
+    {
+        evt_msg = evt->get_message();
+    }
+}
+
+void UTestHandler2::handle(DataEvent& event, Flow*)
+{
+    UTestEvent* evt = static_cast<UTestEvent*>(&event);
+    if (evt)
+    {
+        evt_msg = evt->get_message();
+    }
+}
+//--------------------------------------------------------------------------
+
+
+struct DbUtIds { enum : unsigned { EVENT1, EVENT2, num_ids }; };
+
+const PubKey pub_key1 { "mp_ut1", DbUtIds::num_ids };
+const PubKey pub_key2 { "mp_ut2", DbUtIds::num_ids };
+
+//--------------------------------------------------------------------------
+// Test Group
+//--------------------------------------------------------------------------
+
+TEST_GROUP(mp_data_bus_pub)
+{
+    unsigned pub_id1 = 0;  // cppcheck-suppress variableScope
+    MPDataBus* mp_dbus = nullptr;
+    void setup() override
+    {
+        mp_dbus = new MPDataBus();
+        mp_dbus->init(2);
+        pub_id1 = MPDataBus::get_id(pub_key1);
+        CHECK(MPDataBus::valid(pub_id1));
+
+        snort_conf->mp_dbus = mp_dbus;
+    }
+
+    void teardown() override
+    { }
+};
+
+TEST(mp_data_bus_pub, publish)
+{
+    CHECK_TRUE(mp_dbus->get_event_queue()->empty());
+    CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0);
+
+    std::shared_ptr<UTestEvent> event = std::make_shared<UTestEvent>(100);
+
+    mp_dbus->publish(pub_id1, DbUtIds::EVENT1, event);
+
+    std::this_thread::sleep_for(std::chrono::milliseconds(150));
+
+    delete mp_dbus;
+
+    CHECK(1 == MockMPTransport::get_count());
+}
+
+TEST_GROUP(mp_data_bus)
+{
+    unsigned pub_id1 = 0, pub_id2 = 0;  // cppcheck-suppress variableScope
+
+    void setup() override
+    {
+        unsigned max_procs_val = 2;
+        snort_conf->mp_dbus = new MPDataBus();
+        snort_conf->mp_dbus->init(max_procs_val);
+        pub_id1 = MPDataBus::get_id(pub_key1);
+        pub_id2 = MPDataBus::get_id(pub_key2);
+        CHECK(MPDataBus::valid(pub_id1));
+        CHECK(MPDataBus::valid(pub_id2));
+    }
+
+    void teardown() override
+    {
+        delete snort_conf->mp_dbus;
+    }
+};
+
+//--------------------------------------------------------------------------
+// Tests
+//--------------------------------------------------------------------------
+
+TEST(mp_data_bus, init)
+{
+    CHECK(SnortConfig::get_conf()->mp_dbus != nullptr);  
+    CHECK(SnortConfig::get_conf()->mp_dbus->get_event_queue() != nullptr);
+}
+
+TEST(mp_data_bus, no_subscribers_and_receive)
+{
+    UTestHandler1* h1 = new UTestHandler1();
+
+    std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+
+    MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+    SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+
+    CHECK_EQUAL(0, h1->evt_msg);
+    delete h1;
+    h1 = nullptr;
+}
+
+TEST(mp_data_bus, register_event_helpers)
+{
+    MPSerializeFunc serialize_func = serialize_mock;
+    MPDeserializeFunc deserialize_func = deserialize_mock;
+    CHECK(0 == MockMPTransport::get_test_register_helpers_calls());
+
+    MPDataBus::register_event_helpers(pub_key1, DbUtIds::EVENT1, serialize_func, deserialize_func);
+    CHECK(1 == MockMPTransport::get_test_register_helpers_calls());
+
+    MPDataBus::register_event_helpers(pub_key1, DbUtIds::EVENT1, serialize_func, deserialize_func);
+    CHECK(2 == MockMPTransport::get_test_register_helpers_calls());
+}
+
+TEST(mp_data_bus, subscribe_and_receive)
+{
+    // one snort subscribes to it
+    UTestHandler1* h1 = new UTestHandler1();
+    MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+
+    // publish event from other snort
+    // since we don't have a way to publish events, we will use receive_message to simulate the event
+    // from a different snort
+    std::shared_ptr<UTestEvent> event = std::make_shared<UTestEvent>(100);
+
+    MPEventInfo event_info(event, MPEventType(DbUtIds::EVENT1), pub_id1);
+    SnortConfig::get_conf()->mp_dbus->receive_message(event_info);
+    
+    CHECK_EQUAL(100, h1->evt_msg);
+
+    std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(200);
+    
+    MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+    SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+    
+    CHECK_EQUAL(200, h1->evt_msg);
+}
+
+TEST(mp_data_bus, two_subscribers_diff_event_and_receive)
+{
+    UTestHandler1* h1 = new UTestHandler1();
+    UTestHandler2* h2 = new UTestHandler2();
+
+    MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+    MPDataBus::subscribe(pub_key2, DbUtIds::EVENT2, h2);
+
+    std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+
+    MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+    SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+
+    CHECK_EQUAL(100, h1->evt_msg);
+    CHECK_EQUAL(1, h2->evt_msg);
+
+    std::shared_ptr<UTestEvent> event2 = std::make_shared<UTestEvent>(200);
+
+    MPEventInfo event_info2(event2, MPEventType(DbUtIds::EVENT2), pub_id2);
+    SnortConfig::get_conf()->mp_dbus->receive_message(event_info2);
+
+    CHECK_EQUAL(100, h1->evt_msg);
+    CHECK_EQUAL(200, h2->evt_msg);
+}
+
+TEST(mp_data_bus, two_subscribers_same_event_and_receive)
+{
+    UTestHandler1* h1 = new UTestHandler1();
+    UTestHandler2* h2 = new UTestHandler2();
+
+    MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+    MPDataBus::subscribe(pub_key2, DbUtIds::EVENT1, h2);
+
+    std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+
+    MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+    SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+
+    CHECK_EQUAL(100, h1->evt_msg);
+    CHECK_EQUAL(1, h2->evt_msg);
+}
+
+TEST_GROUP(mp_data_bus_clone)
+{
+    unsigned pub_id1 = 0, pub_id2 = 0;  // cppcheck-suppress variableScope
+    void setup() override
+    {
+        unsigned max_procs_val = 2;
+        snort_conf->mp_dbus = new MPDataBus();
+        snort_conf->mp_dbus->init(max_procs_val);
+        pub_id1 = MPDataBus::get_id(pub_key1);
+        pub_id2 = MPDataBus::get_id(pub_key2);
+        CHECK(MPDataBus::valid(pub_id1));
+        CHECK(MPDataBus::valid(pub_id2));
+    }
+
+    void teardown() override
+    {
+        delete snort_conf->mp_dbus;
+    }
+};
+//-------------------------------------------------------------------------
+
+TEST(mp_data_bus_clone, z_clone)
+{
+    unsigned pub_id1, pub_id2;
+    pub_id1 = MPDataBus::get_id(pub_key1);
+    pub_id2 = MPDataBus::get_id(pub_key2);
+    // subscribing to the events in the original mp_data_bus
+    // and then cloning the mp_data_bus
+    // and checking if the events are received in the cloned mp_data_bus
+    // and not in the original mp_data_bus
+    UTestHandler1* h1 = new UTestHandler1();
+    MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+
+    UTestHandler2* h2 = new UTestHandler2();
+    MPDataBus::subscribe(pub_key2, DbUtIds::EVENT2, h2);
+
+    // original mp_data_bus should be deleted with previous SnortConfig
+    // deleted with exit handlers of Test framework
+    MPDataBus* mp_data_bus_cloned = new MPDataBus();
+    mp_data_bus_cloned->clone(*SnortConfig::get_conf()->mp_dbus, nullptr);
+
+    std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+    MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+
+    mp_data_bus_cloned->receive_message(event_info1);
+
+    CHECK_EQUAL(100, h1->evt_msg);
+    CHECK_EQUAL(1, h2->evt_msg);
+
+    std::shared_ptr<UTestEvent> event2 = std::make_shared<UTestEvent>(200);
+
+    MPEventInfo event_info2(event2, MPEventType(DbUtIds::EVENT2), pub_id2);
+    mp_data_bus_cloned->receive_message(event_info2);
+    CHECK_EQUAL(100, h1->evt_msg);
+    CHECK_EQUAL(200, h2->evt_msg);
+}
+
+//-------------------------------------------------------------------------
+// main
+//-------------------------------------------------------------------------
+
+int main(int argc, char** argv)
+{
+    // event_map is not released until after cpputest gives up
+    MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
+    return CommandLineTestRunner::RunAllTests(argc, argv);
+}
\ No newline at end of file
index 8d6f7eb42dd26c00ac0c875726957b6cf8747a1c..02597c88afcf0140ee539ae0ed8f04519432c9f4 100644 (file)
@@ -209,10 +209,10 @@ enum HelpType
         PluginManager::list_plugins();
         break;
     }
+    MPTransportManager::term();
     ModuleManager::term();
     PluginManager::release_plugins();
     ScriptManager::release_scripts();
-    MPTransportManager::term();
     delete SnortConfig::get_conf();
     exit(0);
 }
index 9106c3238ce8e285dc8bb4247d2120eaf4a3f53b..dbaa053fdd09bceb5ef25f5b8e7d1bd75d93eb16 100644 (file)
@@ -392,6 +392,102 @@ bool ClassificationsModule::set(const char*, Value& v, SnortConfig*)
     return true;
 }
 
+//-------------------------------------------------------------------------
+// multiprocess data bus module
+//-------------------------------------------------------------------------
+
+static const Parameter mp_data_bus_params[] =
+{
+    { "max_eventq_size", Parameter::PT_INT, "100:65535", "1000",
+      "maximum events to queue" },
+
+    { "transport", Parameter::PT_STRING, nullptr, nullptr,
+      "transport to use for inter-process communication" },
+
+    { "debug", Parameter::PT_BOOL, nullptr, "false",
+      "enable debugging" },
+
+    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+static int enable_debug(lua_State*)
+{
+    if(SnortConfig::get_conf()->mp_dbus)
+        SnortConfig::get_conf()->mp_dbus->enable_debug = true;
+
+    return 0;
+}
+
+static int disable_debug(lua_State*)
+{
+    if(SnortConfig::get_conf()->mp_dbus)
+        SnortConfig::get_conf()->mp_dbus->enable_debug = false;
+
+    return 0;
+}
+
+static const Command mp_dbus_cmds[] =
+{
+    {"enable", enable_debug, nullptr, "enable multiprocess data bus debugging"},
+    {"disable", disable_debug, nullptr, "disable multiprocess data bus debugging"},
+    {nullptr, nullptr, nullptr, nullptr}
+};
+
+#define mp_data_bus_help \
+    "configure multiprocess data bus"
+
+class MPDataBusModule : public Module
+{
+public:
+    MPDataBusModule() :
+        Module("mp_data_bus", mp_data_bus_help, mp_data_bus_params) { }
+
+    bool set(const char*, Value&, SnortConfig*) override;
+    bool begin(const char*, int, SnortConfig*) override;
+    bool end(const char*, int, SnortConfig*) override;
+    const Command* get_commands() const override;
+
+    Usage get_usage() const override
+    { return GLOBAL; }
+};
+
+bool MPDataBusModule::begin(const char*, int, SnortConfig*)
+{
+    return true;
+}
+
+bool MPDataBusModule::end(const char*, int, SnortConfig*)
+{
+    return true;
+}
+
+bool MPDataBusModule::set(const char*, Value& v, SnortConfig*)
+{
+    if ( v.is("max_eventq_size") )
+    {
+        MPDataBus::mp_max_eventq_size = v.get_uint32();
+    }
+    else if ( v.is("transport") )
+    {
+        MPDataBus::transport = v.get_string();
+    }
+    else if ( v.is("debug") )
+    {
+        MPDataBus::enable_debug = v.get_bool();
+    }
+    else 
+    {
+        WarningMessage("MPDataBus: Unknown parameter '%s' in mp_data_bus module\n", v.get_name());
+        return false;
+    }
+    return true;
+}
+
+const Command* MPDataBusModule::get_commands() const
+{
+    return mp_dbus_cmds;
+}
+
 //-------------------------------------------------------------------------
 // reference module
 //-------------------------------------------------------------------------
@@ -1969,6 +2065,7 @@ void module_init()
     ModuleManager::add_module(new CodecModule);
     ModuleManager::add_module(new DetectionModule);
     ModuleManager::add_module(new MemoryModule);
+    ModuleManager::add_module(new MPDataBusModule);
     ModuleManager::add_module(new PacketTracerModule);
     ModuleManager::add_module(new PacketsModule);
     ModuleManager::add_module(new ProcessModule);
index 2f889ddc21a23154f11e300205c3df6be81d9504..b735a25b607e9bf15782646fe13cee4d967559a5 100644 (file)
@@ -171,6 +171,11 @@ void Snort::init(int argc, char** argv)
     // This call must be immediately after "SnortConfig::set_conf(sc)"
     // since the first trace call may happen somewhere after this point
     TraceApi::thread_init(sc->trace_config);
+    if (sc->max_procs > 1)
+    {
+        sc->mp_dbus = new MPDataBus();
+        sc->mp_dbus->init(sc->max_procs);
+    }
 
     PluginManager::load_so_plugins(sc);
 
@@ -330,6 +335,7 @@ void Snort::term()
 
     const SnortConfig* sc = SnortConfig::get_conf();
 
+    MPTransportManager::term();
     IpsManager::global_term(sc);
     HostAttributesManager::term();
 
@@ -374,7 +380,6 @@ void Snort::term()
     host_cache.term();
     PluginManager::release_plugins();
     ScriptManager::release_scripts();
-    MPTransportManager::term();
     memory::MemoryCap::term();
     detection_filter_term();
 
@@ -575,7 +580,8 @@ SnortConfig* Snort::get_updated_policy(
 
     SnortConfig* sc = new SnortConfig(other_conf, iname);
     sc->global_dbus->clone(*other_conf->global_dbus, iname);
-    if (sc->max_procs > 1)
+
+    if (other_conf->mp_dbus != nullptr)
         sc->mp_dbus->clone(*other_conf->mp_dbus, iname);
 
     if ( fname )
index 27a2d5ec839bdebf6e0bdbf2e3463ec689eb73f6..5db70d17bbef0fa4111dcfa301ec192b9591145d 100644 (file)
@@ -199,8 +199,6 @@ void SnortConfig::init(const SnortConfig* const other_conf, ProtocolReference* p
         policy_map = new PolicyMap;
         thread_config = new ThreadConfig();
         global_dbus = new DataBus();
-        if (max_procs > 1)
-            mp_dbus = new MPDataBus();
 
         proto_ref = new ProtocolReference(protocol_reference);
         so_rules = new SoRules;
@@ -264,7 +262,7 @@ SnortConfig::~SnortConfig()
     if ( cloned )
     {
         delete global_dbus;
-        if (max_procs > 1)
+        if (mp_dbus)
             delete mp_dbus;
         policy_map->set_cloned(true);
         delete policy_map;
@@ -325,7 +323,7 @@ SnortConfig::~SnortConfig()
     delete overlay_trace_config;
     delete ha_config;
     delete global_dbus;
-    if (max_procs > 1)
+    if (mp_dbus)
         delete mp_dbus;
 
     delete profiler;
index 6b15acad1dad606d0029565a74f76043a5a290b4..8e40562f71b8a982faafd2ff444af524f515720b 100644 (file)
@@ -101,11 +101,10 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg)
 
         DataEvent* internal_event = nullptr;
         (deserialize_func)((const char*)(msg->content + sizeof(MPTransportMessageHeader)), transport_message_header->data_length, internal_event);
-        MPEventInfo event(internal_event, transport_message_header->event_id, transport_message_header->pub_id);
+        MPEventInfo event(std::shared_ptr<DataEvent> (internal_event), transport_message_header->event_id, transport_message_header->pub_id);
 
         (transport_receive_handler)(event);
 
-        delete internal_event;
     }
     delete msg;
 }
@@ -151,9 +150,9 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
     transport_message.header.type = EVENT_MESSAGE;
     transport_message.header.pub_id = event.pub_id;
     transport_message.header.event_id = event.type;
-
     
-    (serialize_func)(event.event, transport_message.data, &transport_message.header.data_length);
+
+    (serialize_func)(event.event.get(), transport_message.data, &transport_message.header.data_length);
     for (auto &&sc_handler : this->side_channels)
     {
         auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length);
index b3a9f5b9685b55c46591ef3d7aadd3b455e1adc7..84cf8497999dd4a73f2d43feff2af44280c58c82 100644 (file)
@@ -454,15 +454,16 @@ TEST(unix_transport_test_messaging, send_to_transport_biderectional)
     test_transport_message_1->register_event_helpers(0, 0, mp_helper_functions_mock);
     test_transport_message_2->register_event_helpers(0, 0, mp_helper_functions_mock);
 
-    MPEventInfo event(&test_event, 0, 0);
+    std::shared_ptr<TestDataEvent> event = std::make_shared<TestDataEvent>();
 
-    auto res = test_transport_message_1->send_to_transport(event);
+    MPEventInfo event_info(event, 0, 0);
+    auto res = test_transport_message_1->send_to_transport(event_info);
     
 
     CHECK(res == true);
     CHECK(test_serialize_calls == 1);
 
-    res = test_transport_message_2->send_to_transport(event);
+    res = test_transport_message_2->send_to_transport(event_info);
     
     CHECK(res == true);
     CHECK(test_serialize_calls == 2);
@@ -479,7 +480,9 @@ TEST(unix_transport_test_messaging, send_to_transport_no_helpers)
 {
     clear_test_calls();
 
-    MPEventInfo event(&test_event, 0, 0);
+    std::shared_ptr<TestDataEvent> event_in = std::make_shared<TestDataEvent>();
+
+    MPEventInfo event(event_in, 0, 0);
 
     auto res = test_transport_message_1->send_to_transport(event);
     CHECK(res == false);