From 8981622107bc6089c6b558a3afa1920b0f8daff6 Mon Sep 17 00:00:00 2001 From: "Oleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco)" Date: Thu, 24 Apr 2025 18:16:27 +0000 Subject: [PATCH] Pull request #4695: mp_unix_transport: mp_transport plugin type, implementation of unix domain name based mp transport Merge in SNORT/snort3 from ~OSTEPANO/snort3:mp_transport_layer to master Squashed commit of the following: commit edb3158929808ca911049623f5e676554134eab7 Author: Oleksandr Stepanov Date: Thu Mar 27 16:06:10 2025 -0400 mp_unix_transport: mp_transport plugin type, implementation of unix domain name based mp transport --- src/CMakeLists.txt | 3 + .../test/unixdomain_connector_test.cc | 74 ++- .../unixdomain_connector.cc | 175 +++++- .../unixdomain_connector.h | 33 ++ .../unixdomain_connector_config.h | 12 + src/framework/CMakeLists.txt | 1 + src/framework/base_api.h | 1 + src/framework/mp_data_bus.cc | 3 + src/framework/mp_data_bus.h | 15 +- src/framework/mp_transport.h | 81 +++ src/main/help.cc | 2 + src/main/snort.cc | 4 + src/managers/CMakeLists.txt | 2 + src/managers/mp_transport_manager.cc | 105 ++++ src/managers/mp_transport_manager.h | 51 ++ src/managers/plugin_manager.cc | 14 +- src/managers/test/CMakeLists.txt | 7 + .../test/mp_transport_manager_test.cc | 197 +++++++ src/mp_transport/CMakeLists.txt | 6 + src/mp_transport/dev_notes.txt | 5 + src/mp_transport/mp_transports.cc | 36 ++ src/mp_transport/mp_transports.h | 25 + .../mp_unix_transport/CMakeLists.txt | 24 + .../mp_unix_transport/dev_notes.txt | 13 + .../mp_unix_transport/mp_unix_transport.cc | 447 ++++++++++++++++ .../mp_unix_transport/mp_unix_transport.h | 132 +++++ .../mp_unix_transport_module.cc | 130 +++++ .../mp_unix_transport_module.h | 74 +++ .../mp_unix_transport/test/CMakeLists.txt | 24 + .../test/unix_transport_module_test.cc | 209 ++++++++ .../test/unix_transport_test.cc | 505 ++++++++++++++++++ src/side_channel/side_channel.cc | 3 + 32 files changed, 2381 insertions(+), 32 deletions(-) create mode 100644 src/framework/mp_transport.h create mode 100644 src/managers/mp_transport_manager.cc create mode 100644 src/managers/mp_transport_manager.h create mode 100644 src/managers/test/mp_transport_manager_test.cc create mode 100644 src/mp_transport/CMakeLists.txt create mode 100644 src/mp_transport/dev_notes.txt create mode 100644 src/mp_transport/mp_transports.cc create mode 100644 src/mp_transport/mp_transports.h create mode 100644 src/mp_transport/mp_unix_transport/CMakeLists.txt create mode 100644 src/mp_transport/mp_unix_transport/dev_notes.txt create mode 100644 src/mp_transport/mp_unix_transport/mp_unix_transport.cc create mode 100644 src/mp_transport/mp_unix_transport/mp_unix_transport.h create mode 100644 src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc create mode 100644 src/mp_transport/mp_unix_transport/mp_unix_transport_module.h create mode 100644 src/mp_transport/mp_unix_transport/test/CMakeLists.txt create mode 100644 src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc create mode 100644 src/mp_transport/mp_unix_transport/test/unix_transport_test.cc diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index bd0803cf5..de2df2eb4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -137,6 +137,7 @@ add_subdirectory(policy_selectors) add_subdirectory(search_engines) add_subdirectory(side_channel) add_subdirectory(connectors) +add_subdirectory(mp_transport) # FIXIT-L Delegate building out the target objects list to subdirectories @@ -198,8 +199,10 @@ add_executable( snort $ $ $ + $ $ $ + $ $ ${STATIC_CODEC_PLUGINS} ${STATIC_NETWORK_INSPECTOR_PLUGINS} diff --git a/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc b/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc index eba383570..269e54782 100644 --- a/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc +++ b/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc @@ -148,8 +148,19 @@ int socket (int, int, int) { return s_socket_return; } int bind (int, const struct sockaddr*, socklen_t) { return s_bind_return; } int listen (int, int) { return s_listen_return; } #endif - -int accept (int, struct sockaddr*, socklen_t*) { return s_accept_return; } +static bool use_test_accept_counter = false; +static uint test_accept_counter = 0; +int accept (int, struct sockaddr*, socklen_t*) +{ + if ( use_test_accept_counter ) + { + if ( test_accept_counter == 0 ) + return -1; + else + test_accept_counter--; + } + return s_accept_return; +} int close (int) { return 0; } static void set_normal_status() @@ -753,6 +764,65 @@ TEST(unixdomain_connector_no_tinit_tterm_call, receive_recv_body_closed) delete[] message; } +static const char* test_listener_path = "/tmp/test_path"; +UnixDomainConnectorListener* test_listener = nullptr; + +TEST_GROUP(unixdomain_connector_listener) +{ + void setup() override + { + test_listener = new UnixDomainConnectorListener(test_listener_path); + } + + void teardown() override + { + if (test_listener) + { + delete test_listener; + test_listener = nullptr; + } + } +}; + +UnixDomainConnector* test_listener_connector = nullptr; +UnixDomainConnectorConfig* test_listener_config = nullptr; + +void connection_callback(UnixDomainConnector* c, UnixDomainConnectorConfig* conf) +{ + assert(c != nullptr); + assert(conf != nullptr); + + test_listener_connector = c; + test_listener_config = conf; +} + +TEST(unixdomain_connector_listener, listener_accept_stop) +{ + UnixDomainConnectorConfig cfg; + cfg.direction = Connector::CONN_DUPLEX; + cfg.connector_name = "unixdomain"; + cfg.paths.push_back(test_listener_path); + cfg.setup = UnixDomainConnectorConfig::Setup::ANSWER; + cfg.async_receive = true; + + use_test_accept_counter = true; + test_accept_counter = 1; + + test_listener->start_accepting_connections(&connection_callback, &cfg); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + test_listener->stop_accepting_connections(); + + CHECK(test_listener_connector != nullptr); + CHECK(test_listener_config != nullptr); + CHECK(test_listener_connector->get_connector_direction() == Connector::CONN_DUPLEX); + CHECK(test_listener_config->async_receive == true); + + delete test_listener_connector; + test_listener_connector = nullptr; + delete test_listener_config; + test_listener_config = nullptr; +} + int main(int argc, char** argv) { int return_value = CommandLineTestRunner::RunAllTests(argc, argv); diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.cc b/src/connectors/unixdomain_connector/unixdomain_connector.cc index 19f7ae8fa..f3e9b382b 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.cc +++ b/src/connectors/unixdomain_connector/unixdomain_connector.cc @@ -83,37 +83,61 @@ static bool attempt_connection(int& sfd, const char* path) { } // Function to handle connection retries -static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx) { - ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr); +static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) { + if(update_handler) + update_handler(nullptr, (cfg.conn_retries > 0)); + else + ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr); - if ( cfg.setup == UnixDomainConnectorConfig::Setup::CALL and cfg.conn_retries) { + if (cfg.conn_retries) + { const auto& paths = cfg.paths; if (idx >= paths.size()) return; - - uint32_t retry_count = 0; + const char* path = paths[idx].c_str(); - while (retry_count < cfg.max_retries) { - int sfd; - if (attempt_connection(sfd, path)) { - // Connection successful - UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx); - LogMessage("UnixDomainC: Connected to %s", path); - ConnectorManager::update_thread_connector(cfg.connector_name, idx, unixdomain_conn); - break; + if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL) + { + LogMessage("UnixDomainC: Attempting to reconnect to %s\n", cfg.paths[idx].c_str()); + + uint32_t retry_count = 0; + + while (retry_count < cfg.max_retries) { + int sfd; + if (attempt_connection(sfd, path)) { + // Connection successful + UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx); + LogMessage("UnixDomainC: Connected to %s", path); + if(update_handler) + { + unixdomain_conn->set_update_handler(update_handler); + update_handler(unixdomain_conn, false); + } + else + ConnectorManager::update_thread_connector(cfg.connector_name, idx, unixdomain_conn); + break; + } + + std::this_thread::sleep_for(std::chrono::seconds(cfg.retry_interval)); + retry_count++; } - - std::this_thread::sleep_for(std::chrono::seconds(cfg.retry_interval)); - retry_count++; + } + else if (cfg.setup == UnixDomainConnectorConfig::Setup::ANSWER) + { + return; + } + else + { + LogMessage("UnixDomainC: Unexpected setup type at retry connection\n"); } } } -static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx) { - std::thread retry_thread(connection_retry_handler, cfg, idx); +static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) { + std::thread retry_thread(connection_retry_handler, cfg, idx, update_handler); retry_thread.detach(); } @@ -254,7 +278,7 @@ void UnixDomainConnector::process_receive() { sock_fd = -1; } - start_retry_thread(cfg, instance_id); + start_retry_thread(cfg, instance_id, update_handler); return; } else if (rval > 0 && pfds[0].revents & POLLIN) { @@ -263,9 +287,23 @@ void UnixDomainConnector::process_receive() { ErrorMessage("UnixDomainC: Input Thread: overrun\n"); delete connector_msg; } + if(message_received_handler) + { + message_received_handler(); + } } } +void UnixDomainConnector::set_update_handler(UnixDomainConnectorUpdateHandler handler) +{ + update_handler = std::move(handler); +} + +void UnixDomainConnector::set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler) +{ + message_received_handler = std::move(handler); +} + void UnixDomainConnector::receive_processing_thread() { while (run_thread.load(std::memory_order_relaxed)) { process_receive(); @@ -347,12 +385,12 @@ static void mod_dtor(Module* m) { delete m; } -static UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx) { +UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler) { int sfd; if (!attempt_connection(sfd, path)) { if (cfg.conn_retries) { // Spawn a new thread to handle connection retries - start_retry_thread(cfg, idx); + start_retry_thread(cfg, idx, update_handler); return nullptr; // Return nullptr as the connection is not yet established } else { @@ -362,6 +400,10 @@ static UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConn } LogMessage("UnixDomainC: Connected to %s", path); UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx); + unixdomain_conn->set_update_handler(update_handler); + if(update_handler) + update_handler(unixdomain_conn, false); + return unixdomain_conn; } @@ -400,7 +442,7 @@ static UnixDomainConnector* unixdomain_connector_tinit_answer(const UnixDomainCo LogMessage("UnixDomainC: Accepted connection from %s \n", path); return new UnixDomainConnector(cfg, peer_sfd, idx); -} +} static bool is_valid_path(const std::string& path) { if (path.empty()) { @@ -491,3 +533,92 @@ const BaseApi* unixdomain_connector[] = &unixdomain_connector_api.base, nullptr }; + +UnixDomainConnectorListener::UnixDomainConnectorListener(const char *path) +{ + assert(path); + + sock_path = strdup(path); + sock_fd = 0; + accept_thread = nullptr; + should_accept = false; +} + +UnixDomainConnectorListener::~UnixDomainConnectorListener() +{ + stop_accepting_connections(); + free(sock_path); + sock_path = nullptr; +} + +void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnectorAcceptHandler handler, UnixDomainConnectorConfig* config) +{ + assert(accept_thread == nullptr); + assert(sock_path); + + should_accept = true; + accept_thread = new std::thread([this, handler, config]() + { + sock_fd = socket(AF_UNIX, SOCK_STREAM, 0); + if (sock_fd == -1) { + ErrorMessage("UnixDomainC: socket error: %s \n", strerror(errno)); + return; + } + + struct sockaddr_un addr; + memset(&addr, 0, sizeof(struct sockaddr_un)); + addr.sun_family = AF_UNIX; + strncpy(addr.sun_path, sock_path, sizeof(addr.sun_path) - 1); + + unlink(sock_path); + + if (bind(sock_fd, (struct sockaddr*)&addr, sizeof(struct sockaddr_un)) == -1) { + ErrorMessage("UnixDomainC: bind error: %s \n", strerror(errno)); + close(sock_fd); + return; + } + + if (listen(sock_fd, 10) == -1) { + ErrorMessage("UnixDomainC: listen error: %s \n", strerror(errno)); + close(sock_fd); + return; + } + + ushort error_count = 0; + + while (should_accept) { + if(error_count > 10) + { + ErrorMessage("UnixDomainC: Too many errors, stopping accept thread\n"); + close(sock_fd); + return; + } + int peer_sfd = accept(sock_fd, nullptr, nullptr); + if (peer_sfd == -1) + { + error_count++; + ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno)); + continue; + } + error_count = 0; + auto config_copy = new UnixDomainConnectorConfig(*config); + auto unix_conn = new UnixDomainConnector(*config_copy, peer_sfd, 0); + handler(unix_conn, config_copy); + } + }); + +} + +void UnixDomainConnectorListener::stop_accepting_connections() +{ + if(should_accept) + { + should_accept = false; + close(sock_fd); + if (accept_thread && accept_thread->joinable()) { + accept_thread->join(); + } + delete accept_thread; + accept_thread = nullptr; + } +} diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.h b/src/connectors/unixdomain_connector/unixdomain_connector.h index 28abd1b65..08339e30f 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.h +++ b/src/connectors/unixdomain_connector/unixdomain_connector.h @@ -23,6 +23,7 @@ #include #include +#include #include "framework/connector.h" #include "managers/connector_manager.h" @@ -48,6 +49,11 @@ public: uint16_t connector_msg_length; }; +class UnixDomainConnector; + +typedef std::function UnixDomainConnectorUpdateHandler; +typedef std::function UnixDomainConnectorMessageReceivedHandler; + class UnixDomainConnector : public snort::Connector { public: @@ -60,6 +66,9 @@ public: snort::ConnectorMsg receive_message(bool) override; void process_receive(); + void set_update_handler(UnixDomainConnectorUpdateHandler handler); + void set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler); + int sock_fd; private: @@ -77,7 +86,31 @@ private: ReceiveRing* receive_ring; size_t instance_id; UnixDomainConnectorConfig cfg; + + UnixDomainConnectorUpdateHandler update_handler; + UnixDomainConnectorMessageReceivedHandler message_received_handler; +}; + +typedef std::function UnixDomainConnectorAcceptHandler; + +class UnixDomainConnectorListener +{ +public: + UnixDomainConnectorListener(const char* path); + ~UnixDomainConnectorListener(); + + void start_accepting_connections(UnixDomainConnectorAcceptHandler handler, UnixDomainConnectorConfig* config); + void stop_accepting_connections(); + + private: + char* sock_path; + int sock_fd; + std::thread* accept_thread; + std::atomic should_accept; + }; +extern SO_PUBLIC UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler = nullptr); + #endif // UNIXDOMAIN_CONNECTOR_H diff --git a/src/connectors/unixdomain_connector/unixdomain_connector_config.h b/src/connectors/unixdomain_connector/unixdomain_connector_config.h index 931395f3e..57d56abed 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector_config.h +++ b/src/connectors/unixdomain_connector/unixdomain_connector_config.h @@ -27,11 +27,23 @@ public: UnixDomainConnectorConfig() { direction = snort::Connector::CONN_DUPLEX; async_receive = true; } + UnixDomainConnectorConfig(const UnixDomainConnectorConfig& other) : + paths(other.paths), setup(other.setup), + conn_retries(other.conn_retries), + retry_interval(other.retry_interval), + max_retries(other.max_retries), + connect_timeout_seconds(other.connect_timeout_seconds), + async_receive(other.async_receive) + { + direction = other.direction; + } + std::vector paths; Setup setup = {}; bool conn_retries = false; uint32_t retry_interval = 4; uint32_t max_retries = 5; + uint32_t connect_timeout_seconds = 30; bool async_receive; }; diff --git a/src/framework/CMakeLists.txt b/src/framework/CMakeLists.txt index 109b20d35..d3eb4a7da 100644 --- a/src/framework/CMakeLists.txt +++ b/src/framework/CMakeLists.txt @@ -25,6 +25,7 @@ set (FRAMEWORK_INCLUDES range.h so_rule.h value.h + mp_transport.h ) add_library ( framework OBJECT diff --git a/src/framework/base_api.h b/src/framework/base_api.h index c806b6a61..4306687da 100644 --- a/src/framework/base_api.h +++ b/src/framework/base_api.h @@ -54,6 +54,7 @@ enum PlugType PT_LOGGER, PT_CONNECTOR, PT_POLICY_SELECTOR, + PT_MP_TRANSPORT, PT_MAX }; diff --git a/src/framework/mp_data_bus.cc b/src/framework/mp_data_bus.cc index ef8606471..1263b2e86 100644 --- a/src/framework/mp_data_bus.cc +++ b/src/framework/mp_data_bus.cc @@ -31,6 +31,8 @@ #include "pub_sub/intrinsic_event_ids.h" #include "utils/stats.h" #include "main/snort_types.h" +#include "managers/mp_transport_manager.h" +#include "log/messages.h" using namespace snort; @@ -105,6 +107,7 @@ void MPDataBus::receive_message(const MPEventInfo& event_info) UNUSED(event_info); } + //-------------------------------------------------------------------------- // private methods //-------------------------------------------------------------------------- diff --git a/src/framework/mp_data_bus.h b/src/framework/mp_data_bus.h index 0fb93aee7..d69918c94 100644 --- a/src/framework/mp_data_bus.h +++ b/src/framework/mp_data_bus.h @@ -39,6 +39,7 @@ #include "main/snort_types.h" #include "data_bus.h" +#include "framework/mp_transport.h" #include namespace snort @@ -47,8 +48,8 @@ class Flow; struct Packet; struct SnortConfig; -typedef bool (*MPSerializeFunc)(const DataEvent& event, char** buffer, size_t* length); -typedef bool (*MPDeserializeFunc)(const char* buffer, size_t length, DataEvent* event); +typedef bool (*MPSerializeFunc)(DataEvent* event, char*& buffer, uint16_t* length); +typedef bool (*MPDeserializeFunc)(const char* buffer, uint16_t length, DataEvent*& event); // Similar to the DataBus class, the MPDataBus class uses uses a combination of PubKey and event ID // for event subscriptions and publishing. New MP-specific event type enums should be added to the @@ -62,16 +63,16 @@ struct MPEventInfo { MPEventType type; unsigned pub_id; - DataEvent event; - MPEventInfo(const DataEvent& e, MPEventType t, unsigned id = 0) + DataEvent* event; + MPEventInfo(DataEvent* e, MPEventType t, unsigned id = 0) : type(t), pub_id(id), event(e) {} }; struct MPHelperFunctions { - MPSerializeFunc* serializer; - MPDeserializeFunc* deserializer; + MPSerializeFunc serializer; + MPDeserializeFunc deserializer; - MPHelperFunctions(MPSerializeFunc* s, MPDeserializeFunc* d) + MPHelperFunctions(MPSerializeFunc s, MPDeserializeFunc d) : serializer(s), deserializer(d) {} }; diff --git a/src/framework/mp_transport.h b/src/framework/mp_transport.h new file mode 100644 index 000000000..4b2decb74 --- /dev/null +++ b/src/framework/mp_transport.h @@ -0,0 +1,81 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_transport.h author Oleksandr Stepanov + +#ifndef MP_TRANSPORT_H +#define MP_TRANSPORT_H + +#include "main/snort_types.h" +#include "framework/base_api.h" + +#include + +namespace snort +{ + +#define MP_TRANSPORT_API_VERSION ((BASE_API_VERSION << 16) | 1) + +struct SnortConfig; +struct MPEventInfo; +struct MPHelperFunctions; + +typedef std::function TransportReceiveEventHandler; + +class MPTransport +{ + public: + + MPTransport() = default; + virtual ~MPTransport() = default; + + virtual bool configure(const SnortConfig*) = 0; + virtual void thread_init() = 0; + virtual void thread_term() = 0; + virtual void init_connection() = 0; + virtual bool send_to_transport(MPEventInfo& event) = 0; + virtual void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) = 0; + virtual void register_receive_handler(const TransportReceiveEventHandler& handler) = 0; + virtual void unregister_receive_handler() = 0; + virtual void enable_logging() = 0; + virtual void disable_logging() = 0; + virtual bool is_logging_enabled() = 0; +}; + + +typedef MPTransport* (* MPTransportNewFunc)(Module*); +typedef void (* MPTransportDelFunc)(MPTransport*); +typedef void (* MPTransportThreadInitFunc)(MPTransport*); +typedef void (* MPTransportThreadTermFunc)(MPTransport*); +typedef void (* MPTransportFunc)(); + +struct MPTransportApi +{ + BaseApi base; + unsigned flags; + + MPTransportFunc pinit; // plugin init + MPTransportFunc pterm; // cleanup pinit() + MPTransportThreadInitFunc tinit; // thread local init + MPTransportThreadTermFunc tterm; // cleanup tinit() + + MPTransportNewFunc ctor; + MPTransportDelFunc dtor; +}; + +} +#endif // MP_TRANSPORT_H diff --git a/src/main/help.cc b/src/main/help.cc index 7f3fafbfb..8d6f7eb42 100644 --- a/src/main/help.cc +++ b/src/main/help.cc @@ -34,6 +34,7 @@ #include "managers/plugin_manager.h" #include "managers/script_manager.h" #include "managers/so_manager.h" +#include "managers/mp_transport_manager.h" #include "packet_io/sfdaq.h" #include "utils/util.h" @@ -211,6 +212,7 @@ enum HelpType ModuleManager::term(); PluginManager::release_plugins(); ScriptManager::release_scripts(); + MPTransportManager::term(); delete SnortConfig::get_conf(); exit(0); } diff --git a/src/main/snort.cc b/src/main/snort.cc index e9e5d41c5..2f889ddc2 100644 --- a/src/main/snort.cc +++ b/src/main/snort.cc @@ -60,6 +60,7 @@ #include "managers/plugin_manager.h" #include "managers/policy_selector_manager.h" #include "managers/script_manager.h" +#include "managers/mp_transport_manager.h" #include "memory/memory_cap.h" #include "network_inspectors/network_inspectors.h" #include "packet_io/active.h" @@ -79,6 +80,7 @@ #include "trace/trace_api.h" #include "trace/trace_config.h" #include "trace/trace_logger.h" +#include "mp_transport/mp_transports.h" #include "utils/stats.h" #include "utils/util.h" @@ -132,6 +134,7 @@ void Snort::init(int argc, char** argv) load_stream_inspectors(); load_network_inspectors(); load_service_inspectors(); + load_mp_transports(); snort_cmd_line_conf = parse_cmd_line(argc, argv); SnortConfig::set_conf(snort_cmd_line_conf); @@ -371,6 +374,7 @@ void Snort::term() host_cache.term(); PluginManager::release_plugins(); ScriptManager::release_scripts(); + MPTransportManager::term(); memory::MemoryCap::term(); detection_filter_term(); diff --git a/src/managers/CMakeLists.txt b/src/managers/CMakeLists.txt index d748f99e4..f72ddfb7e 100644 --- a/src/managers/CMakeLists.txt +++ b/src/managers/CMakeLists.txt @@ -38,6 +38,8 @@ add_library( managers OBJECT so_manager.h connector_manager.cc connector_manager.h + mp_transport_manager.cc + mp_transport_manager.h ) add_custom_command ( diff --git a/src/managers/mp_transport_manager.cc b/src/managers/mp_transport_manager.cc new file mode 100644 index 000000000..9dc73af73 --- /dev/null +++ b/src/managers/mp_transport_manager.cc @@ -0,0 +1,105 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_transport_manager.cc author Oleksandr Stepanov + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "mp_transport_manager.h" + +#include + +using namespace snort; + +struct MPTransportHandler +{ + MPTransportHandler(MPTransport* transport, const MPTransportApi* api) + : transport(transport), api(api) {} + MPTransport* transport; + const MPTransportApi* api; +}; + +static std::unordered_map transports_map; + +void MPTransportManager::instantiate(const MPTransportApi *api, Module *mod, SnortConfig*) +{ + if(transports_map.find(api->base.name) != transports_map.end()) + { + return; + } + + transports_map.insert(std::make_pair(api->base.name, new MPTransportHandler(api->ctor(mod), api))); +} + +MPTransport *MPTransportManager::get_transport(const std::string &name) +{ + auto it = transports_map.find(name); + if (it != transports_map.end()) + { + return it->second->transport; + } + return nullptr; +} + +void MPTransportManager::add_plugin(const MPTransportApi *api) +{ + if (api->pinit) + { + api->pinit(); + } +} + +void MPTransportManager::thread_init() +{ + for (auto &transport : transports_map) + { + if (transport.second->api->tinit) + { + transport.second->api->tinit(transport.second->transport); + } + } +} + +void MPTransportManager::thread_term() +{ + for (auto &transport : transports_map) + { + if (transport.second->api->tterm) + { + transport.second->api->tterm(transport.second->transport); + } + } +} + +void MPTransportManager::term() +{ + for (auto &transport : transports_map) + { + if (transport.second->api->dtor) + { + transport.second->api->dtor(transport.second->transport); + } + if (transport.second->api->pterm) + { + transport.second->api->pterm(); + } + delete transport.second; + } + transports_map.clear(); +} diff --git a/src/managers/mp_transport_manager.h b/src/managers/mp_transport_manager.h new file mode 100644 index 000000000..c1e6503bb --- /dev/null +++ b/src/managers/mp_transport_manager.h @@ -0,0 +1,51 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_transport_manager.h author Oleksandr Stepanov + +#ifndef MP_TRANSPORT_MANAGER_H +#define MP_TRANSPORT_MANAGER_H + +// Manager for multiprocess layer objects + +#include +#include "framework/mp_transport.h" + +namespace snort +{ +class Module; +struct SnortConfig; + + +//------------------------------------------------------------------------- + +class MPTransportManager +{ +public: + static void instantiate(const MPTransportApi* api, Module* mod, SnortConfig*); + static MPTransport* get_transport(const std::string& name); + + static void add_plugin(const MPTransportApi* api); + + static void thread_init(); + static void thread_term(); + + static void term(); +}; +} +#endif + diff --git a/src/managers/plugin_manager.cc b/src/managers/plugin_manager.cc index 6f0f4ec09..e87ab9b3d 100644 --- a/src/managers/plugin_manager.cc +++ b/src/managers/plugin_manager.cc @@ -42,6 +42,7 @@ #include "ips_manager.h" #include "module_manager.h" #include "mpse_manager.h" +#include "mp_transport_manager.h" #include "policy_selector_manager.h" #include "script_manager.h" #include "so_manager.h" @@ -72,6 +73,7 @@ static Symbol symbols[PT_MAX] = { "logger", LOGAPI_VERSION, sizeof(LogApi) }, { "connector", CONNECTOR_API_VERSION, sizeof(ConnectorApi) }, { "policy_selector", POLICY_SELECTOR_API_VERSION, sizeof(PolicySelectorApi) }, + { "mp_transport", MP_TRANSPORT_API_VERSION, sizeof(MPTransportApi) }, }; #else // this gets around the sequence issue with some compilers @@ -88,7 +90,9 @@ static Symbol symbols[PT_MAX] = [PT_LOGGER] = { stringify(PT_LOGGER), LOGAPI_VERSION, sizeof(LogApi) }, [PT_CONNECTOR] = { stringify(PT_CONNECTOR), CONNECTOR_API_VERSION, sizeof(ConnectorApi) }, [PT_POLICY_SELECTOR] = { stringify(PT_POLICY_SELECTOR), POLICY_SELECTOR_API_VERSION, - sizeof(PolicySelectorApi) } + sizeof(PolicySelectorApi) }, + [PT_MP_TRANSPORT] = { stringify(PT_MP_TRANSPORT), MP_TRANSPORT_API_VERSION, + sizeof(MPTransportApi) } }; #endif @@ -321,6 +325,10 @@ static void add_plugin(Plugin& p) case PT_POLICY_SELECTOR: PolicySelectorManager::add_plugin((const PolicySelectorApi*)p.api); break; + + case PT_MP_TRANSPORT: + MPTransportManager::add_plugin((const MPTransportApi*)p.api); + break; default: assert(false); @@ -546,6 +554,10 @@ void PluginManager::instantiate( //IpsManager::instantiate((SoApi*)api, mod, sc); break; + case PT_MP_TRANSPORT: + MPTransportManager::instantiate((const MPTransportApi*)api, mod, sc); + break; + case PT_LOGGER: EventManager::instantiate((const LogApi*)api, mod, sc); break; diff --git a/src/managers/test/CMakeLists.txt b/src/managers/test/CMakeLists.txt index cecd450f1..059dbf51b 100644 --- a/src/managers/test/CMakeLists.txt +++ b/src/managers/test/CMakeLists.txt @@ -3,3 +3,10 @@ add_cpputest(get_inspector_test get_inspector_stubs.h ../inspector_manager.cc ) + +add_cpputest(mp_transport_manager_test + SOURCES + mp_transport_manager_test.cc + ../mp_transport_manager.cc + ../../framework/module.cc +) diff --git a/src/managers/test/mp_transport_manager_test.cc b/src/managers/test/mp_transport_manager_test.cc new file mode 100644 index 000000000..5438ac529 --- /dev/null +++ b/src/managers/test/mp_transport_manager_test.cc @@ -0,0 +1,197 @@ +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "framework/mp_data_bus.h" +#include "framework/mp_transport.h" +#include "framework/module.h" +#include "main/thread_config.h" + +#include "../mp_transport_manager.h" + +#include +#include + +#include + +#define MODULE_NAME "mock_transport" +#define MODULE_HELP "mock transport for testing" + +static int test_transport_ctor_calls = 0; +static int test_transport_dtor_calls = 0; +static int test_transport_pinit_calls = 0; +static int test_transport_pterm_calls = 0; +static int test_transport_tinit_calls = 0; +static int test_transport_tterm_calls = 0; + +namespace snort +{ + +class MockTransport : public MPTransport +{ + public: + MockTransport() : MPTransport() + { } + virtual ~MockTransport() override + { } + virtual bool send_to_transport(MPEventInfo& event) override + { return true; } + virtual void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) override + { } + virtual void init_connection() override + { } + virtual void register_receive_handler(const TransportReceiveEventHandler& handler) override + { } + virtual void unregister_receive_handler() override + { } + virtual void thread_init() override + { } + virtual void thread_term() override + { } + virtual bool configure(const SnortConfig*) override + { return true; } + virtual void enable_logging() override + { } + virtual void disable_logging() override + { } + virtual bool is_logging_enabled() override + { return false; } +}; + +unsigned get_instance_id() { return 0; } +unsigned ThreadConfig::get_instance_max() { return 1; } +}; + + +using namespace snort; + +void show_stats(unsigned long*, PegInfo const*, std::vector > const&, char const*, _IO_FILE*) +{} +void show_stats(unsigned long*, PegInfo const*, unsigned int, char const*) +{} + +static void mock_transport_tinit(MPTransport* t) +{ + test_transport_tinit_calls++; +} +static void mock_transport_tterm(MPTransport* t) +{ + test_transport_tterm_calls++; +} +static MPTransport* mock_transport_ctor(Module* m) +{ + test_transport_ctor_calls++; + return new MockTransport(); +} +static void mock_transport_dtor(MPTransport* t) +{ + test_transport_dtor_calls++; + delete t; +} + +static void mock_transport_pinit() +{ + // Mock plugin init + test_transport_pinit_calls++; +} +static void mock_transport_pterm() +{ + // Mock plugin term + test_transport_pterm_calls++; +} + +static void clear_test_calls() +{ + test_transport_ctor_calls = 0; + test_transport_dtor_calls = 0; + test_transport_pinit_calls = 0; + test_transport_pterm_calls = 0; + test_transport_tinit_calls = 0; + test_transport_tterm_calls = 0; +} + +static struct MPTransportApi mock_transport_api = +{ + { + PT_MP_TRANSPORT, + sizeof(MPTransportApi), + MP_TRANSPORT_API_VERSION, + 2, + API_RESERVED, + API_OPTIONS, + MODULE_NAME, + MODULE_HELP, + nullptr, + nullptr + }, + 0, + mock_transport_pinit, + mock_transport_pterm, + mock_transport_tinit, + mock_transport_tterm, + mock_transport_ctor, + mock_transport_dtor +}; + +TEST_GROUP(mp_transport_manager_test_group) +{ + void setup() override + { + clear_test_calls(); + } + + void teardown() override + { + MPTransportManager::term(); + } +}; + +TEST(mp_transport_manager_test_group, instantiate_transport_object) +{ + MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr); + CHECK(test_transport_ctor_calls == 1); +} + +TEST(mp_transport_manager_test_group, get_transport_object) +{ + MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr); + MPTransport* transport = MPTransportManager::get_transport(MODULE_NAME); + CHECK(transport != nullptr); + + transport = MPTransportManager::get_transport("non_existent_transport"); + CHECK(transport == nullptr); + + MPTransportManager::term(); + + transport = MPTransportManager::get_transport(MODULE_NAME); + CHECK(transport == nullptr); +} + +TEST(mp_transport_manager_test_group, add_plugin) +{ + MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr); + MPTransportManager::add_plugin(&mock_transport_api); + CHECK(test_transport_pinit_calls == 1); + + MPTransportManager::term(); + CHECK(test_transport_pterm_calls == 1); +} + +TEST(mp_transport_manager_test_group, thread_init_term) +{ + MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr); + MPTransportManager::thread_init(); + CHECK(test_transport_tinit_calls == 1); + + MPTransportManager::thread_term(); + CHECK(test_transport_tterm_calls == 1); +} + +int main(int argc, char** argv) +{ + // Allocate object in internal MPTransportManager unordered_map to prevent false-positive memory leak detection ( buckets allocation at first insert ) + MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr); + MPTransportManager::term(); + int return_value = CommandLineTestRunner::RunAllTests(argc, argv); + return return_value; +} \ No newline at end of file diff --git a/src/mp_transport/CMakeLists.txt b/src/mp_transport/CMakeLists.txt new file mode 100644 index 000000000..5c4254574 --- /dev/null +++ b/src/mp_transport/CMakeLists.txt @@ -0,0 +1,6 @@ +add_subdirectory(mp_unix_transport) + +add_library( mp_transports OBJECT + mp_transports.cc + mp_transports.h +) \ No newline at end of file diff --git a/src/mp_transport/dev_notes.txt b/src/mp_transport/dev_notes.txt new file mode 100644 index 000000000..583d8b393 --- /dev/null +++ b/src/mp_transport/dev_notes.txt @@ -0,0 +1,5 @@ +The `mp_transport` files provide the foundation for message-passing transport mechanisms, enabling communication +between distributed components or processes in a modular and extensible manner. These files abstract the complexities +of underlying communication protocols, offering a unified API for sending, receiving, and managing messages. This +abstraction ensures that developers can focus on higher-level application logic without worrying about protocol-specific +details. These transports are utilized by `mp_data_bus.h` to transfer events between distributed Snort instances. diff --git a/src/mp_transport/mp_transports.cc b/src/mp_transport/mp_transports.cc new file mode 100644 index 000000000..c65610915 --- /dev/null +++ b/src/mp_transport/mp_transports.cc @@ -0,0 +1,36 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_transports.cc author Oleksandr Stepanov + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "mp_transports.h" + +#include "framework/mp_transport.h" +#include "managers/plugin_manager.h" + +using namespace snort; + +extern const BaseApi* mp_unix_transport[]; + +void load_mp_transports() +{ + PluginManager::load_plugins(mp_unix_transport); +} diff --git a/src/mp_transport/mp_transports.h b/src/mp_transport/mp_transports.h new file mode 100644 index 000000000..4416d8980 --- /dev/null +++ b/src/mp_transport/mp_transports.h @@ -0,0 +1,25 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_transports.h author Oleksandr Stepanov + +#ifndef MP_TRANSPORTS_H +#define MP_TRANSPORTS_H + +void load_mp_transports(); + +#endif // MP_TRANSPORTS_H diff --git a/src/mp_transport/mp_unix_transport/CMakeLists.txt b/src/mp_transport/mp_unix_transport/CMakeLists.txt new file mode 100644 index 000000000..11b80260b --- /dev/null +++ b/src/mp_transport/mp_unix_transport/CMakeLists.txt @@ -0,0 +1,24 @@ + +set( MP_TRANSPORT_INCLUDES + mp_unix_transport.h + mp_unix_transport_module.h +) + +add_library( mp_unix_transport OBJECT + ${MP_TRANSPORT_INCLUDES} + mp_unix_transport.cc + mp_unix_transport_module.cc +) + +install (FILES ${MIME_INCLUDES} + DESTINATION "${INCLUDE_INSTALL_PATH}/mp_unix_transport" +) + +add_dependencies( mp_unix_transport framework ) +add_dependencies( mp_unix_transport unixdomain_connector ) + +find_package(Threads REQUIRED) +target_link_libraries(mp_unix_transport PRIVATE Threads::Threads) +target_link_libraries(mp_unix_transport PRIVATE $) + +add_subdirectory( test ) \ No newline at end of file diff --git a/src/mp_transport/mp_unix_transport/dev_notes.txt b/src/mp_transport/mp_unix_transport/dev_notes.txt new file mode 100644 index 000000000..9f4af4d82 --- /dev/null +++ b/src/mp_transport/mp_unix_transport/dev_notes.txt @@ -0,0 +1,13 @@ +The MP Unix Domain Transport provides an implementation of the Multi-Process (MP) Transport +interface using existing `UnixDomainConnector` infrastructure. This transport enables +inter-process communication (IPC) between multiple Snort instances running on the same host, +allowing them to exchange events and data. + + * MPUnixDomainTransport - Main class implementing the MPTransport interface with Unix domain socket functionality + * MPUnixDomainTransportModule - Module class that handles configuration parameters and provides an API for MPUnixDomainTransport creation + * MPUnixDomainTransportConfig - Configuration structure for socket paths and connection parameters + +Connention between snort establishes in next sequence: +* First Snort instance acts as a server socket accepting connections from other instances +* Additional instances connect to existing socket paths +* Dynamic re-connection handling with configurable retry parameters diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc new file mode 100644 index 000000000..6b15acad1 --- /dev/null +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc @@ -0,0 +1,447 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_unix_transport.cc author Oleksandr Stepanov + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "mp_unix_transport.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "framework/mp_data_bus.h" +#include "log/messages.h" +#include "main/snort.h" +#include "main/snort_config.h" + +static std::mutex _receive_mutex; +static std::mutex _update_connectors_mutex; + +#define UNIX_SOCKET_NAME_PREFIX "/snort_unix_connector_" + +#define MP_TRANSPORT_LOG_LABEL "MPUnixTransport" + +#define MP_TRANSPORT_LOG(msg, ...) do { \ + if (!this->is_logging_enabled_flag) \ + break; \ + LogMessage(msg, __VA_ARGS__); \ + } while (0) + +namespace snort +{ + +#pragma pack(push, 1) + enum MPTransportMessageType + { + EVENT_MESSAGE = 0, + MAX_TYPE + }; + +struct MPTransportMessageHeader +{ + MPTransportMessageType type;// Type of the message + int32_t pub_id; // Identifier for the module sending or receiving the message + int32_t event_id; // Identifier for the specific event + uint16_t data_length; // Length of the data payload +}; + +struct MPTransportMessage +{ + MPTransportMessageHeader header; // Header containing metadata about the message + char* data; // Placeholder for the actual data payload +}; +#pragma pack(pop) + +void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg) +{ + if (transport_receive_handler and msg) + { + if (msg->content_length < sizeof(MPTransportMessage)) + { + MP_TRANSPORT_LOG("%s: Incomplete message received\n", MP_TRANSPORT_LOG_LABEL); + return; + } + + MPTransportMessageHeader* transport_message_header = (MPTransportMessageHeader*)msg->content; + + if (transport_message_header->type >= MAX_TYPE) + { + MP_TRANSPORT_LOG("%s: Invalid message type received\n", MP_TRANSPORT_LOG_LABEL); + return; + } + + auto deserialize_func = get_event_deserialization_function(transport_message_header->pub_id, transport_message_header->event_id); + if (!deserialize_func) + { + MP_TRANSPORT_LOG("%s: No deserialization function found for event: type %d, id %d\n", MP_TRANSPORT_LOG_LABEL, transport_message_header->type, transport_message_header->event_id); + return; + } + + DataEvent* internal_event = nullptr; + (deserialize_func)((const char*)(msg->content + sizeof(MPTransportMessageHeader)), transport_message_header->data_length, internal_event); + MPEventInfo event(internal_event, transport_message_header->event_id, transport_message_header->pub_id); + + (transport_receive_handler)(event); + + delete internal_event; + } + delete msg; +} + +void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector, UnixDomainConnectorConfig* cfg) +{ + assert(connector); + assert(cfg); + + std::lock_guard guard(_update_connectors_mutex); + + auto side_channel = new SideChannel(ScMsgFormat::BINARY); + side_channel->connector_receive = connector; + side_channel->connector_transmit = side_channel->connector_receive; + side_channel->register_receive_handler(std::bind(&MPUnixDomainTransport::side_channel_receive_handler, this, std::placeholders::_1)); + connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this)); + this->side_channels.push_back(new SideChannelHandle(side_channel, cfg)); + connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); +} + +MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c) : MPTransport(), + config(c) +{ + this->is_logging_enabled_flag = c->enable_logging; +} + +MPUnixDomainTransport::~MPUnixDomainTransport() +{ + cleanup(); +} + +bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event) +{ + auto serialize_func = get_event_serialization_function(event.pub_id, event.type); + + if (!serialize_func) + { + MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event.type); + return false; + } + + MPTransportMessage transport_message; + transport_message.header.type = EVENT_MESSAGE; + transport_message.header.pub_id = event.pub_id; + transport_message.header.event_id = event.type; + + + (serialize_func)(event.event, transport_message.data, &transport_message.header.data_length); + for (auto &&sc_handler : this->side_channels) + { + auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length); + memcpy(msg->content, &transport_message, sizeof(MPTransportMessageHeader)); + memcpy(msg->content + sizeof(MPTransportMessageHeader), transport_message.data, transport_message.header.data_length); + auto send_result = sc_handler->side_channel->transmit_message(msg); + if (!send_result) + { + MP_TRANSPORT_LOG("%s: Failed to send message to side channel\n", MP_TRANSPORT_LOG_LABEL); + } + } + + delete[] transport_message.data; + + return true; +} + +void MPUnixDomainTransport::register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) +{ + assert(helper.deserializer); + assert(helper.serializer); + + this->event_helpers[pub_id] = SerializeFunctionHandle(); + this->event_helpers[pub_id].serialize_functions.insert({event_id, helper}); +} + +void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler& handler) +{ + this->transport_receive_handler = handler; +} + +void MPUnixDomainTransport::unregister_receive_handler() +{ + this->transport_receive_handler = nullptr; +} + +void MPUnixDomainTransport::process_messages_from_side_channels() +{ + std::unique_lock lock(_receive_mutex); + do + { + if ( (std::cv_status::timeout == this->consume_thread_cv.wait_for(lock, std::chrono::milliseconds(config->consume_message_timeout_milliseconds)) ) + and this->consume_message_received == false ) + { + continue; + } + + { + std::lock_guard guard(_update_connectors_mutex); + bool messages_left; + + do + { + messages_left = false; + for (auto &&sc_handler : this->side_channels) + { + messages_left |= sc_handler->side_channel->process(config->consume_message_batch_size); + } + } while (messages_left); + } + + this->consume_message_received = false; + + } while (this->is_running); +} + +void MPUnixDomainTransport::notify_process_thread() +{ + this->consume_thread_cv.notify_all(); + this->consume_message_received = true; +} + +void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_recconecting, SideChannel *side_channel) +{ + std::lock_guard guard(_update_connectors_mutex); + if (side_channel->connector_receive) + { + delete side_channel->connector_receive; + side_channel->connector_receive = side_channel->connector_transmit = nullptr; + } + + if (connector) + { + side_channel->connector_receive = side_channel->connector_transmit = connector; + } + else + { + if (is_recconecting == false) + { + MP_TRANSPORT_LOG("%s: Accepted connection interrupted, removing handle\n", MP_TRANSPORT_LOG_LABEL); + for(auto it = this->side_channels.begin(); it != this->side_channels.end(); ++it) + { + if ((*it)->side_channel == side_channel) + { + delete *it; + this->side_channels.erase(it); + break; + } + } + } + } +} + +MPSerializeFunc MPUnixDomainTransport::get_event_serialization_function(unsigned pub_id, unsigned event_id) +{ + auto helper_it = this->event_helpers.find(pub_id); + if (helper_it == this->event_helpers.end()) + { + MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id); + return nullptr; + } + auto helper_functions = helper_it->second.get_function_set(event_id); + if (!helper_functions) + { + MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id); + return nullptr; + } + return helper_functions->serializer; +} + +MPDeserializeFunc MPUnixDomainTransport::get_event_deserialization_function(unsigned pub_id, unsigned event_id) +{ + auto helper_it = this->event_helpers.find(pub_id); + if (helper_it == this->event_helpers.end()) + { + MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id); + return nullptr; + } + auto helper_functions = helper_it->second.get_function_set(event_id); + if (!helper_functions) + { + MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id); + return nullptr; + } + return helper_functions->deserializer; +} + +void MPUnixDomainTransport::init_connection() +{ + init_side_channels(); +} + +void MPUnixDomainTransport::thread_init() +{ +} + +void MPUnixDomainTransport::thread_term() +{ +} + +bool MPUnixDomainTransport::configure(const SnortConfig *c) +{ + config->max_processes = c->max_procs; + return true; +} + +void MPUnixDomainTransport::cleanup() +{ + this->is_running = false; + MPUnixDomainTransport::unregister_receive_handler(); + if (this->consume_thread) + { + this->consume_thread_cv.notify_all(); + this->consume_thread->join(); + delete this->consume_thread; + this->consume_thread = nullptr; + } + cleanup_side_channels(); + for (auto &&ac_handler : this->accept_handlers) + { + ac_handler->listener->stop_accepting_connections(); + delete ac_handler->listener; + delete ac_handler->connector_config; + delete ac_handler; + } + this->accept_handlers.clear(); +} + +void MPUnixDomainTransport::init_side_channels() +{ + assert(config); + if (config->max_processes < 2) + return; + + auto instance_id = Snort::get_process_id();//Snort instance id + auto max_processes = config->max_processes; + + this->is_running = true; + + for (ushort i = instance_id; i < max_processes; i++) + { + auto listen_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i); + auto unix_listener = new UnixDomainConnectorListener(listen_path.c_str()); + + UnixDomainConnectorConfig* unix_config = new UnixDomainConnectorConfig(); + unix_config->setup = UnixDomainConnectorConfig::Setup::ANSWER; + unix_config->async_receive = true; + if (config->conn_retries) + { + unix_config->conn_retries = config->conn_retries; + unix_config->retry_interval = config->retry_interval_seconds; + unix_config->max_retries = config->max_retries; + unix_config->connect_timeout_seconds = config->connect_timeout_seconds; + } + else + { + unix_config->conn_retries = false; + unix_config->retry_interval = 0; + unix_config->max_retries = 0; + unix_config->connect_timeout_seconds = 0; + } + unix_config->paths.push_back(listen_path); + + unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2), unix_config); + + auto unix_listener_handle = new UnixAcceptorHandle(); + unix_listener_handle->connector_config = unix_config; + unix_listener_handle->listener = unix_listener; + this->accept_handlers.push_back(unix_listener_handle); + } + + for (ushort i = 1; i < instance_id; i++) + { + auto side_channel = new SideChannel(ScMsgFormat::BINARY); + side_channel->register_receive_handler([this](SCMessage* msg) { this->side_channel_receive_handler(msg); }); + + auto send_path = config->unix_domain_socket_path + "/" + "snort_unix_connector_" + std::to_string(i); + + UnixDomainConnectorConfig* connector_conf = new UnixDomainConnectorConfig(); + connector_conf->setup = UnixDomainConnectorConfig::Setup::CALL; + connector_conf->async_receive = true; + connector_conf->conn_retries = config->conn_retries; + connector_conf->retry_interval = config->retry_interval_seconds; + connector_conf->max_retries = config->max_retries; + connector_conf->connect_timeout_seconds = config->connect_timeout_seconds; + connector_conf->paths.push_back(send_path); + + auto connector = unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel)); + + if (connector) + connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this)); + + side_channel->connector_receive = connector; + side_channel->connector_transmit = side_channel->connector_receive; + this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf)); + } + + this->consume_thread = new std::thread(&MPUnixDomainTransport::process_messages_from_side_channels, this); +} +void MPUnixDomainTransport::cleanup_side_channels() +{ + std::lock_guard guard(_update_connectors_mutex); + + for (uint i = 0; i < this->side_channels.size(); i++) + { + auto side_channel = this->side_channels[i]; + delete side_channel; + } + + this->side_channels.clear(); +} + +SideChannelHandle::~SideChannelHandle() +{ + if (side_channel) + { + if (side_channel->connector_receive) + delete side_channel->connector_receive; + + delete side_channel; + } + + if (connector_config) + delete connector_config; +} +void MPUnixDomainTransport::enable_logging() +{ + this->is_logging_enabled_flag = true; +} + +void MPUnixDomainTransport::disable_logging() +{ + this->is_logging_enabled_flag = false; +} + +bool MPUnixDomainTransport::is_logging_enabled() +{ + return this->is_logging_enabled_flag; +} + +}; diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.h b/src/mp_transport/mp_unix_transport/mp_unix_transport.h new file mode 100644 index 000000000..8208e030e --- /dev/null +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport.h @@ -0,0 +1,132 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_unix_transport.h author Oleksandr Stepanov + +#ifndef UNIX_TRANSPORT_H +#define UNIX_TRANSPORT_H + +#include "connectors/unixdomain_connector/unixdomain_connector.h" +#include "framework/mp_data_bus.h" +#include "main/snort_types.h" +#include "side_channel/side_channel.h" + +#include +#include +#include + +namespace snort +{ + +struct MPUnixDomainTransportConfig +{ + std::string unix_domain_socket_path; + uint16_t max_processes = 0; + bool conn_retries = true; + bool enable_logging = false; + uint32_t retry_interval_seconds = 5; + uint32_t max_retries = 5; + uint32_t connect_timeout_seconds = 30; + uint32_t consume_message_timeout_milliseconds = 100; + uint32_t consume_message_batch_size = 5; +}; + +struct SerializeFunctionHandle +{ + std::unordered_map serialize_functions; + + MPHelperFunctions* get_function_set(unsigned event_id) + { + auto it = serialize_functions.find(event_id); + if(it == serialize_functions.end()) + return nullptr; + return &it->second; + } +}; + +struct SideChannelHandle +{ + SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc) : + side_channel(sc), connector_config(cc) + { } + + ~SideChannelHandle(); + + SideChannel* side_channel; + UnixDomainConnectorConfig* connector_config; +}; + +struct UnixAcceptorHandle +{ + UnixDomainConnectorConfig* connector_config = nullptr; + UnixDomainConnectorListener* listener = nullptr; +}; + +class MPUnixDomainTransport : public MPTransport +{ + public: + + MPUnixDomainTransport(MPUnixDomainTransportConfig* c); + ~MPUnixDomainTransport() override; + + bool configure(const SnortConfig*) override; + void thread_init() override; + void thread_term() override; + void init_connection() override; + bool send_to_transport(MPEventInfo& event) override; + void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) override; + void register_receive_handler(const TransportReceiveEventHandler& handler) override; + void unregister_receive_handler() override; + void enable_logging() override; + void disable_logging() override; + bool is_logging_enabled() override; + void cleanup(); + + MPUnixDomainTransportConfig* get_config() + { return config; } + + + private: + + void init_side_channels(); + void cleanup_side_channels(); + void side_channel_receive_handler(SCMessage* msg); + void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg); + void process_messages_from_side_channels(); + void notify_process_thread(); + void connector_update_handler(UnixDomainConnector* connector, bool is_recconecting, SideChannel* side_channel); + + MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id); + MPDeserializeFunc get_event_deserialization_function(unsigned pub_id, unsigned event_id); + + TransportReceiveEventHandler transport_receive_handler = nullptr; + MPUnixDomainTransportConfig* config = nullptr; + + std::vector side_channels; + std::vector accept_handlers; + std::unordered_map event_helpers; + + std::atomic is_running = false; + std::atomic is_logging_enabled_flag; + std::atomic consume_message_received = false; + + std::thread* consume_thread = nullptr; + std::condition_variable consume_thread_cv; +}; + +} +#endif diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc new file mode 100644 index 000000000..b0c8c5f05 --- /dev/null +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc @@ -0,0 +1,130 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_unix_transport_module.cc author Oleksandr Stepanov + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "mp_unix_transport_module.h" + +#include "main/snort_config.h" +#include "log/messages.h" + +#define DEFAULT_UNIX_DOMAIN_SOCKET_PATH "/tmp/snort_unix_connectors" + +using namespace snort; + +static const Parameter unix_transport_params[] = +{ + { "unix_domain_socket_path" , Parameter::PT_STRING, nullptr, DEFAULT_UNIX_DOMAIN_SOCKET_PATH, "unix socket folder" }, + { "max_connect_retries", Parameter::PT_INT, nullptr, "5", "max connection retries" }, + { "retry_interval_seconds", Parameter::PT_INT, nullptr, "30", "retry interval in seconds" }, + { "connect_timeout_seconds", Parameter::PT_INT, nullptr, "30", "connect timeout in seconds" }, + { "consume_message_timeout_milliseconds", Parameter::PT_INT, nullptr, "100", "consume message timeout in milliseconds" }, + { "consume_message_batch_size", Parameter::PT_INT, nullptr, "5", "consume message batch size" }, + { "enable_logging", Parameter::PT_BOOL, nullptr, "false", "enable logging" }, + { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } +}; + +MPUnixDomainTransportModule::MPUnixDomainTransportModule(): Module(MODULE_NAME, MODULE_HELP, unix_transport_params) +{ + config = nullptr; +} + +bool MPUnixDomainTransportModule::begin(const char *, int, SnortConfig *sc) +{ + assert(sc); + assert(!config); + config = new MPUnixDomainTransportConfig; + config->max_processes = sc->max_procs; + return true; +} + +bool MPUnixDomainTransportModule::set(const char *, Value & v, SnortConfig *) +{ + if (v.is("unix_domain_socket_path")) + { + config->unix_domain_socket_path = v.get_string(); + } + else if (v.is("max_connect_retries")) + { + config->conn_retries = true; + config->max_retries = v.get_int32(); + } + else if (v.is("retry_interval_seconds")) + { + config->retry_interval_seconds = v.get_int32(); + } + else if (v.is("connect_timeout_seconds")) + { + config->connect_timeout_seconds = v.get_int32(); + } + else if (v.is("consume_message_timeout_milliseconds")) + { + config->consume_message_timeout_milliseconds = v.get_int32(); + } + else if (v.is("consume_message_batch_size")) + { + config->consume_message_batch_size = v.get_int32(); + } + else if (v.is("enable_logging")) + { + config->enable_logging = v.get_bool(); + } + else + { + WarningMessage("MPUnixDomainTransportModule: received unrecognized parameter %s\n", v.get_as_string().c_str()); + return false; + } + + return true; +} + +static struct MPTransportApi mp_unixdomain_transport_api = +{ + { + PT_MP_TRANSPORT, + sizeof(MPTransportApi), + MP_TRANSPORT_API_VERSION, + 2, + API_RESERVED, + API_OPTIONS, + MODULE_NAME, + MODULE_HELP, + mod_ctor, + mod_dtor + }, + 0, + nullptr, + nullptr, + nullptr, + nullptr, + mp_unixdomain_transport_ctor, + mp_unixdomain_transport_dtor +}; + +#ifdef BUILDING_SO +SO_PUBLIC const BaseApi* snort_plugins[] = +#else +const BaseApi* mp_unix_transport[] = +#endif +{ + &mp_unixdomain_transport_api.base, + nullptr +}; diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h new file mode 100644 index 000000000..464b4f3e8 --- /dev/null +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h @@ -0,0 +1,74 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// mp_unix_transport_module.h author Oleksandr Stepanov + +#ifndef MP_UNIX_TRANSPORT_MODULE_H +#define MP_UNIX_TRANSPORT_MODULE_H + +#define MODULE_NAME "unix_transport" +#define MODULE_HELP "manage the unix transport layer" + +#include "framework/module.h" +#include "framework/mp_transport.h" +#include "mp_unix_transport.h" + +namespace snort +{ + +class MPUnixDomainTransportModule : public Module +{ + public: + + MPUnixDomainTransportModule(); + + ~MPUnixDomainTransportModule() override + { delete config; } + + bool begin(const char*, int, SnortConfig*) override; + bool set(const char*, Value&, SnortConfig*) override; + + Usage get_usage() const override + { return GLOBAL; } + + MPUnixDomainTransportConfig* config; +}; + +static Module* mod_ctor() +{ + return new MPUnixDomainTransportModule; +} + +static void mod_dtor(Module* m) +{ + delete m; +} + +static MPTransport* mp_unixdomain_transport_ctor(Module* m) +{ + auto unix_tr_mod = (MPUnixDomainTransportModule*)m; + return new MPUnixDomainTransport(unix_tr_mod->config); +} + +static void mp_unixdomain_transport_dtor(MPTransport* t) +{ + delete t; +} + +} + +#endif diff --git a/src/mp_transport/mp_unix_transport/test/CMakeLists.txt b/src/mp_transport/mp_unix_transport/test/CMakeLists.txt new file mode 100644 index 000000000..4d58bb8c5 --- /dev/null +++ b/src/mp_transport/mp_unix_transport/test/CMakeLists.txt @@ -0,0 +1,24 @@ + +add_cpputest( unix_transport_test + SOURCES + ../mp_unix_transport.cc + ../../../side_channel/side_channel.cc + ../../../side_channel/side_channel_format.cc + ../../../framework/module.cc + ../../../managers/connector_manager.cc + $ + LIBS + ${CMAKE_THREAD_LIBS_INIT} +) + + add_cpputest( unix_transport_module_test + SOURCES + ../mp_unix_transport_module.cc + ../../../framework/value.cc + ../../../sfip/sf_ip.cc + $ + ../../../framework/module.cc + LIBS + ${CMAKE_THREAD_LIBS_INIT} + ) + diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc new file mode 100644 index 000000000..c00ff41a6 --- /dev/null +++ b/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc @@ -0,0 +1,209 @@ +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "../mp_unix_transport_module.h" + +#include "framework/value.h" +#include "main/snort_config.h" +#include "main/snort.h" +#include "main/thread_config.h" + +#include +#include + + +static int warning_cnt = 0; +static int destroy_cnt = 0; + +namespace snort +{ + void WarningMessage(const char*,...) { warning_cnt++; } + + SnortConfig::SnortConfig(snort::SnortConfig const*, char const*) + { + max_procs = 2; + } + SnortConfig::~SnortConfig() + { + + } + + unsigned ThreadConfig::get_instance_max() + { + return 1; + } + unsigned Snort::get_process_id() + { + return 1; + } + unsigned get_instance_id() + { + return 1; + } + + MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig* config) : + MPTransport() + { + this->config = config; + } + MPUnixDomainTransport::~MPUnixDomainTransport() + { destroy_cnt++; } + void MPUnixDomainTransport::thread_init() + {} + void MPUnixDomainTransport::thread_term() + {} + void MPUnixDomainTransport::init_connection() + {} + void MPUnixDomainTransport::cleanup() + {} + void MPUnixDomainTransport::register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&) + {} + bool MPUnixDomainTransport::send_to_transport(MPEventInfo&) + { return true; } + void MPUnixDomainTransport::unregister_receive_handler() + { } + void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler&) + {} + bool MPUnixDomainTransport::configure(const SnortConfig*) + { return true; } + bool MPUnixDomainTransport::is_logging_enabled() + { return false; } + void MPUnixDomainTransport::enable_logging() + {} + void MPUnixDomainTransport::disable_logging() + {} + + char* snort_strdup(const char*) + { + return nullptr; + } +}; + +void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { } +void show_stats(PegCount*, const PegInfo*, const std::vector&, const char*, FILE*) { } + +using namespace snort; + +MPUnixDomainTransportModule* mod = nullptr; + +TEST_GROUP(MPUnixDomainTransportModuleTests) +{ + void setup() override + { + mod = (MPUnixDomainTransportModule*)mod_ctor(); + } + + void teardown() override + { + mod_dtor(mod); + } +}; + +TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigBegin) +{ + SnortConfig sc; + auto res = mod->begin("test", 0, &sc); + CHECK(res == true); + CHECK(mod->config != nullptr); +} + +TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigEnd) +{ + auto res = mod->end("test", 0, nullptr); + CHECK(res == true); +} + +TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigSet) +{ + Parameter p{"unix_domain_socket_path", Parameter::PT_STRING, nullptr, "test_value", nullptr}; + Parameter p2{"max_connect_retries", Parameter::PT_INT, nullptr, "15", nullptr}; + Parameter p3{"retry_interval_seconds", Parameter::PT_INT, nullptr, "33", nullptr}; + Parameter p4{"connect_timeout_seconds", Parameter::PT_INT, nullptr, "32", nullptr}; + Parameter p5{"consume_message_timeout_milliseconds", Parameter::PT_INT, nullptr, "200", nullptr}; + Parameter p6{"consume_message_batch_size", Parameter::PT_INT, nullptr, "20", nullptr}; + Parameter p7{"enable_logging", Parameter::PT_BOOL, nullptr, "true", nullptr}; + Value v("test_value"); + v.set(&p); + + SnortConfig sc; + mod->begin("test", 0, &sc); + auto res = mod->set(nullptr, v, nullptr); + + CHECK(res == true); + CHECK(strcmp("test_value", mod->config->unix_domain_socket_path.c_str()) == 0); + + v.set((double)15); + v.set(&p2); + res = mod->set(nullptr, v, nullptr); + CHECK(res == true); + CHECK(mod->config->max_retries == 15); + + v.set((double)33); + v.set(&p3); + res = mod->set(nullptr, v, nullptr); + CHECK(res == true); + CHECK(mod->config->retry_interval_seconds == 33); + + v.set((double)32); + v.set(&p4); + res = mod->set(nullptr, v, nullptr); + CHECK(res == true); + CHECK(mod->config->connect_timeout_seconds == 32); + + v.set((double)200); + v.set(&p5); + res = mod->set(nullptr, v, nullptr); + CHECK(res == true); + CHECK(mod->config->consume_message_timeout_milliseconds == 200); + + v.set((double)20); + v.set(&p6); + res = mod->set(nullptr, v, nullptr); + CHECK(res == true); + CHECK(mod->config->consume_message_batch_size == 20); + + v.set(true); + v.set(&p7); + res = mod->set(nullptr, v, nullptr); + CHECK(res == true); + CHECK(mod->config->enable_logging == true); +} + +TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigUnknownSet) +{ + warning_cnt = 0; + + Parameter p{"unknown_value", Parameter::PT_STRING, nullptr, "/tmp/unx_dmn_sck", nullptr}; + Value v("/tmp/unx_dmn_sck"); + v.set(&p); + + SnortConfig sc; + mod->begin("test", 0, &sc); + auto res = mod->set(nullptr, v, nullptr); + + CHECK(res == false); + CHECK(1 == warning_cnt); +} + +TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleGetUsage) +{ + auto res = mod->get_usage(); + CHECK(res == Module::Usage::GLOBAL); +} + +TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleCreateDestroyTransport) +{ + destroy_cnt = 0; + auto transport = mp_unixdomain_transport_ctor(mod); + CHECK(transport != nullptr); + + mp_unixdomain_transport_dtor(transport); + CHECK(destroy_cnt == 1); +} + +int main(int argc, char** argv) +{ + int return_value = CommandLineTestRunner::RunAllTests(argc, argv); + return return_value; +} \ No newline at end of file diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc new file mode 100644 index 000000000..d611e15b3 --- /dev/null +++ b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc @@ -0,0 +1,505 @@ +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include "framework/mp_transport.h" +#include "../mp_unix_transport.h" +#include "framework/counts.h" +#include "framework/mp_data_bus.h" +#include "main/snort.h" +#include "main/thread_config.h" +#include "main/snort_config.h" +#include "connectors/unixdomain_connector/unixdomain_connector.h" + +#include +#include + +#include +#include +#include +#include + +static int snort_instance_id = 0; + +static int accept_cnt = 0; + +static int test_socket_calls = 0; +static int test_bind_calls = 0; +static int test_listen_calls = 0; +static int test_accept_calls = 0; +static int test_close_calls = 0; +static int test_connect_calls = 0; +static int test_call_sock_created = 0; +static int test_serialize_calls = 0; +static int test_deserialize_calls = 0; + +int accept (int, struct sockaddr*, socklen_t*) +{ + test_accept_calls++; + return accept_cnt--; +} + +int close (int) +{ + test_close_calls++; + return 0; +} + +void clear_test_calls() +{ + test_socket_calls = 0; + test_bind_calls = 0; + test_listen_calls = 0; + test_accept_calls = 0; + test_close_calls = 0; + test_connect_calls = 0; + test_call_sock_created = 0; + test_serialize_calls = 0; + test_deserialize_calls = 0; +} + +namespace snort +{ + void ErrorMessage(const char*,...) { } + void WarningMessage(const char*,...) { } + void LogMessage(const char* s, ...) { } + void LogText(const char*, FILE*) {} + void ParseError(const char*, ...) { } + + unsigned ThreadConfig::get_instance_max() + { + return 2; // Mock value for testing + } + + SnortConfig::SnortConfig(snort::SnortConfig const*, char const*) + { + max_procs = 2; + } + SnortConfig::~SnortConfig() + { + + } + + unsigned Snort::get_process_id() + { + return snort_instance_id; + } + + unsigned get_instance_id() + { + return snort_instance_id; + } +}; +static int test_send_calls = 0; +UnixDomainConnector* listen_connector = nullptr; +UnixDomainConnector* call_connector = nullptr; + +void UnixDomainConnector::set_message_received_handler(std::function h) +{ + message_received_handler = h; +} + +static bool expect_update_change = false; +std::function test_update_handler = nullptr; + +void UnixDomainConnector::set_update_handler(std::function h) +{ + if(expect_update_change) + test_update_handler = h; +} + +static snort::ConnectorMsg* test_msg_answer = nullptr; +static snort::ConnectorMsg* test_msg_call = nullptr; +static uint8_t* test_msg_call_data = nullptr; +static uint8_t* test_msg_answer_data = nullptr; +UnixDomainConnectorListener::UnixDomainConnectorListener(char const*) // cppcheck-suppress uninitMemberVar +{} +UnixDomainConnectorListener::~UnixDomainConnectorListener() +{ +} +void UnixDomainConnectorListener::stop_accepting_connections() +{ + close(0); +} +void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnectorAcceptHandler h, UnixDomainConnectorConfig* cfg) +{ + socket(0,0,0); + while(accept_cnt > 0) + { + accept(0, nullptr, nullptr); + auto cfg_copy = new UnixDomainConnectorConfig(*cfg); + h(new UnixDomainConnector(*cfg_copy, 0, 0), cfg_copy); + } +} + +bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg& m, const ID&) +{ + test_send_calls++; + + if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL) + { + test_msg_call_data = new uint8_t[m.get_length()]; + memcpy(test_msg_call_data, m.get_data(), m.get_length()); + test_msg_call = new snort::ConnectorMsg(test_msg_call_data, m.get_length()); + if(!call_connector) + call_connector = this; + listen_connector->process_receive(); + } + else + { + test_msg_answer_data = new uint8_t[m.get_length()]; + memcpy(test_msg_answer_data, m.get_data(), m.get_length()); + test_msg_answer = new snort::ConnectorMsg(test_msg_answer_data, m.get_length()); + if(!listen_connector) + listen_connector = this; + call_connector->process_receive(); + } + + return true; +} +void UnixDomainConnector::process_receive() +{ + if (message_received_handler) + { + message_received_handler(); + } +} +bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg&&, const ID&) +{ return true; } +snort::ConnectorMsg UnixDomainConnector::receive_message(bool) +{ + if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL) + { + if (test_msg_answer) + { + snort::ConnectorMsg msg(test_msg_answer_data, test_msg_answer->get_length()); + delete test_msg_answer; + test_msg_answer = nullptr; + return std::move(msg); // cppcheck-suppress returnStdMoveLocal + } + } + else + { + if (test_msg_call) + { + snort::ConnectorMsg msg(test_msg_call_data, test_msg_call->get_length()); + delete test_msg_call; + test_msg_call = nullptr; + return std::move(msg); // cppcheck-suppress returnStdMoveLocal + } + } + return snort::ConnectorMsg(); +} +UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx) : Connector(config) // cppcheck-suppress uninitMemberVar +{ cfg = config; } // cppcheck-suppress useInitializationList +UnixDomainConnector::~UnixDomainConnector() +{ + close(0); +} + +UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler) +{ + if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL) + { + test_call_sock_created++; + auto new_conn = new UnixDomainConnector(cfg, 0, idx); + call_connector = new_conn; + return new_conn; + } + assert(false); + return nullptr; +} + +void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { } +void show_stats(PegCount*, const PegInfo*, const std::vector&, const char*, FILE*) { } + +using namespace snort; + + +static int s_socket_return = 1; +static int s_bind_return = 0; +static int s_listen_return = 0; +static int s_connect_return = 1; + +#ifdef __GLIBC__ +int socket (int, int, int) __THROW { test_socket_calls++; return s_socket_return; } +int bind (int, const struct sockaddr*, socklen_t) __THROW { test_bind_calls++; return s_bind_return; } +int listen (int, int) __THROW { test_listen_calls++; return s_listen_return; } +int connect (int, const struct sockaddr*, socklen_t) __THROW { test_connect_calls++; return s_connect_return; } +int unlink (const char *__name) __THROW { return 0;}; +#else +int socket (int, int, int) { test_socket_calls++; return s_socket_return; } +int bind (int, const struct sockaddr*, socklen_t) { test_bind_calls++; return s_bind_return; } +int listen (int, int) { test_listen_calls++; return s_listen_return; } +int connect (int, const struct sockaddr*, socklen_t) { test_connect_calls++; return s_connect_return; } +int unlink (const char *__name) { return 0;}; +#endif + +int fcntl (int __fd, int __cmd, ...) { return 0;} +ssize_t send (int, const void*, size_t n, int) { return n; } + + +std::mutex accept_mutex; +std::condition_variable accept_cond; + + +class TestDataEvent : public DataEvent +{ +public: + TestDataEvent() {} + ~TestDataEvent() override {} +}; + +bool serialize_mock(DataEvent* event, char*& buffer, uint16_t* length) +{ + test_serialize_calls++; + buffer = new char[9]; + *length = 9; + memcpy(buffer, "test_data", 9); + return true; +} + +bool deserialize_mock(const char* buffer, uint16_t length, DataEvent*& event) +{ + test_deserialize_calls++; + event = new TestDataEvent(); + return true; +} + +MPHelperFunctions mp_helper_functions_mock(serialize_mock, deserialize_mock); + +static MPUnixDomainTransportConfig test_config; +static MPUnixDomainTransport* test_transport = nullptr; + +static SnortConfig test_snort_config(nullptr, nullptr); + +TEST_GROUP(unix_transport_test_connectivity_group) +{ + void setup() override + { + test_snort_config.max_procs = 2; + test_transport = new MPUnixDomainTransport(&test_config); + test_transport->configure(&test_snort_config); + } + + void teardown() override + { + delete test_transport; + test_transport = nullptr; + } +}; + +static MPUnixDomainTransportConfig test_config_message; + +static MPTransport* test_transport_message_1 = nullptr; +static MPTransport* test_transport_message_2 = nullptr; + +static int reciveved_1_msg_cnt = 0; +static int reciveved_2_msg_cnt = 0; + +TEST_GROUP(unix_transport_test_messaging) +{ + void setup() override + { + test_snort_config.max_procs = 2; + + accept_cnt = 1; + + test_config_message.unix_domain_socket_path = "/tmp"; + test_config_message.max_processes = 2; + test_config_message.conn_retries = false; + test_config_message.retry_interval_seconds = 0; + test_config_message.max_retries = 0; + test_config_message.connect_timeout_seconds = 30; + + test_transport_message_1 = new MPUnixDomainTransport(&test_config_message); + snort_instance_id = 1; + test_transport_message_1->configure(&test_snort_config); + test_transport_message_1->init_connection(); + test_transport_message_1->register_receive_handler([reciveved_1_msg_cnt](const snort::MPEventInfo& e) + { + reciveved_1_msg_cnt++; + }); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + test_transport_message_2 = new MPUnixDomainTransport(&test_config_message); + snort_instance_id = 2; + test_transport_message_2->configure(&test_snort_config); + test_transport_message_2->init_connection(); + test_transport_message_2->register_receive_handler([reciveved_2_msg_cnt](const snort::MPEventInfo& e) + { + reciveved_2_msg_cnt++; + }); + } + + void teardown() override + { + delete test_transport_message_1; + test_transport_message_1 = nullptr; + delete test_transport_message_2; + test_transport_message_2 = nullptr; + delete[] test_msg_call_data; + test_msg_call_data = nullptr; + delete[] test_msg_answer_data; + test_msg_answer_data = nullptr; + } +}; + +TEST(unix_transport_test_connectivity_group, get_config) +{ + auto unix_transport = (MPUnixDomainTransport*)test_transport; + CHECK(unix_transport->get_config() == &test_config); +}; + +TEST(unix_transport_test_connectivity_group, set_logging_enabled_disabled) +{ + auto logging_status = test_transport->is_logging_enabled(); + CHECK(logging_status == false); + + test_transport->enable_logging(); + logging_status = test_transport->is_logging_enabled(); + CHECK(logging_status == true); + + test_transport->disable_logging(); + logging_status = test_transport->is_logging_enabled(); + CHECK(logging_status == false); +}; + +TEST(unix_transport_test_connectivity_group, init_connection_single_snort_instance) +{ + clear_test_calls(); + test_config.unix_domain_socket_path = "/tmp"; + test_config.max_processes = 1; + + test_transport->init_connection(); + + CHECK(test_socket_calls == 0); + CHECK(test_bind_calls == 0); + CHECK(test_listen_calls == 0); + CHECK(test_accept_calls == 0); + CHECK(test_close_calls == 0); + + test_transport->cleanup(); + CHECK(test_close_calls == 0); +}; + +TEST(unix_transport_test_connectivity_group, init_connection_first_snort_instance) +{ + clear_test_calls(); + snort_instance_id = 1; + + test_config.unix_domain_socket_path = "/tmp"; + test_config.max_processes = 2; + + accept_cnt = 1; + + test_transport->init_connection(); + + CHECK(test_accept_calls == 1); + + test_transport->cleanup(); + CHECK(test_close_calls == 2); +}; + +TEST(unix_transport_test_connectivity_group, init_connection_second_snort_instance) +{ + clear_test_calls(); + snort_instance_id = 2; + test_config.unix_domain_socket_path = "/tmp"; + test_config.max_processes = 2; + + test_transport->init_connection(); + + CHECK(test_bind_calls == 0); + CHECK(test_listen_calls == 0); + CHECK(test_accept_calls == 0); + CHECK(test_close_calls == 0); + CHECK(test_call_sock_created == 1); + + test_transport->cleanup(); + CHECK(test_close_calls == 1); +}; + +TEST(unix_transport_test_connectivity_group, connector_update_handler_call) +{ + clear_test_calls(); + + test_config.unix_domain_socket_path = "/tmp"; + test_config.max_processes = 2; + + accept_cnt = 1; + snort_instance_id = 1; + + test_update_handler = nullptr; + expect_update_change = true; + + test_transport->init_connection(); + + CHECK(test_update_handler != nullptr); + + test_update_handler(nullptr, false); + + CHECK(test_close_calls == 1); + expect_update_change = false; + test_update_handler = nullptr; +}; + +static TestDataEvent test_event; + +TEST(unix_transport_test_messaging, send_to_transport_biderectional) +{ + clear_test_calls(); + + test_transport_message_1->register_event_helpers(0, 0, mp_helper_functions_mock); + test_transport_message_2->register_event_helpers(0, 0, mp_helper_functions_mock); + + MPEventInfo event(&test_event, 0, 0); + + auto res = test_transport_message_1->send_to_transport(event); + + + CHECK(res == true); + CHECK(test_serialize_calls == 1); + + res = test_transport_message_2->send_to_transport(event); + + CHECK(res == true); + CHECK(test_serialize_calls == 2); + + std::this_thread::sleep_for(std::chrono::milliseconds(500)); + + CHECK(test_deserialize_calls == 2); + CHECK(reciveved_1_msg_cnt == 1); + CHECK(reciveved_2_msg_cnt == 1); + CHECK(test_send_calls == 2); +}; + +TEST(unix_transport_test_messaging, send_to_transport_no_helpers) +{ + clear_test_calls(); + + MPEventInfo event(&test_event, 0, 0); + + auto res = test_transport_message_1->send_to_transport(event); + CHECK(res == false); + CHECK(test_serialize_calls == 0); + CHECK(test_deserialize_calls == 0); + CHECK(reciveved_1_msg_cnt == 0); + CHECK(reciveved_2_msg_cnt == 0); + CHECK(test_send_calls == 0); + + res = test_transport_message_2->send_to_transport(event); + CHECK(res == false); + CHECK(test_serialize_calls == 0); + CHECK(test_deserialize_calls == 0); + CHECK(reciveved_1_msg_cnt == 0); + CHECK(reciveved_2_msg_cnt == 0); + CHECK(test_send_calls == 0); +} + +int main(int argc, char** argv) +{ + int return_value = CommandLineTestRunner::RunAllTests(argc, argv); + return return_value; +} \ No newline at end of file diff --git a/src/side_channel/side_channel.cc b/src/side_channel/side_channel.cc index 941c47859..3fe5a0f80 100644 --- a/src/side_channel/side_channel.cc +++ b/src/side_channel/side_channel.cc @@ -172,6 +172,9 @@ SideChannel::SideChannel(ScMsgFormat fmt) : msg_format(fmt) // return true iff we received any messages. bool SideChannel::process(int max_messages) { + if(!connector_receive) + return false; + bool received_message = false; while (true) -- 2.47.2