From: Umang Sharma (umasharm) Date: Sat, 26 Apr 2025 00:34:17 +0000 (+0000) Subject: Pull request #4692: mp_data_bus: core logic for mp databus X-Git-Tag: 3.7.4.0~10 X-Git-Url: http://git.ipfire.org/gitweb/gitweb.cgi?a=commitdiff_plain;h=fb512b6e3410775d7cc33a6c6a322ff68729e10a;p=thirdparty%2Fsnort3.git Pull request #4692: mp_data_bus: core logic for mp databus Merge in SNORT/snort3 from ~UMASHARM/snort3:mp_dbus to master Squashed commit of the following: commit 7fc8f62dac71aea14203346fe12d2d3bc9605f9c Author: Umang Sharma Date: Thu Apr 24 15:29:53 2025 -0400 mp_data_bus: core logic for mp databus --- diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.cc b/src/connectors/unixdomain_connector/unixdomain_connector.cc index f3e9b382b..24a75779b 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.cc +++ b/src/connectors/unixdomain_connector/unixdomain_connector.cc @@ -110,7 +110,7 @@ static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_ if (attempt_connection(sfd, path)) { // Connection successful UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx); - LogMessage("UnixDomainC: Connected to %s", path); + LogMessage("UnixDomainC: Connected to %s\n", path); if(update_handler) { unixdomain_conn->set_update_handler(update_handler); @@ -398,7 +398,7 @@ UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorCo return nullptr; } } - LogMessage("UnixDomainC: Connected to %s", path); + LogMessage("UnixDomainC: Connected to %s\n", path); UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx); unixdomain_conn->set_update_handler(update_handler); if(update_handler) @@ -591,11 +591,15 @@ void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnecto { ErrorMessage("UnixDomainC: Too many errors, stopping accept thread\n"); close(sock_fd); + sock_fd = -1; + should_accept = false; return; } int peer_sfd = accept(sock_fd, nullptr, nullptr); if (peer_sfd == -1) { + if (!should_accept) + return; error_count++; ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno)); continue; @@ -611,13 +615,16 @@ void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnecto void UnixDomainConnectorListener::stop_accepting_connections() { - if(should_accept) + should_accept = false; + if (sock_fd) { - should_accept = false; + shutdown(sock_fd, SHUT_RDWR); close(sock_fd); - if (accept_thread && accept_thread->joinable()) { - accept_thread->join(); - } + sock_fd = -1; + } + + if (accept_thread && accept_thread->joinable()) { + accept_thread->join(); delete accept_thread; accept_thread = nullptr; } diff --git a/src/framework/mp_data_bus.cc b/src/framework/mp_data_bus.cc index 1263b2e86..8e6477050 100644 --- a/src/framework/mp_data_bus.cc +++ b/src/framework/mp_data_bus.cc @@ -31,99 +31,277 @@ #include "pub_sub/intrinsic_event_ids.h" #include "utils/stats.h" #include "main/snort_types.h" -#include "managers/mp_transport_manager.h" #include "log/messages.h" +#include "helpers/ring.h" +#include "managers/mp_transport_manager.h" using namespace snort; +std::condition_variable MPDataBus::queue_cv; +std::mutex MPDataBus::queue_mutex; +uint32_t MPDataBus::mp_max_eventq_size = DEFAULT_MAX_EVENTQ_SIZE; +std::string MPDataBus::transport = DEFAULT_TRANSPORT; +bool MPDataBus::enable_debug = false; +MPTransport* MPDataBus::transport_layer = nullptr; + static std::unordered_map mp_pub_ids; //-------------------------------------------------------------------------- // public methods //------------------------------------------------------------------------- -MPDataBus::MPDataBus() = default; +MPDataBus::MPDataBus() : run_thread(true) +{ + mp_event_queue = new Ring>(mp_max_eventq_size); + start_worker_thread(); +} MPDataBus::~MPDataBus() { - // Clean up mp_pub_sub - for (auto& sublist : mp_pub_sub) + stop_worker_thread(); + + for (auto& [key, sublist] : mp_pub_sub) { for (auto* handler : sublist) { if (handler->cloned) + { handler->cloned = false; + } else - delete handler; + { + delete handler; + } } sublist.clear(); } mp_pub_sub.clear(); + delete mp_event_queue; + mp_event_queue = nullptr; } unsigned MPDataBus::init(int max_procs) { - UNUSED(max_procs); + if (max_procs <= 1) + { + return 1; + } + + transport_layer = MPTransportManager::get_transport(transport); + if (transport_layer == nullptr) + { + ErrorMessage("MPDataBus: Failed to get transport layer\n"); + return 0; + } + + transport_layer->register_receive_handler(std::bind(&MPDataBus::receive_message, this, std::placeholders::_1)); + transport_layer->init_connection(); + return 0; } void MPDataBus::clone(MPDataBus& from, const char* exclude_name) -{ - UNUSED(from); - UNUSED(exclude_name); +{ + from.stop_worker_thread(); + for (const auto& [key, sublist] : from.mp_pub_sub) + { + unsigned pid = key.first; + unsigned eid = key.second; + + for (auto* h : sublist) + { + if (!exclude_name || strcmp(exclude_name, h->module_name) != 0) + { + h->cloned = true; + _subscribe(pid, eid, h); + } + } + } +} + +unsigned MPDataBus::get_id(const PubKey& key) +{ + // Generate a unique hash for the publisher's name, + std::hash hasher; + unsigned unique_id = (hasher(key.publisher) % 10000); + + auto it = mp_pub_ids.find(key.publisher); + + if (it == mp_pub_ids.end()) + { + // Map the unique hash to the publisher + mp_pub_ids[key.publisher] = unique_id; + } + // Return the unique ID for the publisher + return mp_pub_ids[key.publisher]; } -// module subscribes to an event from a peer snort process void MPDataBus::subscribe(const PubKey& key, unsigned eid, DataHandler* h) { - UNUSED(key); - UNUSED(eid); - UNUSED(h); + if(! SnortConfig::get_conf()->mp_dbus) + { + ErrorMessage("MPDataBus: MPDataBus not initialized\n"); + return; + } + + SnortConfig::get_conf()->mp_dbus->_subscribe(key, eid, h); + MP_DATABUS_LOG("MPDataBus: Subscribed to event ID %u\n", eid); } -// publish event to all peer snort processes subscribed to the event -bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, DataEvent& e, Flow* f) +bool MPDataBus::publish(unsigned pub_id, unsigned evt_id, std::shared_ptr e, Flow*) { - // Publish implementation - UNUSED(pub_id); - UNUSED(evt_id); - UNUSED(e); - UNUSED(f); + std::shared_ptr event_info = + std::make_shared(std::move(e), MPEventType(evt_id), pub_id); + + const SnortConfig *sc = SnortConfig::get_conf(); + + if (sc->mp_dbus == nullptr) + { + ErrorMessage("MPDataBus: MPDataBus not initialized\n"); + return false; + } + + if (sc->mp_dbus->mp_event_queue != nullptr and !sc->mp_dbus->mp_event_queue->full() and !sc->mp_dbus->mp_event_queue->put(event_info)) { + ErrorMessage("MPDataBus: Failed to enqueue event for publisher ID %u and event ID %u\n", pub_id, evt_id); + return false; + } + + { + std::lock_guard lock(queue_mutex); + queue_cv.notify_one(); + } + + MP_DATABUS_LOG("MPDataBus: Event published for publisher ID %u and event ID %u\n", pub_id, evt_id); + return true; } -// register event helpers for serialization and deserialization of msg events -void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc* mp_serializer_helper, MPDeserializeFunc* mp_deserializer_helper) +void MPDataBus::register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc& mp_serializer_helper, MPDeserializeFunc& mp_deserializer_helper) { - UNUSED(key); - UNUSED(evt_id); - UNUSED(mp_serializer_helper); - UNUSED(mp_deserializer_helper); + if (!SnortConfig::get_conf()->mp_dbus && !SnortConfig::get_conf()->mp_dbus->transport_layer) + { + ErrorMessage("MPDataBus: MPDataBus or transport layer not initialized\n"); + return; + } + + unsigned pub_id = get_id(key); + + MPHelperFunctions helpers(mp_serializer_helper, mp_deserializer_helper); + + SnortConfig::get_conf()->mp_dbus->transport_layer->register_event_helpers(pub_id, evt_id, helpers); + MP_DATABUS_LOG("MPDataBus: Registered event helpers for event ID %u\n", evt_id); } // API for receiving the DataEvent and Event type from transport layer void MPDataBus::receive_message(const MPEventInfo& event_info) { - UNUSED(event_info); + DataEvent *e = event_info.event.get(); + unsigned evt_id = event_info.type; + unsigned pub_id = event_info.pub_id; + + MP_DATABUS_LOG("MPDataBus: Received message for publisher ID %u and event ID %u\n", pub_id, evt_id); + + _publish(pub_id, evt_id, *e, nullptr); } //-------------------------------------------------------------------------- // private methods //-------------------------------------------------------------------------- +void MPDataBus::process_event_queue() +{ + if (!mp_event_queue) { + return; + } + + std::unique_lock lock(queue_mutex); + + queue_cv.wait_for(lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP), [this]() { + return mp_event_queue != nullptr && !mp_event_queue->empty(); + }); + + lock.unlock(); + + while (!mp_event_queue->empty()) { + std::shared_ptr event_info = mp_event_queue->get(nullptr); + if (event_info) { + MP_DATABUS_LOG("MPDataBus: Processing event for publisher ID %u \n", + event_info->pub_id); + + transport_layer->send_to_transport(*event_info); + } + } +} + +void MPDataBus::worker_thread_func() +{ + while (run_thread.load() ) { + process_event_queue(); + } +} + +void MPDataBus::start_worker_thread() +{ + run_thread.store(true); + worker_thread = std::make_unique(&MPDataBus::worker_thread_func, this); +} + +void MPDataBus::stop_worker_thread() +{ + run_thread.store(false); + queue_cv.notify_one(); + + if (worker_thread && worker_thread->joinable()) + { + worker_thread->join(); + } + + worker_thread.reset(); +} + +static bool compare(DataHandler* a, DataHandler* b) +{ + if ( a->order and b->order ) + return a->order < b->order; + + if ( a->order ) + return true; + + return false; +} void MPDataBus::_subscribe(unsigned pid, unsigned eid, DataHandler* h) { - UNUSED(pid); - UNUSED(eid); - UNUSED(h); + std::pair key = {pid, eid}; + + SubList& subs = mp_pub_sub[key]; + subs.emplace_back(h); + + std::sort(subs.begin(), subs.end(), compare); } -void MPDataBus::_publish(unsigned int pid, unsigned int eid, DataEvent& e, Flow* f) +void MPDataBus::_subscribe(const PubKey& key, unsigned eid, DataHandler* h) { - UNUSED(pid); - UNUSED(eid); - UNUSED(e); - UNUSED(f); + unsigned pid = get_id(key); + _subscribe(pid, eid, h); +} + + +void MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f) +{ + std::pair key = {pid, eid}; + + auto it = mp_pub_sub.find(key); + if (it == mp_pub_sub.end()) + { + MP_DATABUS_LOG("MPDataBus: No subscribers for publisher ID %u and event ID %u\n", pid, eid); + return; + } + const SubList& subs = it->second; + + for (auto* handler : subs) + { + handler->handle(e, f); + } } diff --git a/src/framework/mp_data_bus.h b/src/framework/mp_data_bus.h index d69918c94..ee1ce340e 100644 --- a/src/framework/mp_data_bus.h +++ b/src/framework/mp_data_bus.h @@ -35,12 +35,31 @@ #include #include #include +#include +#include #include +#include +#include #include "main/snort_types.h" #include "data_bus.h" #include "framework/mp_transport.h" #include +#include "framework/mp_transport.h" + +#define DEFAULT_TRANSPORT "unix_transport" +#define DEFAULT_MAX_EVENTQ_SIZE 1000 +#define WORKER_THREAD_SLEEP 100 + +#define MP_DATABUS_LOG(msg, ...) do { \ + if (!MPDataBus::enable_debug) \ + break; \ + LogMessage(msg, __VA_ARGS__); \ + } while (0) + + +template +class Ring; namespace snort { @@ -59,13 +78,13 @@ typedef bool (*MPDeserializeFunc)(const char* buffer, uint16_t length, DataEvent // manner analogous to the approach used for intra-snort pub_sub. typedef unsigned MPEventType; -struct MPEventInfo +struct MPEventInfo { - MPEventType type; unsigned pub_id; - DataEvent* event; - MPEventInfo(DataEvent* e, MPEventType t, unsigned id = 0) - : type(t), pub_id(id), event(e) {} + MPEventType type; + std::shared_ptr event; + MPEventInfo(std::shared_ptr e, MPEventType t, unsigned id = 0) + : pub_id(id), type(t), event(std::move(e)) {} }; struct MPHelperFunctions { @@ -76,39 +95,78 @@ struct MPHelperFunctions { : serializer(s), deserializer(d) {} }; +struct pair_hash +{ + template + std::size_t operator()(const std::pair& pair) const + { + std::hash hash1; + std::hash hash2; + return hash1(pair.first) ^ (hash2(pair.second) << 1); + } +}; + class SO_PUBLIC MPDataBus { public: - MPDataBus(); + MPDataBus(); ~MPDataBus(); + + static uint32_t mp_max_eventq_size; + static std::string transport; + static bool enable_debug; - static unsigned init(int); + static MPTransport * transport_layer; + unsigned init(int); void clone(MPDataBus& from, const char* exclude_name = nullptr); - unsigned get_id(const PubKey& key) - { return DataBus::get_id(key); } + static unsigned get_id(const PubKey& key); - bool valid(unsigned pub_id) + static bool valid(unsigned pub_id) { return pub_id != 0; } - void subscribe(const PubKey& key, unsigned id, DataHandler* handler); + static void subscribe(const PubKey& key, unsigned id, DataHandler* handler); - bool publish(unsigned pub_id, unsigned evt_id, DataEvent& e, Flow* f = nullptr); + // API for publishing the DataEvent to the peer Snort processes + // The user needs to pass a shared_ptr to the DataEvent object as the third argument + // This is to ensure that the DataEvent object is not deleted before it is published + // or consumed by the worker thread + // and the shared_ptr will handle the memory management by reference counting + static bool publish(unsigned pub_id, unsigned evt_id, std::shared_ptr e, Flow* f = nullptr); - void register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc* mp_serializer_helper, MPDeserializeFunc* mp_deserializer_helper); + static void register_event_helpers(const PubKey& key, unsigned evt_id, MPSerializeFunc& mp_serializer_helper, MPDeserializeFunc& mp_deserializer_helper); // API for receiving the DataEvent and Event type from transport layer using EventInfo void receive_message(const MPEventInfo& event_info); + Ring>* get_event_queue() + { return mp_event_queue; } + private: void _subscribe(unsigned pid, unsigned eid, DataHandler* h); + void _subscribe(const PubKey& key, unsigned eid, DataHandler* h); + void _publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f); private: typedef std::vector SubList; - std::vector mp_pub_sub; + + std::unordered_map, SubList, pair_hash> mp_pub_sub; + + std::atomic run_thread; + std::unique_ptr worker_thread; + + Ring>* mp_event_queue; + + static std::condition_variable queue_cv; + static std::mutex queue_mutex; + + void start_worker_thread(); + void stop_worker_thread(); + void worker_thread_func(); + void process_event_queue(); +}; }; -} #endif diff --git a/src/framework/test/CMakeLists.txt b/src/framework/test/CMakeLists.txt index 0f554e0f6..5a16663a6 100644 --- a/src/framework/test/CMakeLists.txt +++ b/src/framework/test/CMakeLists.txt @@ -3,6 +3,10 @@ add_cpputest( data_bus_test SOURCES ../data_bus.cc ) +add_cpputest( mp_data_bus_test + SOURCES ../mp_data_bus.cc +) + # libapi_def.a is actually a text file with the preprocessed header source if ( ENABLE_UNIT_TESTS ) @@ -11,4 +15,4 @@ if ( ENABLE_UNIT_TESTS ) install(TARGETS api_def) endif () - +SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -pthread") diff --git a/src/framework/test/mp_data_bus_test.cc b/src/framework/test/mp_data_bus_test.cc new file mode 100644 index 000000000..658d0f894 --- /dev/null +++ b/src/framework/test/mp_data_bus_test.cc @@ -0,0 +1,445 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2019-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_data_bus_test.cc author Umang Sharma + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + + +#include "../mp_data_bus.h" +#include "../main/snort_config.h" +#include "utils/stats.h" +#include "helpers/ring.h" +#include + +#include +#include +#include + +#include +#include + +using namespace snort; + +namespace snort +{ +static SnortConfig s_conf; + +THREAD_LOCAL SnortConfig* snort_conf = &s_conf; + +void ErrorMessage(const char*, ...) { } +void LogMessage(const char*, ...) { } + +const SnortConfig* SnortConfig::get_conf() +{ return snort_conf; } + +SnortConfig::SnortConfig(const SnortConfig* const, const char*) +: daq_config(nullptr), thread_config(nullptr) +{ } + +SnortConfig::~SnortConfig() +{ } +} + +class MockMPTransport : public MPTransport +{ +public: + MockMPTransport() = default; + ~MockMPTransport() override = default; + + static int get_count() + { + return count; + } + + static int get_test_register_helpers_calls() + { + return test_register_helpers_calls; + } + + bool send_to_transport(MPEventInfo&) override + { + count++; + return true; + } + + void register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&) override + { + test_register_helpers_calls++; + return; + } + + void init_connection() override + { + return; + } + + void register_receive_handler(const TransportReceiveEventHandler&) override + { + return; + } + + void unregister_receive_handler() override + { + return; + } + + void thread_init() override + { + return; + } + + void thread_term() override + { + return; + } + + bool configure(const SnortConfig*) override + { + return true; + } + + void enable_logging() override + { + return; + } + + void disable_logging() override + { + return; + } + + bool is_logging_enabled() override + { + return true; + } +private: + inline static int count = 0; + inline static int test_register_helpers_calls = 0; +}; + +static MockMPTransport mp_transport_pointer; + +MPTransport* MPTransportManager::get_transport(const std::string&) +{ + return &mp_transport_pointer; +} + +class UTestEvent : public DataEvent +{ +public: + UTestEvent(int m) : msg(m) { } + + int get_message() const + { return msg; } + +private: + int msg; +}; + +bool serialize_mock(DataEvent*, char*& buffer, uint16_t* length) +{ + buffer = new char[9]; + *length = 9; + memcpy(buffer, "test_data", 9); + return true; +} + +bool deserialize_mock(const char*, uint16_t length, DataEvent*& event) +{ + event = new UTestEvent(length); + return true; +} + +class UTestHandler1 : public DataHandler +{ +public: + UTestHandler1(unsigned u = 0) : DataHandler("unit_test1") + { if (u) order = u; } + + void handle(DataEvent& event, Flow*) override; + + int evt_msg = 0; +}; + +class UTestHandler2 : public DataHandler +{ +public: + UTestHandler2(unsigned u = 0) : DataHandler("unit_test2") + { if (u) order = u; } + + void handle(DataEvent& event, Flow*) override; + + int evt_msg = 1; +}; + +void UTestHandler1::handle(DataEvent& event, Flow*) +{ + UTestEvent* evt = static_cast(&event); + if (evt) + { + evt_msg = evt->get_message(); + } +} + +void UTestHandler2::handle(DataEvent& event, Flow*) +{ + UTestEvent* evt = static_cast(&event); + if (evt) + { + evt_msg = evt->get_message(); + } +} +//-------------------------------------------------------------------------- + + +struct DbUtIds { enum : unsigned { EVENT1, EVENT2, num_ids }; }; + +const PubKey pub_key1 { "mp_ut1", DbUtIds::num_ids }; +const PubKey pub_key2 { "mp_ut2", DbUtIds::num_ids }; + +//-------------------------------------------------------------------------- +// Test Group +//-------------------------------------------------------------------------- + +TEST_GROUP(mp_data_bus_pub) +{ + unsigned pub_id1 = 0; // cppcheck-suppress variableScope + MPDataBus* mp_dbus = nullptr; + void setup() override + { + mp_dbus = new MPDataBus(); + mp_dbus->init(2); + pub_id1 = MPDataBus::get_id(pub_key1); + CHECK(MPDataBus::valid(pub_id1)); + + snort_conf->mp_dbus = mp_dbus; + } + + void teardown() override + { } +}; + +TEST(mp_data_bus_pub, publish) +{ + CHECK_TRUE(mp_dbus->get_event_queue()->empty()); + CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0); + + std::shared_ptr event = std::make_shared(100); + + mp_dbus->publish(pub_id1, DbUtIds::EVENT1, event); + + std::this_thread::sleep_for(std::chrono::milliseconds(150)); + + delete mp_dbus; + + CHECK(1 == MockMPTransport::get_count()); +} + +TEST_GROUP(mp_data_bus) +{ + unsigned pub_id1 = 0, pub_id2 = 0; // cppcheck-suppress variableScope + + void setup() override + { + unsigned max_procs_val = 2; + snort_conf->mp_dbus = new MPDataBus(); + snort_conf->mp_dbus->init(max_procs_val); + pub_id1 = MPDataBus::get_id(pub_key1); + pub_id2 = MPDataBus::get_id(pub_key2); + CHECK(MPDataBus::valid(pub_id1)); + CHECK(MPDataBus::valid(pub_id2)); + } + + void teardown() override + { + delete snort_conf->mp_dbus; + } +}; + +//-------------------------------------------------------------------------- +// Tests +//-------------------------------------------------------------------------- + +TEST(mp_data_bus, init) +{ + CHECK(SnortConfig::get_conf()->mp_dbus != nullptr); + CHECK(SnortConfig::get_conf()->mp_dbus->get_event_queue() != nullptr); +} + +TEST(mp_data_bus, no_subscribers_and_receive) +{ + UTestHandler1* h1 = new UTestHandler1(); + + std::shared_ptr event1 = std::make_shared(100); + + MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1); + SnortConfig::get_conf()->mp_dbus->receive_message(event_info1); + + CHECK_EQUAL(0, h1->evt_msg); + delete h1; + h1 = nullptr; +} + +TEST(mp_data_bus, register_event_helpers) +{ + MPSerializeFunc serialize_func = serialize_mock; + MPDeserializeFunc deserialize_func = deserialize_mock; + CHECK(0 == MockMPTransport::get_test_register_helpers_calls()); + + MPDataBus::register_event_helpers(pub_key1, DbUtIds::EVENT1, serialize_func, deserialize_func); + CHECK(1 == MockMPTransport::get_test_register_helpers_calls()); + + MPDataBus::register_event_helpers(pub_key1, DbUtIds::EVENT1, serialize_func, deserialize_func); + CHECK(2 == MockMPTransport::get_test_register_helpers_calls()); +} + +TEST(mp_data_bus, subscribe_and_receive) +{ + // one snort subscribes to it + UTestHandler1* h1 = new UTestHandler1(); + MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1); + + // publish event from other snort + // since we don't have a way to publish events, we will use receive_message to simulate the event + // from a different snort + std::shared_ptr event = std::make_shared(100); + + MPEventInfo event_info(event, MPEventType(DbUtIds::EVENT1), pub_id1); + SnortConfig::get_conf()->mp_dbus->receive_message(event_info); + + CHECK_EQUAL(100, h1->evt_msg); + + std::shared_ptr event1 = std::make_shared(200); + + MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1); + SnortConfig::get_conf()->mp_dbus->receive_message(event_info1); + + CHECK_EQUAL(200, h1->evt_msg); +} + +TEST(mp_data_bus, two_subscribers_diff_event_and_receive) +{ + UTestHandler1* h1 = new UTestHandler1(); + UTestHandler2* h2 = new UTestHandler2(); + + MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1); + MPDataBus::subscribe(pub_key2, DbUtIds::EVENT2, h2); + + std::shared_ptr event1 = std::make_shared(100); + + MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1); + SnortConfig::get_conf()->mp_dbus->receive_message(event_info1); + + CHECK_EQUAL(100, h1->evt_msg); + CHECK_EQUAL(1, h2->evt_msg); + + std::shared_ptr event2 = std::make_shared(200); + + MPEventInfo event_info2(event2, MPEventType(DbUtIds::EVENT2), pub_id2); + SnortConfig::get_conf()->mp_dbus->receive_message(event_info2); + + CHECK_EQUAL(100, h1->evt_msg); + CHECK_EQUAL(200, h2->evt_msg); +} + +TEST(mp_data_bus, two_subscribers_same_event_and_receive) +{ + UTestHandler1* h1 = new UTestHandler1(); + UTestHandler2* h2 = new UTestHandler2(); + + MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1); + MPDataBus::subscribe(pub_key2, DbUtIds::EVENT1, h2); + + std::shared_ptr event1 = std::make_shared(100); + + MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1); + SnortConfig::get_conf()->mp_dbus->receive_message(event_info1); + + CHECK_EQUAL(100, h1->evt_msg); + CHECK_EQUAL(1, h2->evt_msg); +} + +TEST_GROUP(mp_data_bus_clone) +{ + unsigned pub_id1 = 0, pub_id2 = 0; // cppcheck-suppress variableScope + void setup() override + { + unsigned max_procs_val = 2; + snort_conf->mp_dbus = new MPDataBus(); + snort_conf->mp_dbus->init(max_procs_val); + pub_id1 = MPDataBus::get_id(pub_key1); + pub_id2 = MPDataBus::get_id(pub_key2); + CHECK(MPDataBus::valid(pub_id1)); + CHECK(MPDataBus::valid(pub_id2)); + } + + void teardown() override + { + delete snort_conf->mp_dbus; + } +}; +//------------------------------------------------------------------------- + +TEST(mp_data_bus_clone, z_clone) +{ + unsigned pub_id1, pub_id2; + pub_id1 = MPDataBus::get_id(pub_key1); + pub_id2 = MPDataBus::get_id(pub_key2); + // subscribing to the events in the original mp_data_bus + // and then cloning the mp_data_bus + // and checking if the events are received in the cloned mp_data_bus + // and not in the original mp_data_bus + UTestHandler1* h1 = new UTestHandler1(); + MPDataBus::subscribe(pub_key1, DbUtIds::EVENT1, h1); + + UTestHandler2* h2 = new UTestHandler2(); + MPDataBus::subscribe(pub_key2, DbUtIds::EVENT2, h2); + + // original mp_data_bus should be deleted with previous SnortConfig + // deleted with exit handlers of Test framework + MPDataBus* mp_data_bus_cloned = new MPDataBus(); + mp_data_bus_cloned->clone(*SnortConfig::get_conf()->mp_dbus, nullptr); + + std::shared_ptr event1 = std::make_shared(100); + MPEventInfo event_info1(event1, MPEventType(DbUtIds::EVENT1), pub_id1); + + mp_data_bus_cloned->receive_message(event_info1); + + CHECK_EQUAL(100, h1->evt_msg); + CHECK_EQUAL(1, h2->evt_msg); + + std::shared_ptr event2 = std::make_shared(200); + + MPEventInfo event_info2(event2, MPEventType(DbUtIds::EVENT2), pub_id2); + mp_data_bus_cloned->receive_message(event_info2); + CHECK_EQUAL(100, h1->evt_msg); + CHECK_EQUAL(200, h2->evt_msg); +} + +//------------------------------------------------------------------------- +// main +//------------------------------------------------------------------------- + +int main(int argc, char** argv) +{ + // event_map is not released until after cpputest gives up + MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); + return CommandLineTestRunner::RunAllTests(argc, argv); +} \ No newline at end of file diff --git a/src/main/help.cc b/src/main/help.cc index 8d6f7eb42..02597c88a 100644 --- a/src/main/help.cc +++ b/src/main/help.cc @@ -209,10 +209,10 @@ enum HelpType PluginManager::list_plugins(); break; } + MPTransportManager::term(); ModuleManager::term(); PluginManager::release_plugins(); ScriptManager::release_scripts(); - MPTransportManager::term(); delete SnortConfig::get_conf(); exit(0); } diff --git a/src/main/modules.cc b/src/main/modules.cc index 9106c3238..dbaa053fd 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -392,6 +392,102 @@ bool ClassificationsModule::set(const char*, Value& v, SnortConfig*) return true; } +//------------------------------------------------------------------------- +// multiprocess data bus module +//------------------------------------------------------------------------- + +static const Parameter mp_data_bus_params[] = +{ + { "max_eventq_size", Parameter::PT_INT, "100:65535", "1000", + "maximum events to queue" }, + + { "transport", Parameter::PT_STRING, nullptr, nullptr, + "transport to use for inter-process communication" }, + + { "debug", Parameter::PT_BOOL, nullptr, "false", + "enable debugging" }, + + { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } +}; + +static int enable_debug(lua_State*) +{ + if(SnortConfig::get_conf()->mp_dbus) + SnortConfig::get_conf()->mp_dbus->enable_debug = true; + + return 0; +} + +static int disable_debug(lua_State*) +{ + if(SnortConfig::get_conf()->mp_dbus) + SnortConfig::get_conf()->mp_dbus->enable_debug = false; + + return 0; +} + +static const Command mp_dbus_cmds[] = +{ + {"enable", enable_debug, nullptr, "enable multiprocess data bus debugging"}, + {"disable", disable_debug, nullptr, "disable multiprocess data bus debugging"}, + {nullptr, nullptr, nullptr, nullptr} +}; + +#define mp_data_bus_help \ + "configure multiprocess data bus" + +class MPDataBusModule : public Module +{ +public: + MPDataBusModule() : + Module("mp_data_bus", mp_data_bus_help, mp_data_bus_params) { } + + bool set(const char*, Value&, SnortConfig*) override; + bool begin(const char*, int, SnortConfig*) override; + bool end(const char*, int, SnortConfig*) override; + const Command* get_commands() const override; + + Usage get_usage() const override + { return GLOBAL; } +}; + +bool MPDataBusModule::begin(const char*, int, SnortConfig*) +{ + return true; +} + +bool MPDataBusModule::end(const char*, int, SnortConfig*) +{ + return true; +} + +bool MPDataBusModule::set(const char*, Value& v, SnortConfig*) +{ + if ( v.is("max_eventq_size") ) + { + MPDataBus::mp_max_eventq_size = v.get_uint32(); + } + else if ( v.is("transport") ) + { + MPDataBus::transport = v.get_string(); + } + else if ( v.is("debug") ) + { + MPDataBus::enable_debug = v.get_bool(); + } + else + { + WarningMessage("MPDataBus: Unknown parameter '%s' in mp_data_bus module\n", v.get_name()); + return false; + } + return true; +} + +const Command* MPDataBusModule::get_commands() const +{ + return mp_dbus_cmds; +} + //------------------------------------------------------------------------- // reference module //------------------------------------------------------------------------- @@ -1969,6 +2065,7 @@ void module_init() ModuleManager::add_module(new CodecModule); ModuleManager::add_module(new DetectionModule); ModuleManager::add_module(new MemoryModule); + ModuleManager::add_module(new MPDataBusModule); ModuleManager::add_module(new PacketTracerModule); ModuleManager::add_module(new PacketsModule); ModuleManager::add_module(new ProcessModule); diff --git a/src/main/snort.cc b/src/main/snort.cc index 2f889ddc2..b735a25b6 100644 --- a/src/main/snort.cc +++ b/src/main/snort.cc @@ -171,6 +171,11 @@ void Snort::init(int argc, char** argv) // This call must be immediately after "SnortConfig::set_conf(sc)" // since the first trace call may happen somewhere after this point TraceApi::thread_init(sc->trace_config); + if (sc->max_procs > 1) + { + sc->mp_dbus = new MPDataBus(); + sc->mp_dbus->init(sc->max_procs); + } PluginManager::load_so_plugins(sc); @@ -330,6 +335,7 @@ void Snort::term() const SnortConfig* sc = SnortConfig::get_conf(); + MPTransportManager::term(); IpsManager::global_term(sc); HostAttributesManager::term(); @@ -374,7 +380,6 @@ void Snort::term() host_cache.term(); PluginManager::release_plugins(); ScriptManager::release_scripts(); - MPTransportManager::term(); memory::MemoryCap::term(); detection_filter_term(); @@ -575,7 +580,8 @@ SnortConfig* Snort::get_updated_policy( SnortConfig* sc = new SnortConfig(other_conf, iname); sc->global_dbus->clone(*other_conf->global_dbus, iname); - if (sc->max_procs > 1) + + if (other_conf->mp_dbus != nullptr) sc->mp_dbus->clone(*other_conf->mp_dbus, iname); if ( fname ) diff --git a/src/main/snort_config.cc b/src/main/snort_config.cc index 27a2d5ec8..5db70d17b 100644 --- a/src/main/snort_config.cc +++ b/src/main/snort_config.cc @@ -199,8 +199,6 @@ void SnortConfig::init(const SnortConfig* const other_conf, ProtocolReference* p policy_map = new PolicyMap; thread_config = new ThreadConfig(); global_dbus = new DataBus(); - if (max_procs > 1) - mp_dbus = new MPDataBus(); proto_ref = new ProtocolReference(protocol_reference); so_rules = new SoRules; @@ -264,7 +262,7 @@ SnortConfig::~SnortConfig() if ( cloned ) { delete global_dbus; - if (max_procs > 1) + if (mp_dbus) delete mp_dbus; policy_map->set_cloned(true); delete policy_map; @@ -325,7 +323,7 @@ SnortConfig::~SnortConfig() delete overlay_trace_config; delete ha_config; delete global_dbus; - if (max_procs > 1) + if (mp_dbus) delete mp_dbus; delete profiler; diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc index 6b15acad1..8e40562f7 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc @@ -101,11 +101,10 @@ void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg) DataEvent* internal_event = nullptr; (deserialize_func)((const char*)(msg->content + sizeof(MPTransportMessageHeader)), transport_message_header->data_length, internal_event); - MPEventInfo event(internal_event, transport_message_header->event_id, transport_message_header->pub_id); + MPEventInfo event(std::shared_ptr (internal_event), transport_message_header->event_id, transport_message_header->pub_id); (transport_receive_handler)(event); - delete internal_event; } delete msg; } @@ -151,9 +150,9 @@ bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event) transport_message.header.type = EVENT_MESSAGE; transport_message.header.pub_id = event.pub_id; transport_message.header.event_id = event.type; - - (serialize_func)(event.event, transport_message.data, &transport_message.header.data_length); + + (serialize_func)(event.event.get(), transport_message.data, &transport_message.header.data_length); for (auto &&sc_handler : this->side_channels) { auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length); diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc index b3a9f5b96..84cf84979 100644 --- a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc +++ b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc @@ -454,15 +454,16 @@ TEST(unix_transport_test_messaging, send_to_transport_biderectional) test_transport_message_1->register_event_helpers(0, 0, mp_helper_functions_mock); test_transport_message_2->register_event_helpers(0, 0, mp_helper_functions_mock); - MPEventInfo event(&test_event, 0, 0); + std::shared_ptr event = std::make_shared(); - auto res = test_transport_message_1->send_to_transport(event); + MPEventInfo event_info(event, 0, 0); + auto res = test_transport_message_1->send_to_transport(event_info); CHECK(res == true); CHECK(test_serialize_calls == 1); - res = test_transport_message_2->send_to_transport(event); + res = test_transport_message_2->send_to_transport(event_info); CHECK(res == true); CHECK(test_serialize_calls == 2); @@ -479,7 +480,9 @@ TEST(unix_transport_test_messaging, send_to_transport_no_helpers) { clear_test_calls(); - MPEventInfo event(&test_event, 0, 0); + std::shared_ptr event_in = std::make_shared(); + + MPEventInfo event(event_in, 0, 0); auto res = test_transport_message_1->send_to_transport(event); CHECK(res == false);