if (attempt_connection(sfd, path)) {
// Connection successful
UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
- LogMessage("UnixDomainC: Connected to %s", path);
+ LogMessage("UnixDomainC: Connected to %s\n", path);
if(update_handler)
{
unixdomain_conn->set_update_handler(update_handler);
return nullptr;
}
}
- LogMessage("UnixDomainC: Connected to %s", path);
+ LogMessage("UnixDomainC: Connected to %s\n", path);
UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
unixdomain_conn->set_update_handler(update_handler);
if(update_handler)
{
ErrorMessage("UnixDomainC: Too many errors, stopping accept thread\n");
close(sock_fd);
+ sock_fd = -1;
+ should_accept = false;
return;
}
int peer_sfd = accept(sock_fd, nullptr, nullptr);
if (peer_sfd == -1)
{
+ if (!should_accept)
+ return;
error_count++;
ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno));
continue;
void UnixDomainConnectorListener::stop_accepting_connections()
{
- if(should_accept)
+ should_accept = false;
+ if (sock_fd)
{
- should_accept = false;
+ shutdown(sock_fd, SHUT_RDWR);
close(sock_fd);
- if (accept_thread && accept_thread->joinable()) {
- accept_thread->join();
- }
+ sock_fd = -1;
+ }
+
+ if (accept_thread && accept_thread->joinable()) {
+ accept_thread->join();
delete accept_thread;
accept_thread = nullptr;
}
#include "pub_sub/intrinsic_event_ids.h"
#include "utils/stats.h"
#include "main/snort_types.h"
-#include "managers/mp_transport_manager.h"
#include "log/messages.h"
+#include "helpers/ring.h"
+#include "managers/mp_transport_manager.h"
using namespace snort;
+std::condition_variable MPDataBus::queue_cv;
+std::mutex MPDataBus::queue_mutex;
+uint32_t MPDataBus::mp_max_eventq_size = DEFAULT_MAX_EVENTQ_SIZE;
+std::string MPDataBus::transport = DEFAULT_TRANSPORT;
+bool MPDataBus::enable_debug = false;
+MPTransport* MPDataBus::transport_layer = nullptr;
+
static std::unordered_map<std::string, unsigned> mp_pub_ids;
//--------------------------------------------------------------------------
// public methods
//-------------------------------------------------------------------------
-MPDataBus::MPDataBus() = default;
+MPDataBus::MPDataBus() : run_thread(true)
+{
+ mp_event_queue = new Ring<std::shared_ptr<MPEventInfo>>(mp_max_eventq_size);
+ start_worker_thread();
+}
MPDataBus::~MPDataBus()
{
- // Clean up mp_pub_sub
- for (auto& sublist : mp_pub_sub)
+ stop_worker_thread();
+
+ for (auto& [key, sublist] : mp_pub_sub)
{
for (auto* handler : sublist)
{
if (handler->cloned)
+ {
handler->cloned = false;
+ }
else
- delete handler;
+ {
+ delete handler;
+ }
}
sublist.clear();
}
mp_pub_sub.clear();
+ delete mp_event_queue;
+ mp_event_queue = nullptr;
}
unsigned MPDataBus::init(int max_procs)
{
- UNUSED(max_procs);
+ if (max_procs <= 1)
+ {
+ return 1;
+ }
+
+ transport_layer = MPTransportManager::get_transport(transport);
+ if (transport_layer == nullptr)
+ {
+ ErrorMessage("MPDataBus: Failed to get transport layer\n");
+ return 0;
+ }
+
+ transport_layer->register_receive_handler(std::bind(&MPDataBus::receive_message, this, std::placeholders::_1));
+ transport_layer->init_connection();
+
return 0;
}
void MPDataBus::clone(MPDataBus& from, const char* exclude_name)
-{
- UNUSED(from);
- UNUSED(exclude_name);
+{
+ from.stop_worker_thread();
+ for (const auto& [key, sublist] : from.mp_pub_sub)
+ {
+ unsigned pid = key.first;
+ unsigned eid = key.second;
+
+ for (auto* h : sublist)
+ {
+ if (!exclude_name || strcmp(exclude_name, h->module_name) != 0)
+ {
+ h->cloned = true;
+ _subscribe(pid, eid, h);
+ }
+ }
+ }
+}
+
+unsigned MPDataBus::get_id(const PubKey& key)
+{
+ // Generate a unique hash for the publisher's name,
+ std::hash<std::string> hasher;
+ unsigned unique_id = (hasher(key.publisher) % 10000);
+
+ auto it = mp_pub_ids.find(key.publisher);
+
+ if (it == mp_pub_ids.end())
+ {
+ // Map the unique hash to the publisher
+ mp_pub_ids[key.publisher] = unique_id;
+ }
+ // Return the unique ID for the publisher
+ return mp_pub_ids[key.publisher];
}
-// module subscribes to an event from a peer snort process
void MPDataBus::subscribe(const PubKey& key, unsigned eid, DataHandler* h)
{
- UNUSED(key);
- UNUSED(eid);
- UNUSED(h);
+ if(! SnortConfig::get_conf()->mp_dbus)
+ {
+ ErrorMessage("MPDataBus: MPDataBus not initialized\n");
+ return;
+ }
+
+ SnortConfig::get_conf()->mp_dbus->_subscribe(key, eid, h);
+ MP_DATABUS_LOG("MPDataBus: Subscribed to event ID %u\n", eid);
}
-// publish event to all peer snort processes subscribed to the event
-bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, DataEvent& e, Flow* f)
+bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, std::shared_ptr<DataEvent> e, Flow*)
{
- // Publish implementation
- UNUSED(pub_id);
- UNUSED(evt_id);
- UNUSED(e);
- UNUSED(f);
+ std::shared_ptr<MPEventInfo> event_info =
+ std::make_shared<MPEventInfo>(std::move(e), MPEventType(evt_id), pub_id);
+
+ const SnortConfig *sc = SnortConfig::get_conf();
+
+ if (sc->mp_dbus == nullptr)
+ {
+ ErrorMessage("MPDataBus: MPDataBus not initialized\n");
+ return false;
+ }
+
+ if (sc->mp_dbus->mp_event_queue != nullptr and !sc->mp_dbus->mp_event_queue->full() and !sc->mp_dbus->mp_event_queue->put(event_info)) {
+ ErrorMessage("MPDataBus: Failed to enqueue event for publisher ID %u and event ID %u\n", pub_id, evt_id);
+ return false;
+ }
+
+ {
+ std::lock_guard<std::mutex> lock(queue_mutex);
+ queue_cv.notify_one();
+ }
+
+ MP_DATABUS_LOG("MPDataBus: Event published for publisher ID %u and event ID %u\n", pub_id, evt_id);
+
return true;
}
-// register event helpers for serialization and deserialization of msg events
-void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc* mp_serializer_helper, MPDeserializeFunc* mp_deserializer_helper)
+void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc& mp_serializer_helper, MPDeserializeFunc& mp_deserializer_helper)
{
- UNUSED(key);
- UNUSED(evt_id);
- UNUSED(mp_serializer_helper);
- UNUSED(mp_deserializer_helper);
+ if (!SnortConfig::get_conf()->mp_dbus && !SnortConfig::get_conf()->mp_dbus->transport_layer)
+ {
+ ErrorMessage("MPDataBus: MPDataBus or transport layer not initialized\n");
+ return;
+ }
+
+ unsigned pub_id = get_id(key);
+
+ MPHelperFunctions helpers(mp_serializer_helper, mp_deserializer_helper);
+
+ SnortConfig::get_conf()->mp_dbus->transport_layer->register_event_helpers(pub_id, evt_id, helpers);
+ MP_DATABUS_LOG("MPDataBus: Registered event helpers for event ID %u\n", evt_id);
}
// API for receiving the DataEvent and Event type from transport layer
void MPDataBus::receive_message(const MPEventInfo& event_info)
{
- UNUSED(event_info);
+ DataEvent *e = event_info.event.get();
+ unsigned evt_id = event_info.type;
+ unsigned pub_id = event_info.pub_id;
+
+ MP_DATABUS_LOG("MPDataBus: Received message for publisher ID %u and event ID %u\n", pub_id, evt_id);
+
+ _publish(pub_id, evt_id, *e, nullptr);
}
//--------------------------------------------------------------------------
// private methods
//--------------------------------------------------------------------------
+void MPDataBus::process_event_queue()
+{
+ if (!mp_event_queue) {
+ return;
+ }
+
+ std::unique_lock<std::mutex> lock(queue_mutex);
+
+ queue_cv.wait_for(lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP), [this]() {
+ return mp_event_queue != nullptr && !mp_event_queue->empty();
+ });
+
+ lock.unlock();
+
+ while (!mp_event_queue->empty()) {
+ std::shared_ptr<MPEventInfo> event_info = mp_event_queue->get(nullptr);
+ if (event_info) {
+ MP_DATABUS_LOG("MPDataBus: Processing event for publisher ID %u \n",
+ event_info->pub_id);
+
+ transport_layer->send_to_transport(*event_info);
+ }
+ }
+}
+
+void MPDataBus::worker_thread_func()
+{
+ while (run_thread.load() ) {
+ process_event_queue();
+ }
+}
+
+void MPDataBus::start_worker_thread()
+{
+ run_thread.store(true);
+ worker_thread = std::make_unique<std::thread>(&MPDataBus::worker_thread_func, this);
+}
+
+void MPDataBus::stop_worker_thread()
+{
+ run_thread.store(false);
+ queue_cv.notify_one();
+
+ if (worker_thread && worker_thread->joinable())
+ {
+ worker_thread->join();
+ }
+
+ worker_thread.reset();
+}
+
+static bool compare(DataHandler* a, DataHandler* b)
+{
+ if ( a->order and b->order )
+ return a->order < b->order;
+
+ if ( a->order )
+ return true;
+
+ return false;
+}
void MPDataBus::_subscribe(unsigned pid, unsigned eid, DataHandler* h)
{
- UNUSED(pid);
- UNUSED(eid);
- UNUSED(h);
+ std::pair<unsigned, unsigned> key = {pid, eid};
+
+ SubList& subs = mp_pub_sub[key];
+ subs.emplace_back(h);
+
+ std::sort(subs.begin(), subs.end(), compare);
}
-void MPDataBus::_publish(unsigned int pid, unsigned int eid, DataEvent& e, Flow* f)
+void MPDataBus::_subscribe(const PubKey& key, unsigned eid, DataHandler* h)
{
- UNUSED(pid);
- UNUSED(eid);
- UNUSED(e);
- UNUSED(f);
+ unsigned pid = get_id(key);
+ _subscribe(pid, eid, h);
+}
+
+
+void MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f)
+{
+ std::pair<unsigned, unsigned> key = {pid, eid};
+
+ auto it = mp_pub_sub.find(key);
+ if (it == mp_pub_sub.end())
+ {
+ MP_DATABUS_LOG("MPDataBus: No subscribers for publisher ID %u and event ID %u\n", pid, eid);
+ return;
+ }
+ const SubList& subs = it->second;
+
+ for (auto* handler : subs)
+ {
+ handler->handle(e, f);
+ }
}
#include <unordered_map>
#include <unordered_set>
#include <vector>
+#include <mutex>
+#include <condition_variable>
#include <queue>
+#include <atomic>
+#include <thread>
#include "main/snort_types.h"
#include "data_bus.h"
#include "framework/mp_transport.h"
#include <bitset>
+#include "framework/mp_transport.h"
+
+#define DEFAULT_TRANSPORT "unix_transport"
+#define DEFAULT_MAX_EVENTQ_SIZE 1000
+#define WORKER_THREAD_SLEEP 100
+
+#define MP_DATABUS_LOG(msg, ...) do { \
+ if (!MPDataBus::enable_debug) \
+ break; \
+ LogMessage(msg, __VA_ARGS__); \
+ } while (0)
+
+
+template <typename T>
+class Ring;
namespace snort
{
// manner analogous to the approach used for intra-snort pub_sub.
typedef unsigned MPEventType;
-struct MPEventInfo
+struct MPEventInfo
{
- MPEventType type;
unsigned pub_id;
- DataEvent* event;
- MPEventInfo(DataEvent* e, MPEventType t, unsigned id = 0)
- : type(t), pub_id(id), event(e) {}
+ MPEventType type;
+ std::shared_ptr<DataEvent> event;
+ MPEventInfo(std::shared_ptr<DataEvent> e, MPEventType t, unsigned id = 0)
+ : pub_id(id), type(t), event(std::move(e)) {}
};
struct MPHelperFunctions {
: serializer(s), deserializer(d) {}
};
+struct pair_hash
+{
+ template <class T1, class T2>
+ std::size_t operator()(const std::pair<T1, T2>& pair) const
+ {
+ std::hash<T1> hash1;
+ std::hash<T2> hash2;
+ return hash1(pair.first) ^ (hash2(pair.second) << 1);
+ }
+};
+
class SO_PUBLIC MPDataBus
{
public:
- MPDataBus();
+ MPDataBus();
~MPDataBus();
+
+ static uint32_t mp_max_eventq_size;
+ static std::string transport;
+ static bool enable_debug;
- static unsigned init(int);
+ static MPTransport * transport_layer;
+ unsigned init(int);
void clone(MPDataBus& from, const char* exclude_name = nullptr);
- unsigned get_id(const PubKey& key)
- { return DataBus::get_id(key); }
+ static unsigned get_id(const PubKey& key);
- bool valid(unsigned pub_id)
+ static bool valid(unsigned pub_id)
{ return pub_id != 0; }
- void subscribe(const PubKey& key, unsigned id, DataHandler* handler);
+ static void subscribe(const PubKey& key, unsigned id, DataHandler* handler);
- bool publish(unsigned pub_id, unsigned evt_id, DataEvent& e, Flow* f = nullptr);
+ // API for publishing the DataEvent to the peer Snort processes
+ // The user needs to pass a shared_ptr to the DataEvent object as the third argument
+ // This is to ensure that the DataEvent object is not deleted before it is published
+ // or consumed by the worker thread
+ // and the shared_ptr will handle the memory management by reference counting
+ static bool publish(unsigned pub_id, unsigned evt_id, std::shared_ptr<DataEvent> e, Flow* f = nullptr);
- void register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc* mp_serializer_helper, MPDeserializeFunc* mp_deserializer_helper);
+ static void register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc& mp_serializer_helper, MPDeserializeFunc& mp_deserializer_helper);
// API for receiving the DataEvent and Event type from transport layer using EventInfo
void receive_message(const MPEventInfo& event_info);
+ Ring<std::shared_ptr<MPEventInfo>>* get_event_queue()
+ { return mp_event_queue; }
+
private:
void _subscribe(unsigned pid, unsigned eid, DataHandler* h);
+ void _subscribe(const PubKey& key, unsigned eid, DataHandler* h);
+
void _publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f);
private:
typedef std::vector<DataHandler*> SubList;
- std::vector<SubList> mp_pub_sub;
+
+ std::unordered_map<std::pair<unsigned, unsigned>, SubList, pair_hash> mp_pub_sub;
+
+ std::atomic<bool> run_thread;
+ std::unique_ptr<std::thread> worker_thread;
+
+ Ring<std::shared_ptr<MPEventInfo>>* mp_event_queue;
+
+ static std::condition_variable queue_cv;
+ static std::mutex queue_mutex;
+
+ void start_worker_thread();
+ void stop_worker_thread();
+ void worker_thread_func();
+ void process_event_queue();
+};
};
-}
#endif
SOURCES ../data_bus.cc
)
+add_cpputest( mp_data_bus_test
+ SOURCES ../mp_data_bus.cc
+)
+
# libapi_def.a is actually a text file with the preprocessed header source
if ( ENABLE_UNIT_TESTS )
install(TARGETS api_def)
endif ()
-
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread")
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2019-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_data_bus_test.cc author Umang Sharma <umasharm@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+
+#include "../mp_data_bus.h"
+#include "../main/snort_config.h"
+#include "utils/stats.h"
+#include "helpers/ring.h"
+#include <condition_variable>
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+#include <CppUTestExt/MockSupport.h>
+
+#include <managers/mp_transport_manager.h>
+#include <framework/mp_transport.h>
+
+using namespace snort;
+
+namespace snort
+{
+static SnortConfig s_conf;
+
+THREAD_LOCAL SnortConfig* snort_conf = &s_conf;
+
+void ErrorMessage(const char*, ...) { }
+void LogMessage(const char*, ...) { }
+
+const SnortConfig* SnortConfig::get_conf()
+{ return snort_conf; }
+
+SnortConfig::SnortConfig(const SnortConfig* const, const char*)
+: daq_config(nullptr), thread_config(nullptr)
+{ }
+
+SnortConfig::~SnortConfig()
+{ }
+}
+
+class MockMPTransport : public MPTransport
+{
+public:
+ MockMPTransport() = default;
+ ~MockMPTransport() override = default;
+
+ static int get_count()
+ {
+ return count;
+ }
+
+ static int get_test_register_helpers_calls()
+ {
+ return test_register_helpers_calls;
+ }
+
+ bool send_to_transport(MPEventInfo&) override
+ {
+ count++;
+ return true;
+ }
+
+ void register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&) override
+ {
+ test_register_helpers_calls++;
+ return;
+ }
+
+ void init_connection() override
+ {
+ return;
+ }
+
+ void register_receive_handler(const TransportReceiveEventHandler&) override
+ {
+ return;
+ }
+
+ void unregister_receive_handler() override
+ {
+ return;
+ }
+
+ void thread_init() override
+ {
+ return;
+ }
+
+ void thread_term() override
+ {
+ return;
+ }
+
+ bool configure(const SnortConfig*) override
+ {
+ return true;
+ }
+
+ void enable_logging() override
+ {
+ return;
+ }
+
+ void disable_logging() override
+ {
+ return;
+ }
+
+ bool is_logging_enabled() override
+ {
+ return true;
+ }
+private:
+ inline static int count = 0;
+ inline static int test_register_helpers_calls = 0;
+};
+
+static MockMPTransport mp_transport_pointer;
+
+MPTransport* MPTransportManager::get_transport(const std::string&)
+{
+ return &mp_transport_pointer;
+}
+
+class UTestEvent : public DataEvent
+{
+public:
+ UTestEvent(int m) : msg(m) { }
+
+ int get_message() const
+ { return msg; }
+
+private:
+ int msg;
+};
+
+bool serialize_mock(DataEvent*, char*& buffer, uint16_t* length)
+{
+ buffer = new char[9];
+ *length = 9;
+ memcpy(buffer, "test_data", 9);
+ return true;
+}
+
+bool deserialize_mock(const char*, uint16_t length, DataEvent*& event)
+{
+ event = new UTestEvent(length);
+ return true;
+}
+
+class UTestHandler1 : public DataHandler
+{
+public:
+ UTestHandler1(unsigned u = 0) : DataHandler("unit_test1")
+ { if (u) order = u; }
+
+ void handle(DataEvent& event, Flow*) override;
+
+ int evt_msg = 0;
+};
+
+class UTestHandler2 : public DataHandler
+{
+public:
+ UTestHandler2(unsigned u = 0) : DataHandler("unit_test2")
+ { if (u) order = u; }
+
+ void handle(DataEvent& event, Flow*) override;
+
+ int evt_msg = 1;
+};
+
+void UTestHandler1::handle(DataEvent& event, Flow*)
+{
+ UTestEvent* evt = static_cast<UTestEvent*>(&event);
+ if (evt)
+ {
+ evt_msg = evt->get_message();
+ }
+}
+
+void UTestHandler2::handle(DataEvent& event, Flow*)
+{
+ UTestEvent* evt = static_cast<UTestEvent*>(&event);
+ if (evt)
+ {
+ evt_msg = evt->get_message();
+ }
+}
+//--------------------------------------------------------------------------
+
+
+struct DbUtIds { enum : unsigned { EVENT1, EVENT2, num_ids }; };
+
+const PubKey pub_key1 { "mp_ut1", DbUtIds::num_ids };
+const PubKey pub_key2 { "mp_ut2", DbUtIds::num_ids };
+
+//--------------------------------------------------------------------------
+// Test Group
+//--------------------------------------------------------------------------
+
+TEST_GROUP(mp_data_bus_pub)
+{
+ unsigned pub_id1 = 0; // cppcheck-suppress variableScope
+ MPDataBus* mp_dbus = nullptr;
+ void setup() override
+ {
+ mp_dbus = new MPDataBus();
+ mp_dbus->init(2);
+ pub_id1 = MPDataBus::get_id(pub_key1);
+ CHECK(MPDataBus::valid(pub_id1));
+
+ snort_conf->mp_dbus = mp_dbus;
+ }
+
+ void teardown() override
+ { }
+};
+
+TEST(mp_data_bus_pub, publish)
+{
+ CHECK_TRUE(mp_dbus->get_event_queue()->empty());
+ CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0);
+
+ std::shared_ptr<UTestEvent> event = std::make_shared<UTestEvent>(100);
+
+ mp_dbus->publish(pub_id1, DbUtIds::EVENT1, event);
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(150));
+
+ delete mp_dbus;
+
+ CHECK(1 == MockMPTransport::get_count());
+}
+
+TEST_GROUP(mp_data_bus)
+{
+ unsigned pub_id1 = 0, pub_id2 = 0; // cppcheck-suppress variableScope
+
+ void setup() override
+ {
+ unsigned max_procs_val = 2;
+ snort_conf->mp_dbus = new MPDataBus();
+ snort_conf->mp_dbus->init(max_procs_val);
+ pub_id1 = MPDataBus::get_id(pub_key1);
+ pub_id2 = MPDataBus::get_id(pub_key2);
+ CHECK(MPDataBus::valid(pub_id1));
+ CHECK(MPDataBus::valid(pub_id2));
+ }
+
+ void teardown() override
+ {
+ delete snort_conf->mp_dbus;
+ }
+};
+
+//--------------------------------------------------------------------------
+// Tests
+//--------------------------------------------------------------------------
+
+TEST(mp_data_bus, init)
+{
+ CHECK(SnortConfig::get_conf()->mp_dbus != nullptr);
+ CHECK(SnortConfig::get_conf()->mp_dbus->get_event_queue() != nullptr);
+}
+
+TEST(mp_data_bus, no_subscribers_and_receive)
+{
+ UTestHandler1* h1 = new UTestHandler1();
+
+ std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+
+ MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+ SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+
+ CHECK_EQUAL(0, h1->evt_msg);
+ delete h1;
+ h1 = nullptr;
+}
+
+TEST(mp_data_bus, register_event_helpers)
+{
+ MPSerializeFunc serialize_func = serialize_mock;
+ MPDeserializeFunc deserialize_func = deserialize_mock;
+ CHECK(0 == MockMPTransport::get_test_register_helpers_calls());
+
+ MPDataBus::register_event_helpers(pub_key1, DbUtIds::EVENT1, serialize_func, deserialize_func);
+ CHECK(1 == MockMPTransport::get_test_register_helpers_calls());
+
+ MPDataBus::register_event_helpers(pub_key1, DbUtIds::EVENT1, serialize_func, deserialize_func);
+ CHECK(2 == MockMPTransport::get_test_register_helpers_calls());
+}
+
+TEST(mp_data_bus, subscribe_and_receive)
+{
+ // one snort subscribes to it
+ UTestHandler1* h1 = new UTestHandler1();
+ MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+
+ // publish event from other snort
+ // since we don't have a way to publish events, we will use receive_message to simulate the event
+ // from a different snort
+ std::shared_ptr<UTestEvent> event = std::make_shared<UTestEvent>(100);
+
+ MPEventInfo event_info(event, MPEventType(DbUtIds::EVENT1), pub_id1);
+ SnortConfig::get_conf()->mp_dbus->receive_message(event_info);
+
+ CHECK_EQUAL(100, h1->evt_msg);
+
+ std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(200);
+
+ MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+ SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+
+ CHECK_EQUAL(200, h1->evt_msg);
+}
+
+TEST(mp_data_bus, two_subscribers_diff_event_and_receive)
+{
+ UTestHandler1* h1 = new UTestHandler1();
+ UTestHandler2* h2 = new UTestHandler2();
+
+ MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+ MPDataBus::subscribe(pub_key2, DbUtIds::EVENT2, h2);
+
+ std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+
+ MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+ SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+
+ CHECK_EQUAL(100, h1->evt_msg);
+ CHECK_EQUAL(1, h2->evt_msg);
+
+ std::shared_ptr<UTestEvent> event2 = std::make_shared<UTestEvent>(200);
+
+ MPEventInfo event_info2(event2, MPEventType(DbUtIds::EVENT2), pub_id2);
+ SnortConfig::get_conf()->mp_dbus->receive_message(event_info2);
+
+ CHECK_EQUAL(100, h1->evt_msg);
+ CHECK_EQUAL(200, h2->evt_msg);
+}
+
+TEST(mp_data_bus, two_subscribers_same_event_and_receive)
+{
+ UTestHandler1* h1 = new UTestHandler1();
+ UTestHandler2* h2 = new UTestHandler2();
+
+ MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+ MPDataBus::subscribe(pub_key2, DbUtIds::EVENT1, h2);
+
+ std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+
+ MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+ SnortConfig::get_conf()->mp_dbus->receive_message(event_info1);
+
+ CHECK_EQUAL(100, h1->evt_msg);
+ CHECK_EQUAL(1, h2->evt_msg);
+}
+
+TEST_GROUP(mp_data_bus_clone)
+{
+ unsigned pub_id1 = 0, pub_id2 = 0; // cppcheck-suppress variableScope
+ void setup() override
+ {
+ unsigned max_procs_val = 2;
+ snort_conf->mp_dbus = new MPDataBus();
+ snort_conf->mp_dbus->init(max_procs_val);
+ pub_id1 = MPDataBus::get_id(pub_key1);
+ pub_id2 = MPDataBus::get_id(pub_key2);
+ CHECK(MPDataBus::valid(pub_id1));
+ CHECK(MPDataBus::valid(pub_id2));
+ }
+
+ void teardown() override
+ {
+ delete snort_conf->mp_dbus;
+ }
+};
+//-------------------------------------------------------------------------
+
+TEST(mp_data_bus_clone, z_clone)
+{
+ unsigned pub_id1, pub_id2;
+ pub_id1 = MPDataBus::get_id(pub_key1);
+ pub_id2 = MPDataBus::get_id(pub_key2);
+ // subscribing to the events in the original mp_data_bus
+ // and then cloning the mp_data_bus
+ // and checking if the events are received in the cloned mp_data_bus
+ // and not in the original mp_data_bus
+ UTestHandler1* h1 = new UTestHandler1();
+ MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1);
+
+ UTestHandler2* h2 = new UTestHandler2();
+ MPDataBus::subscribe(pub_key2, DbUtIds::EVENT2, h2);
+
+ // original mp_data_bus should be deleted with previous SnortConfig
+ // deleted with exit handlers of Test framework
+ MPDataBus* mp_data_bus_cloned = new MPDataBus();
+ mp_data_bus_cloned->clone(*SnortConfig::get_conf()->mp_dbus, nullptr);
+
+ std::shared_ptr<UTestEvent> event1 = std::make_shared<UTestEvent>(100);
+ MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1);
+
+ mp_data_bus_cloned->receive_message(event_info1);
+
+ CHECK_EQUAL(100, h1->evt_msg);
+ CHECK_EQUAL(1, h2->evt_msg);
+
+ std::shared_ptr<UTestEvent> event2 = std::make_shared<UTestEvent>(200);
+
+ MPEventInfo event_info2(event2, MPEventType(DbUtIds::EVENT2), pub_id2);
+ mp_data_bus_cloned->receive_message(event_info2);
+ CHECK_EQUAL(100, h1->evt_msg);
+ CHECK_EQUAL(200, h2->evt_msg);
+}
+
+//-------------------------------------------------------------------------
+// main
+//-------------------------------------------------------------------------
+
+int main(int argc, char** argv)
+{
+ // event_map is not released until after cpputest gives up
+ MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
+ return CommandLineTestRunner::RunAllTests(argc, argv);
+}
\ No newline at end of file
PluginManager::list_plugins();
break;
}
+ MPTransportManager::term();
ModuleManager::term();
PluginManager::release_plugins();
ScriptManager::release_scripts();
- MPTransportManager::term();
delete SnortConfig::get_conf();
exit(0);
}
return true;
}
+//-------------------------------------------------------------------------
+// multiprocess data bus module
+//-------------------------------------------------------------------------
+
+static const Parameter mp_data_bus_params[] =
+{
+ { "max_eventq_size", Parameter::PT_INT, "100:65535", "1000",
+ "maximum events to queue" },
+
+ { "transport", Parameter::PT_STRING, nullptr, nullptr,
+ "transport to use for inter-process communication" },
+
+ { "debug", Parameter::PT_BOOL, nullptr, "false",
+ "enable debugging" },
+
+ { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+static int enable_debug(lua_State*)
+{
+ if(SnortConfig::get_conf()->mp_dbus)
+ SnortConfig::get_conf()->mp_dbus->enable_debug = true;
+
+ return 0;
+}
+
+static int disable_debug(lua_State*)
+{
+ if(SnortConfig::get_conf()->mp_dbus)
+ SnortConfig::get_conf()->mp_dbus->enable_debug = false;
+
+ return 0;
+}
+
+static const Command mp_dbus_cmds[] =
+{
+ {"enable", enable_debug, nullptr, "enable multiprocess data bus debugging"},
+ {"disable", disable_debug, nullptr, "disable multiprocess data bus debugging"},
+ {nullptr, nullptr, nullptr, nullptr}
+};
+
+#define mp_data_bus_help \
+ "configure multiprocess data bus"
+
+class MPDataBusModule : public Module
+{
+public:
+ MPDataBusModule() :
+ Module("mp_data_bus", mp_data_bus_help, mp_data_bus_params) { }
+
+ bool set(const char*, Value&, SnortConfig*) override;
+ bool begin(const char*, int, SnortConfig*) override;
+ bool end(const char*, int, SnortConfig*) override;
+ const Command* get_commands() const override;
+
+ Usage get_usage() const override
+ { return GLOBAL; }
+};
+
+bool MPDataBusModule::begin(const char*, int, SnortConfig*)
+{
+ return true;
+}
+
+bool MPDataBusModule::end(const char*, int, SnortConfig*)
+{
+ return true;
+}
+
+bool MPDataBusModule::set(const char*, Value& v, SnortConfig*)
+{
+ if ( v.is("max_eventq_size") )
+ {
+ MPDataBus::mp_max_eventq_size = v.get_uint32();
+ }
+ else if ( v.is("transport") )
+ {
+ MPDataBus::transport = v.get_string();
+ }
+ else if ( v.is("debug") )
+ {
+ MPDataBus::enable_debug = v.get_bool();
+ }
+ else
+ {
+ WarningMessage("MPDataBus: Unknown parameter '%s' in mp_data_bus module\n", v.get_name());
+ return false;
+ }
+ return true;
+}
+
+const Command* MPDataBusModule::get_commands() const
+{
+ return mp_dbus_cmds;
+}
+
//-------------------------------------------------------------------------
// reference module
//-------------------------------------------------------------------------
ModuleManager::add_module(new CodecModule);
ModuleManager::add_module(new DetectionModule);
ModuleManager::add_module(new MemoryModule);
+ ModuleManager::add_module(new MPDataBusModule);
ModuleManager::add_module(new PacketTracerModule);
ModuleManager::add_module(new PacketsModule);
ModuleManager::add_module(new ProcessModule);
// This call must be immediately after "SnortConfig::set_conf(sc)"
// since the first trace call may happen somewhere after this point
TraceApi::thread_init(sc->trace_config);
+ if (sc->max_procs > 1)
+ {
+ sc->mp_dbus = new MPDataBus();
+ sc->mp_dbus->init(sc->max_procs);
+ }
PluginManager::load_so_plugins(sc);
const SnortConfig* sc = SnortConfig::get_conf();
+ MPTransportManager::term();
IpsManager::global_term(sc);
HostAttributesManager::term();
host_cache.term();
PluginManager::release_plugins();
ScriptManager::release_scripts();
- MPTransportManager::term();
memory::MemoryCap::term();
detection_filter_term();
SnortConfig* sc = new SnortConfig(other_conf, iname);
sc->global_dbus->clone(*other_conf->global_dbus, iname);
- if (sc->max_procs > 1)
+
+ if (other_conf->mp_dbus != nullptr)
sc->mp_dbus->clone(*other_conf->mp_dbus, iname);
if ( fname )
policy_map = new PolicyMap;
thread_config = new ThreadConfig();
global_dbus = new DataBus();
- if (max_procs > 1)
- mp_dbus = new MPDataBus();
proto_ref = new ProtocolReference(protocol_reference);
so_rules = new SoRules;
if ( cloned )
{
delete global_dbus;
- if (max_procs > 1)
+ if (mp_dbus)
delete mp_dbus;
policy_map->set_cloned(true);
delete policy_map;
delete overlay_trace_config;
delete ha_config;
delete global_dbus;
- if (max_procs > 1)
+ if (mp_dbus)
delete mp_dbus;
delete profiler;
DataEvent* internal_event = nullptr;
(deserialize_func)((const char*)(msg->content + sizeof(MPTransportMessageHeader)), transport_message_header->data_length, internal_event);
- MPEventInfo event(internal_event, transport_message_header->event_id, transport_message_header->pub_id);
+ MPEventInfo event(std::shared_ptr<DataEvent> (internal_event), transport_message_header->event_id, transport_message_header->pub_id);
(transport_receive_handler)(event);
- delete internal_event;
}
delete msg;
}
transport_message.header.type = EVENT_MESSAGE;
transport_message.header.pub_id = event.pub_id;
transport_message.header.event_id = event.type;
-
- (serialize_func)(event.event, transport_message.data, &transport_message.header.data_length);
+
+ (serialize_func)(event.event.get(), transport_message.data, &transport_message.header.data_length);
for (auto &&sc_handler : this->side_channels)
{
auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length);
test_transport_message_1->register_event_helpers(0, 0, mp_helper_functions_mock);
test_transport_message_2->register_event_helpers(0, 0, mp_helper_functions_mock);
- MPEventInfo event(&test_event, 0, 0);
+ std::shared_ptr<TestDataEvent> event = std::make_shared<TestDataEvent>();
- auto res = test_transport_message_1->send_to_transport(event);
+ MPEventInfo event_info(event, 0, 0);
+ auto res = test_transport_message_1->send_to_transport(event_info);
CHECK(res == true);
CHECK(test_serialize_calls == 1);
- res = test_transport_message_2->send_to_transport(event);
+ res = test_transport_message_2->send_to_transport(event_info);
CHECK(res == true);
CHECK(test_serialize_calls == 2);
{
clear_test_calls();
- MPEventInfo event(&test_event, 0, 0);
+ std::shared_ptr<TestDataEvent> event_in = std::make_shared<TestDataEvent>();
+
+ MPEventInfo event(event_in, 0, 0);
auto res = test_transport_message_1->send_to_transport(event);
CHECK(res == false);