add_subdirectory(search_engines)
add_subdirectory(side_channel)
add_subdirectory(connectors)
+add_subdirectory(mp_transport)
# FIXIT-L Delegate building out the target objects list to subdirectories
$<TARGET_OBJECTS:target_based>
$<TARGET_OBJECTS:tcp_connector>
$<TARGET_OBJECTS:unixdomain_connector>
+ $<TARGET_OBJECTS:mp_unix_transport>
$<TARGET_OBJECTS:time>
$<TARGET_OBJECTS:trace>
+ $<TARGET_OBJECTS:mp_transports>
$<TARGET_OBJECTS:utils>
${STATIC_CODEC_PLUGINS}
${STATIC_NETWORK_INSPECTOR_PLUGINS}
int bind (int, const struct sockaddr*, socklen_t) { return s_bind_return; }
int listen (int, int) { return s_listen_return; }
#endif
-
-int accept (int, struct sockaddr*, socklen_t*) { return s_accept_return; }
+static bool use_test_accept_counter = false;
+static uint test_accept_counter = 0;
+int accept (int, struct sockaddr*, socklen_t*)
+{
+ if ( use_test_accept_counter )
+ {
+ if ( test_accept_counter == 0 )
+ return -1;
+ else
+ test_accept_counter--;
+ }
+ return s_accept_return;
+}
int close (int) { return 0; }
static void set_normal_status()
delete[] message;
}
+static const char* test_listener_path = "/tmp/test_path";
+UnixDomainConnectorListener* test_listener = nullptr;
+
+TEST_GROUP(unixdomain_connector_listener)
+{
+ void setup() override
+ {
+ test_listener = new UnixDomainConnectorListener(test_listener_path);
+ }
+
+ void teardown() override
+ {
+ if (test_listener)
+ {
+ delete test_listener;
+ test_listener = nullptr;
+ }
+ }
+};
+
+UnixDomainConnector* test_listener_connector = nullptr;
+UnixDomainConnectorConfig* test_listener_config = nullptr;
+
+void connection_callback(UnixDomainConnector* c, UnixDomainConnectorConfig* conf)
+{
+ assert(c != nullptr);
+ assert(conf != nullptr);
+
+ test_listener_connector = c;
+ test_listener_config = conf;
+}
+
+TEST(unixdomain_connector_listener, listener_accept_stop)
+{
+ UnixDomainConnectorConfig cfg;
+ cfg.direction = Connector::CONN_DUPLEX;
+ cfg.connector_name = "unixdomain";
+ cfg.paths.push_back(test_listener_path);
+ cfg.setup = UnixDomainConnectorConfig::Setup::ANSWER;
+ cfg.async_receive = true;
+
+ use_test_accept_counter = true;
+ test_accept_counter = 1;
+
+ test_listener->start_accepting_connections(&connection_callback, &cfg);
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+ test_listener->stop_accepting_connections();
+
+ CHECK(test_listener_connector != nullptr);
+ CHECK(test_listener_config != nullptr);
+ CHECK(test_listener_connector->get_connector_direction() == Connector::CONN_DUPLEX);
+ CHECK(test_listener_config->async_receive == true);
+
+ delete test_listener_connector;
+ test_listener_connector = nullptr;
+ delete test_listener_config;
+ test_listener_config = nullptr;
+}
+
int main(int argc, char** argv)
{
int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
}
// Function to handle connection retries
-static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx) {
- ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr);
+static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) {
+ if(update_handler)
+ update_handler(nullptr, (cfg.conn_retries > 0));
+ else
+ ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr);
- if ( cfg.setup == UnixDomainConnectorConfig::Setup::CALL and cfg.conn_retries) {
+ if (cfg.conn_retries)
+ {
const auto& paths = cfg.paths;
if (idx >= paths.size())
return;
-
- uint32_t retry_count = 0;
+
const char* path = paths[idx].c_str();
- while (retry_count < cfg.max_retries) {
- int sfd;
- if (attempt_connection(sfd, path)) {
- // Connection successful
- UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
- LogMessage("UnixDomainC: Connected to %s", path);
- ConnectorManager::update_thread_connector(cfg.connector_name, idx, unixdomain_conn);
- break;
+ if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+ {
+ LogMessage("UnixDomainC: Attempting to reconnect to %s\n", cfg.paths[idx].c_str());
+
+ uint32_t retry_count = 0;
+
+ while (retry_count < cfg.max_retries) {
+ int sfd;
+ if (attempt_connection(sfd, path)) {
+ // Connection successful
+ UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
+ LogMessage("UnixDomainC: Connected to %s", path);
+ if(update_handler)
+ {
+ unixdomain_conn->set_update_handler(update_handler);
+ update_handler(unixdomain_conn, false);
+ }
+ else
+ ConnectorManager::update_thread_connector(cfg.connector_name, idx, unixdomain_conn);
+ break;
+ }
+
+ std::this_thread::sleep_for(std::chrono::seconds(cfg.retry_interval));
+ retry_count++;
}
-
- std::this_thread::sleep_for(std::chrono::seconds(cfg.retry_interval));
- retry_count++;
+ }
+ else if (cfg.setup == UnixDomainConnectorConfig::Setup::ANSWER)
+ {
+ return;
+ }
+ else
+ {
+ LogMessage("UnixDomainC: Unexpected setup type at retry connection\n");
}
}
}
-static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx) {
- std::thread retry_thread(connection_retry_handler, cfg, idx);
+static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) {
+ std::thread retry_thread(connection_retry_handler, cfg, idx, update_handler);
retry_thread.detach();
}
sock_fd = -1;
}
- start_retry_thread(cfg, instance_id);
+ start_retry_thread(cfg, instance_id, update_handler);
return;
}
else if (rval > 0 && pfds[0].revents & POLLIN) {
ErrorMessage("UnixDomainC: Input Thread: overrun\n");
delete connector_msg;
}
+ if(message_received_handler)
+ {
+ message_received_handler();
+ }
}
}
+void UnixDomainConnector::set_update_handler(UnixDomainConnectorUpdateHandler handler)
+{
+ update_handler = std::move(handler);
+}
+
+void UnixDomainConnector::set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler)
+{
+ message_received_handler = std::move(handler);
+}
+
void UnixDomainConnector::receive_processing_thread() {
while (run_thread.load(std::memory_order_relaxed)) {
process_receive();
delete m;
}
-static UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx) {
+UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler) {
int sfd;
if (!attempt_connection(sfd, path)) {
if (cfg.conn_retries) {
// Spawn a new thread to handle connection retries
- start_retry_thread(cfg, idx);
+ start_retry_thread(cfg, idx, update_handler);
return nullptr; // Return nullptr as the connection is not yet established
} else {
}
LogMessage("UnixDomainC: Connected to %s", path);
UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
+ unixdomain_conn->set_update_handler(update_handler);
+ if(update_handler)
+ update_handler(unixdomain_conn, false);
+
return unixdomain_conn;
}
LogMessage("UnixDomainC: Accepted connection from %s \n", path);
return new UnixDomainConnector(cfg, peer_sfd, idx);
-}
+}
static bool is_valid_path(const std::string& path) {
if (path.empty()) {
&unixdomain_connector_api.base,
nullptr
};
+
+UnixDomainConnectorListener::UnixDomainConnectorListener(const char *path)
+{
+ assert(path);
+
+ sock_path = strdup(path);
+ sock_fd = 0;
+ accept_thread = nullptr;
+ should_accept = false;
+}
+
+UnixDomainConnectorListener::~UnixDomainConnectorListener()
+{
+ stop_accepting_connections();
+ free(sock_path);
+ sock_path = nullptr;
+}
+
+void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnectorAcceptHandler handler, UnixDomainConnectorConfig* config)
+{
+ assert(accept_thread == nullptr);
+ assert(sock_path);
+
+ should_accept = true;
+ accept_thread = new std::thread([this, handler, config]()
+ {
+ sock_fd = socket(AF_UNIX, SOCK_STREAM, 0);
+ if (sock_fd == -1) {
+ ErrorMessage("UnixDomainC: socket error: %s \n", strerror(errno));
+ return;
+ }
+
+ struct sockaddr_un addr;
+ memset(&addr, 0, sizeof(struct sockaddr_un));
+ addr.sun_family = AF_UNIX;
+ strncpy(addr.sun_path, sock_path, sizeof(addr.sun_path) - 1);
+
+ unlink(sock_path);
+
+ if (bind(sock_fd, (struct sockaddr*)&addr, sizeof(struct sockaddr_un)) == -1) {
+ ErrorMessage("UnixDomainC: bind error: %s \n", strerror(errno));
+ close(sock_fd);
+ return;
+ }
+
+ if (listen(sock_fd, 10) == -1) {
+ ErrorMessage("UnixDomainC: listen error: %s \n", strerror(errno));
+ close(sock_fd);
+ return;
+ }
+
+ ushort error_count = 0;
+
+ while (should_accept) {
+ if(error_count > 10)
+ {
+ ErrorMessage("UnixDomainC: Too many errors, stopping accept thread\n");
+ close(sock_fd);
+ return;
+ }
+ int peer_sfd = accept(sock_fd, nullptr, nullptr);
+ if (peer_sfd == -1)
+ {
+ error_count++;
+ ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno));
+ continue;
+ }
+ error_count = 0;
+ auto config_copy = new UnixDomainConnectorConfig(*config);
+ auto unix_conn = new UnixDomainConnector(*config_copy, peer_sfd, 0);
+ handler(unix_conn, config_copy);
+ }
+ });
+
+}
+
+void UnixDomainConnectorListener::stop_accepting_connections()
+{
+ if(should_accept)
+ {
+ should_accept = false;
+ close(sock_fd);
+ if (accept_thread && accept_thread->joinable()) {
+ accept_thread->join();
+ }
+ delete accept_thread;
+ accept_thread = nullptr;
+ }
+}
#include <atomic>
#include <thread>
+#include <functional>
#include "framework/connector.h"
#include "managers/connector_manager.h"
uint16_t connector_msg_length;
};
+class UnixDomainConnector;
+
+typedef std::function<void (UnixDomainConnector*,bool)> UnixDomainConnectorUpdateHandler;
+typedef std::function<void ()> UnixDomainConnectorMessageReceivedHandler;
+
class UnixDomainConnector : public snort::Connector
{
public:
snort::ConnectorMsg receive_message(bool) override;
void process_receive();
+ void set_update_handler(UnixDomainConnectorUpdateHandler handler);
+ void set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler);
+
int sock_fd;
private:
ReceiveRing* receive_ring;
size_t instance_id;
UnixDomainConnectorConfig cfg;
+
+ UnixDomainConnectorUpdateHandler update_handler;
+ UnixDomainConnectorMessageReceivedHandler message_received_handler;
+};
+
+typedef std::function<void (UnixDomainConnector*, UnixDomainConnectorConfig*)> UnixDomainConnectorAcceptHandler;
+
+class UnixDomainConnectorListener
+{
+public:
+ UnixDomainConnectorListener(const char* path);
+ ~UnixDomainConnectorListener();
+
+ void start_accepting_connections(UnixDomainConnectorAcceptHandler handler, UnixDomainConnectorConfig* config);
+ void stop_accepting_connections();
+
+ private:
+ char* sock_path;
+ int sock_fd;
+ std::thread* accept_thread;
+ std::atomic<bool> should_accept;
+
};
+extern SO_PUBLIC UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler = nullptr);
+
#endif // UNIXDOMAIN_CONNECTOR_H
UnixDomainConnectorConfig()
{ direction = snort::Connector::CONN_DUPLEX; async_receive = true; }
+ UnixDomainConnectorConfig(const UnixDomainConnectorConfig& other) :
+ paths(other.paths), setup(other.setup),
+ conn_retries(other.conn_retries),
+ retry_interval(other.retry_interval),
+ max_retries(other.max_retries),
+ connect_timeout_seconds(other.connect_timeout_seconds),
+ async_receive(other.async_receive)
+ {
+ direction = other.direction;
+ }
+
std::vector<std::string> paths;
Setup setup = {};
bool conn_retries = false;
uint32_t retry_interval = 4;
uint32_t max_retries = 5;
+ uint32_t connect_timeout_seconds = 30;
bool async_receive;
};
range.h
so_rule.h
value.h
+ mp_transport.h
)
add_library ( framework OBJECT
PT_LOGGER,
PT_CONNECTOR,
PT_POLICY_SELECTOR,
+ PT_MP_TRANSPORT,
PT_MAX
};
#include "pub_sub/intrinsic_event_ids.h"
#include "utils/stats.h"
#include "main/snort_types.h"
+#include "managers/mp_transport_manager.h"
+#include "log/messages.h"
using namespace snort;
UNUSED(event_info);
}
+
//--------------------------------------------------------------------------
// private methods
//--------------------------------------------------------------------------
#include "main/snort_types.h"
#include "data_bus.h"
+#include "framework/mp_transport.h"
#include <bitset>
namespace snort
struct Packet;
struct SnortConfig;
-typedef bool (*MPSerializeFunc)(const DataEvent& event, char** buffer, size_t* length);
-typedef bool (*MPDeserializeFunc)(const char* buffer, size_t length, DataEvent* event);
+typedef bool (*MPSerializeFunc)(DataEvent* event, char*& buffer, uint16_t* length);
+typedef bool (*MPDeserializeFunc)(const char* buffer, uint16_t length, DataEvent*& event);
// Similar to the DataBus class, the MPDataBus class uses uses a combination of PubKey and event ID
// for event subscriptions and publishing. New MP-specific event type enums should be added to the
{
MPEventType type;
unsigned pub_id;
- DataEvent event;
- MPEventInfo(const DataEvent& e, MPEventType t, unsigned id = 0)
+ DataEvent* event;
+ MPEventInfo(DataEvent* e, MPEventType t, unsigned id = 0)
: type(t), pub_id(id), event(e) {}
};
struct MPHelperFunctions {
- MPSerializeFunc* serializer;
- MPDeserializeFunc* deserializer;
+ MPSerializeFunc serializer;
+ MPDeserializeFunc deserializer;
- MPHelperFunctions(MPSerializeFunc* s, MPDeserializeFunc* d)
+ MPHelperFunctions(MPSerializeFunc s, MPDeserializeFunc d)
: serializer(s), deserializer(d) {}
};
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transport.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_TRANSPORT_H
+#define MP_TRANSPORT_H
+
+#include "main/snort_types.h"
+#include "framework/base_api.h"
+
+#include <functional>
+
+namespace snort
+{
+
+#define MP_TRANSPORT_API_VERSION ((BASE_API_VERSION << 16) | 1)
+
+struct SnortConfig;
+struct MPEventInfo;
+struct MPHelperFunctions;
+
+typedef std::function<void (const MPEventInfo& event_info)> TransportReceiveEventHandler;
+
+class MPTransport
+{
+ public:
+
+ MPTransport() = default;
+ virtual ~MPTransport() = default;
+
+ virtual bool configure(const SnortConfig*) = 0;
+ virtual void thread_init() = 0;
+ virtual void thread_term() = 0;
+ virtual void init_connection() = 0;
+ virtual bool send_to_transport(MPEventInfo& event) = 0;
+ virtual void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) = 0;
+ virtual void register_receive_handler(const TransportReceiveEventHandler& handler) = 0;
+ virtual void unregister_receive_handler() = 0;
+ virtual void enable_logging() = 0;
+ virtual void disable_logging() = 0;
+ virtual bool is_logging_enabled() = 0;
+};
+
+
+typedef MPTransport* (* MPTransportNewFunc)(Module*);
+typedef void (* MPTransportDelFunc)(MPTransport*);
+typedef void (* MPTransportThreadInitFunc)(MPTransport*);
+typedef void (* MPTransportThreadTermFunc)(MPTransport*);
+typedef void (* MPTransportFunc)();
+
+struct MPTransportApi
+{
+ BaseApi base;
+ unsigned flags;
+
+ MPTransportFunc pinit; // plugin init
+ MPTransportFunc pterm; // cleanup pinit()
+ MPTransportThreadInitFunc tinit; // thread local init
+ MPTransportThreadTermFunc tterm; // cleanup tinit()
+
+ MPTransportNewFunc ctor;
+ MPTransportDelFunc dtor;
+};
+
+}
+#endif // MP_TRANSPORT_H
#include "managers/plugin_manager.h"
#include "managers/script_manager.h"
#include "managers/so_manager.h"
+#include "managers/mp_transport_manager.h"
#include "packet_io/sfdaq.h"
#include "utils/util.h"
ModuleManager::term();
PluginManager::release_plugins();
ScriptManager::release_scripts();
+ MPTransportManager::term();
delete SnortConfig::get_conf();
exit(0);
}
#include "managers/plugin_manager.h"
#include "managers/policy_selector_manager.h"
#include "managers/script_manager.h"
+#include "managers/mp_transport_manager.h"
#include "memory/memory_cap.h"
#include "network_inspectors/network_inspectors.h"
#include "packet_io/active.h"
#include "trace/trace_api.h"
#include "trace/trace_config.h"
#include "trace/trace_logger.h"
+#include "mp_transport/mp_transports.h"
#include "utils/stats.h"
#include "utils/util.h"
load_stream_inspectors();
load_network_inspectors();
load_service_inspectors();
+ load_mp_transports();
snort_cmd_line_conf = parse_cmd_line(argc, argv);
SnortConfig::set_conf(snort_cmd_line_conf);
host_cache.term();
PluginManager::release_plugins();
ScriptManager::release_scripts();
+ MPTransportManager::term();
memory::MemoryCap::term();
detection_filter_term();
so_manager.h
connector_manager.cc
connector_manager.h
+ mp_transport_manager.cc
+ mp_transport_manager.h
)
add_custom_command (
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transport_manager.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_transport_manager.h"
+
+#include <unordered_map>
+
+using namespace snort;
+
+struct MPTransportHandler
+{
+ MPTransportHandler(MPTransport* transport, const MPTransportApi* api)
+ : transport(transport), api(api) {}
+ MPTransport* transport;
+ const MPTransportApi* api;
+};
+
+static std::unordered_map<std::string, MPTransportHandler*> transports_map;
+
+void MPTransportManager::instantiate(const MPTransportApi *api, Module *mod, SnortConfig*)
+{
+ if(transports_map.find(api->base.name) != transports_map.end())
+ {
+ return;
+ }
+
+ transports_map.insert(std::make_pair(api->base.name, new MPTransportHandler(api->ctor(mod), api)));
+}
+
+MPTransport *MPTransportManager::get_transport(const std::string &name)
+{
+ auto it = transports_map.find(name);
+ if (it != transports_map.end())
+ {
+ return it->second->transport;
+ }
+ return nullptr;
+}
+
+void MPTransportManager::add_plugin(const MPTransportApi *api)
+{
+ if (api->pinit)
+ {
+ api->pinit();
+ }
+}
+
+void MPTransportManager::thread_init()
+{
+ for (auto &transport : transports_map)
+ {
+ if (transport.second->api->tinit)
+ {
+ transport.second->api->tinit(transport.second->transport);
+ }
+ }
+}
+
+void MPTransportManager::thread_term()
+{
+ for (auto &transport : transports_map)
+ {
+ if (transport.second->api->tterm)
+ {
+ transport.second->api->tterm(transport.second->transport);
+ }
+ }
+}
+
+void MPTransportManager::term()
+{
+ for (auto &transport : transports_map)
+ {
+ if (transport.second->api->dtor)
+ {
+ transport.second->api->dtor(transport.second->transport);
+ }
+ if (transport.second->api->pterm)
+ {
+ transport.second->api->pterm();
+ }
+ delete transport.second;
+ }
+ transports_map.clear();
+}
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transport_manager.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_TRANSPORT_MANAGER_H
+#define MP_TRANSPORT_MANAGER_H
+
+// Manager for multiprocess layer objects
+
+#include <string>
+#include "framework/mp_transport.h"
+
+namespace snort
+{
+class Module;
+struct SnortConfig;
+
+
+//-------------------------------------------------------------------------
+
+class MPTransportManager
+{
+public:
+ static void instantiate(const MPTransportApi* api, Module* mod, SnortConfig*);
+ static MPTransport* get_transport(const std::string& name);
+
+ static void add_plugin(const MPTransportApi* api);
+
+ static void thread_init();
+ static void thread_term();
+
+ static void term();
+};
+}
+#endif
+
#include "ips_manager.h"
#include "module_manager.h"
#include "mpse_manager.h"
+#include "mp_transport_manager.h"
#include "policy_selector_manager.h"
#include "script_manager.h"
#include "so_manager.h"
{ "logger", LOGAPI_VERSION, sizeof(LogApi) },
{ "connector", CONNECTOR_API_VERSION, sizeof(ConnectorApi) },
{ "policy_selector", POLICY_SELECTOR_API_VERSION, sizeof(PolicySelectorApi) },
+ { "mp_transport", MP_TRANSPORT_API_VERSION, sizeof(MPTransportApi) },
};
#else
// this gets around the sequence issue with some compilers
[PT_LOGGER] = { stringify(PT_LOGGER), LOGAPI_VERSION, sizeof(LogApi) },
[PT_CONNECTOR] = { stringify(PT_CONNECTOR), CONNECTOR_API_VERSION, sizeof(ConnectorApi) },
[PT_POLICY_SELECTOR] = { stringify(PT_POLICY_SELECTOR), POLICY_SELECTOR_API_VERSION,
- sizeof(PolicySelectorApi) }
+ sizeof(PolicySelectorApi) },
+ [PT_MP_TRANSPORT] = { stringify(PT_MP_TRANSPORT), MP_TRANSPORT_API_VERSION,
+ sizeof(MPTransportApi) }
};
#endif
case PT_POLICY_SELECTOR:
PolicySelectorManager::add_plugin((const PolicySelectorApi*)p.api);
break;
+
+ case PT_MP_TRANSPORT:
+ MPTransportManager::add_plugin((const MPTransportApi*)p.api);
+ break;
default:
assert(false);
//IpsManager::instantiate((SoApi*)api, mod, sc);
break;
+ case PT_MP_TRANSPORT:
+ MPTransportManager::instantiate((const MPTransportApi*)api, mod, sc);
+ break;
+
case PT_LOGGER:
EventManager::instantiate((const LogApi*)api, mod, sc);
break;
get_inspector_stubs.h
../inspector_manager.cc
)
+
+add_cpputest(mp_transport_manager_test
+ SOURCES
+ mp_transport_manager_test.cc
+ ../mp_transport_manager.cc
+ ../../framework/module.cc
+)
--- /dev/null
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "framework/mp_data_bus.h"
+#include "framework/mp_transport.h"
+#include "framework/module.h"
+#include "main/thread_config.h"
+
+#include "../mp_transport_manager.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+#include <unordered_map>
+
+#define MODULE_NAME "mock_transport"
+#define MODULE_HELP "mock transport for testing"
+
+static int test_transport_ctor_calls = 0;
+static int test_transport_dtor_calls = 0;
+static int test_transport_pinit_calls = 0;
+static int test_transport_pterm_calls = 0;
+static int test_transport_tinit_calls = 0;
+static int test_transport_tterm_calls = 0;
+
+namespace snort
+{
+
+class MockTransport : public MPTransport
+{
+ public:
+ MockTransport() : MPTransport()
+ { }
+ virtual ~MockTransport() override
+ { }
+ virtual bool send_to_transport(MPEventInfo& event) override
+ { return true; }
+ virtual void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) override
+ { }
+ virtual void init_connection() override
+ { }
+ virtual void register_receive_handler(const TransportReceiveEventHandler& handler) override
+ { }
+ virtual void unregister_receive_handler() override
+ { }
+ virtual void thread_init() override
+ { }
+ virtual void thread_term() override
+ { }
+ virtual bool configure(const SnortConfig*) override
+ { return true; }
+ virtual void enable_logging() override
+ { }
+ virtual void disable_logging() override
+ { }
+ virtual bool is_logging_enabled() override
+ { return false; }
+};
+
+unsigned get_instance_id() { return 0; }
+unsigned ThreadConfig::get_instance_max() { return 1; }
+};
+
+
+using namespace snort;
+
+void show_stats(unsigned long*, PegInfo const*, std::vector<unsigned int, std::allocator<unsigned int> > const&, char const*, _IO_FILE*)
+{}
+void show_stats(unsigned long*, PegInfo const*, unsigned int, char const*)
+{}
+
+static void mock_transport_tinit(MPTransport* t)
+{
+ test_transport_tinit_calls++;
+}
+static void mock_transport_tterm(MPTransport* t)
+{
+ test_transport_tterm_calls++;
+}
+static MPTransport* mock_transport_ctor(Module* m)
+{
+ test_transport_ctor_calls++;
+ return new MockTransport();
+}
+static void mock_transport_dtor(MPTransport* t)
+{
+ test_transport_dtor_calls++;
+ delete t;
+}
+
+static void mock_transport_pinit()
+{
+ // Mock plugin init
+ test_transport_pinit_calls++;
+}
+static void mock_transport_pterm()
+{
+ // Mock plugin term
+ test_transport_pterm_calls++;
+}
+
+static void clear_test_calls()
+{
+ test_transport_ctor_calls = 0;
+ test_transport_dtor_calls = 0;
+ test_transport_pinit_calls = 0;
+ test_transport_pterm_calls = 0;
+ test_transport_tinit_calls = 0;
+ test_transport_tterm_calls = 0;
+}
+
+static struct MPTransportApi mock_transport_api =
+{
+ {
+ PT_MP_TRANSPORT,
+ sizeof(MPTransportApi),
+ MP_TRANSPORT_API_VERSION,
+ 2,
+ API_RESERVED,
+ API_OPTIONS,
+ MODULE_NAME,
+ MODULE_HELP,
+ nullptr,
+ nullptr
+ },
+ 0,
+ mock_transport_pinit,
+ mock_transport_pterm,
+ mock_transport_tinit,
+ mock_transport_tterm,
+ mock_transport_ctor,
+ mock_transport_dtor
+};
+
+TEST_GROUP(mp_transport_manager_test_group)
+{
+ void setup() override
+ {
+ clear_test_calls();
+ }
+
+ void teardown() override
+ {
+ MPTransportManager::term();
+ }
+};
+
+TEST(mp_transport_manager_test_group, instantiate_transport_object)
+{
+ MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+ CHECK(test_transport_ctor_calls == 1);
+}
+
+TEST(mp_transport_manager_test_group, get_transport_object)
+{
+ MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+ MPTransport* transport = MPTransportManager::get_transport(MODULE_NAME);
+ CHECK(transport != nullptr);
+
+ transport = MPTransportManager::get_transport("non_existent_transport");
+ CHECK(transport == nullptr);
+
+ MPTransportManager::term();
+
+ transport = MPTransportManager::get_transport(MODULE_NAME);
+ CHECK(transport == nullptr);
+}
+
+TEST(mp_transport_manager_test_group, add_plugin)
+{
+ MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+ MPTransportManager::add_plugin(&mock_transport_api);
+ CHECK(test_transport_pinit_calls == 1);
+
+ MPTransportManager::term();
+ CHECK(test_transport_pterm_calls == 1);
+}
+
+TEST(mp_transport_manager_test_group, thread_init_term)
+{
+ MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+ MPTransportManager::thread_init();
+ CHECK(test_transport_tinit_calls == 1);
+
+ MPTransportManager::thread_term();
+ CHECK(test_transport_tterm_calls == 1);
+}
+
+int main(int argc, char** argv)
+{
+ // Allocate object in internal MPTransportManager unordered_map to prevent false-positive memory leak detection ( buckets allocation at first insert )
+ MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+ MPTransportManager::term();
+ int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
+ return return_value;
+}
\ No newline at end of file
--- /dev/null
+add_subdirectory(mp_unix_transport)
+
+add_library( mp_transports OBJECT
+ mp_transports.cc
+ mp_transports.h
+)
\ No newline at end of file
--- /dev/null
+The `mp_transport` files provide the foundation for message-passing transport mechanisms, enabling communication
+between distributed components or processes in a modular and extensible manner. These files abstract the complexities
+of underlying communication protocols, offering a unified API for sending, receiving, and managing messages. This
+abstraction ensures that developers can focus on higher-level application logic without worrying about protocol-specific
+details. These transports are utilized by `mp_data_bus.h` to transfer events between distributed Snort instances.
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transports.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_transports.h"
+
+#include "framework/mp_transport.h"
+#include "managers/plugin_manager.h"
+
+using namespace snort;
+
+extern const BaseApi* mp_unix_transport[];
+
+void load_mp_transports()
+{
+ PluginManager::load_plugins(mp_unix_transport);
+}
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transports.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_TRANSPORTS_H
+#define MP_TRANSPORTS_H
+
+void load_mp_transports();
+
+#endif // MP_TRANSPORTS_H
--- /dev/null
+
+set( MP_TRANSPORT_INCLUDES
+ mp_unix_transport.h
+ mp_unix_transport_module.h
+)
+
+add_library( mp_unix_transport OBJECT
+ ${MP_TRANSPORT_INCLUDES}
+ mp_unix_transport.cc
+ mp_unix_transport_module.cc
+)
+
+install (FILES ${MIME_INCLUDES}
+ DESTINATION "${INCLUDE_INSTALL_PATH}/mp_unix_transport"
+)
+
+add_dependencies( mp_unix_transport framework )
+add_dependencies( mp_unix_transport unixdomain_connector )
+
+find_package(Threads REQUIRED)
+target_link_libraries(mp_unix_transport PRIVATE Threads::Threads)
+target_link_libraries(mp_unix_transport PRIVATE $<TARGET_OBJECTS:unixdomain_connector>)
+
+add_subdirectory( test )
\ No newline at end of file
--- /dev/null
+The MP Unix Domain Transport provides an implementation of the Multi-Process (MP) Transport
+interface using existing `UnixDomainConnector` infrastructure. This transport enables
+inter-process communication (IPC) between multiple Snort instances running on the same host,
+allowing them to exchange events and data.
+
+ * MPUnixDomainTransport - Main class implementing the MPTransport interface with Unix domain socket functionality
+ * MPUnixDomainTransportModule - Module class that handles configuration parameters and provides an API for MPUnixDomainTransport creation
+ * MPUnixDomainTransportConfig - Configuration structure for socket paths and connection parameters
+
+Connention between snort establishes in next sequence:
+* First Snort instance acts as a server socket accepting connections from other instances
+* Additional instances connect to existing socket paths
+* Dynamic re-connection handling with configurable retry parameters
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_unix_transport.h"
+
+#include <cstring>
+#include <fcntl.h>
+#include <iostream>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include "framework/mp_data_bus.h"
+#include "log/messages.h"
+#include "main/snort.h"
+#include "main/snort_config.h"
+
+static std::mutex _receive_mutex;
+static std::mutex _update_connectors_mutex;
+
+#define UNIX_SOCKET_NAME_PREFIX "/snort_unix_connector_"
+
+#define MP_TRANSPORT_LOG_LABEL "MPUnixTransport"
+
+#define MP_TRANSPORT_LOG(msg, ...) do { \
+ if (!this->is_logging_enabled_flag) \
+ break; \
+ LogMessage(msg, __VA_ARGS__); \
+ } while (0)
+
+namespace snort
+{
+
+#pragma pack(push, 1)
+ enum MPTransportMessageType
+ {
+ EVENT_MESSAGE = 0,
+ MAX_TYPE
+ };
+
+struct MPTransportMessageHeader
+{
+ MPTransportMessageType type;// Type of the message
+ int32_t pub_id; // Identifier for the module sending or receiving the message
+ int32_t event_id; // Identifier for the specific event
+ uint16_t data_length; // Length of the data payload
+};
+
+struct MPTransportMessage
+{
+ MPTransportMessageHeader header; // Header containing metadata about the message
+ char* data; // Placeholder for the actual data payload
+};
+#pragma pack(pop)
+
+void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg)
+{
+ if (transport_receive_handler and msg)
+ {
+ if (msg->content_length < sizeof(MPTransportMessage))
+ {
+ MP_TRANSPORT_LOG("%s: Incomplete message received\n", MP_TRANSPORT_LOG_LABEL);
+ return;
+ }
+
+ MPTransportMessageHeader* transport_message_header = (MPTransportMessageHeader*)msg->content;
+
+ if (transport_message_header->type >= MAX_TYPE)
+ {
+ MP_TRANSPORT_LOG("%s: Invalid message type received\n", MP_TRANSPORT_LOG_LABEL);
+ return;
+ }
+
+ auto deserialize_func = get_event_deserialization_function(transport_message_header->pub_id, transport_message_header->event_id);
+ if (!deserialize_func)
+ {
+ MP_TRANSPORT_LOG("%s: No deserialization function found for event: type %d, id %d\n", MP_TRANSPORT_LOG_LABEL, transport_message_header->type, transport_message_header->event_id);
+ return;
+ }
+
+ DataEvent* internal_event = nullptr;
+ (deserialize_func)((const char*)(msg->content + sizeof(MPTransportMessageHeader)), transport_message_header->data_length, internal_event);
+ MPEventInfo event(internal_event, transport_message_header->event_id, transport_message_header->pub_id);
+
+ (transport_receive_handler)(event);
+
+ delete internal_event;
+ }
+ delete msg;
+}
+
+void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector, UnixDomainConnectorConfig* cfg)
+{
+ assert(connector);
+ assert(cfg);
+
+ std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+
+ auto side_channel = new SideChannel(ScMsgFormat::BINARY);
+ side_channel->connector_receive = connector;
+ side_channel->connector_transmit = side_channel->connector_receive;
+ side_channel->register_receive_handler(std::bind(&MPUnixDomainTransport::side_channel_receive_handler, this, std::placeholders::_1));
+ connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
+ this->side_channels.push_back(new SideChannelHandle(side_channel, cfg));
+ connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+}
+
+MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c) : MPTransport(),
+ config(c)
+{
+ this->is_logging_enabled_flag = c->enable_logging;
+}
+
+MPUnixDomainTransport::~MPUnixDomainTransport()
+{
+ cleanup();
+}
+
+bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
+{
+ auto serialize_func = get_event_serialization_function(event.pub_id, event.type);
+
+ if (!serialize_func)
+ {
+ MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event.type);
+ return false;
+ }
+
+ MPTransportMessage transport_message;
+ transport_message.header.type = EVENT_MESSAGE;
+ transport_message.header.pub_id = event.pub_id;
+ transport_message.header.event_id = event.type;
+
+
+ (serialize_func)(event.event, transport_message.data, &transport_message.header.data_length);
+ for (auto &&sc_handler : this->side_channels)
+ {
+ auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length);
+ memcpy(msg->content, &transport_message, sizeof(MPTransportMessageHeader));
+ memcpy(msg->content + sizeof(MPTransportMessageHeader), transport_message.data, transport_message.header.data_length);
+ auto send_result = sc_handler->side_channel->transmit_message(msg);
+ if (!send_result)
+ {
+ MP_TRANSPORT_LOG("%s: Failed to send message to side channel\n", MP_TRANSPORT_LOG_LABEL);
+ }
+ }
+
+ delete[] transport_message.data;
+
+ return true;
+}
+
+void MPUnixDomainTransport::register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper)
+{
+ assert(helper.deserializer);
+ assert(helper.serializer);
+
+ this->event_helpers[pub_id] = SerializeFunctionHandle();
+ this->event_helpers[pub_id].serialize_functions.insert({event_id, helper});
+}
+
+void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler& handler)
+{
+ this->transport_receive_handler = handler;
+}
+
+void MPUnixDomainTransport::unregister_receive_handler()
+{
+ this->transport_receive_handler = nullptr;
+}
+
+void MPUnixDomainTransport::process_messages_from_side_channels()
+{
+ std::unique_lock<std::mutex> lock(_receive_mutex);
+ do
+ {
+ if ( (std::cv_status::timeout == this->consume_thread_cv.wait_for(lock, std::chrono::milliseconds(config->consume_message_timeout_milliseconds)) )
+ and this->consume_message_received == false )
+ {
+ continue;
+ }
+
+ {
+ std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+ bool messages_left;
+
+ do
+ {
+ messages_left = false;
+ for (auto &&sc_handler : this->side_channels)
+ {
+ messages_left |= sc_handler->side_channel->process(config->consume_message_batch_size);
+ }
+ } while (messages_left);
+ }
+
+ this->consume_message_received = false;
+
+ } while (this->is_running);
+}
+
+void MPUnixDomainTransport::notify_process_thread()
+{
+ this->consume_thread_cv.notify_all();
+ this->consume_message_received = true;
+}
+
+void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_recconecting, SideChannel *side_channel)
+{
+ std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+ if (side_channel->connector_receive)
+ {
+ delete side_channel->connector_receive;
+ side_channel->connector_receive = side_channel->connector_transmit = nullptr;
+ }
+
+ if (connector)
+ {
+ side_channel->connector_receive = side_channel->connector_transmit = connector;
+ }
+ else
+ {
+ if (is_recconecting == false)
+ {
+ MP_TRANSPORT_LOG("%s: Accepted connection interrupted, removing handle\n", MP_TRANSPORT_LOG_LABEL);
+ for(auto it = this->side_channels.begin(); it != this->side_channels.end(); ++it)
+ {
+ if ((*it)->side_channel == side_channel)
+ {
+ delete *it;
+ this->side_channels.erase(it);
+ break;
+ }
+ }
+ }
+ }
+}
+
+MPSerializeFunc MPUnixDomainTransport::get_event_serialization_function(unsigned pub_id, unsigned event_id)
+{
+ auto helper_it = this->event_helpers.find(pub_id);
+ if (helper_it == this->event_helpers.end())
+ {
+ MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id);
+ return nullptr;
+ }
+ auto helper_functions = helper_it->second.get_function_set(event_id);
+ if (!helper_functions)
+ {
+ MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id);
+ return nullptr;
+ }
+ return helper_functions->serializer;
+}
+
+MPDeserializeFunc MPUnixDomainTransport::get_event_deserialization_function(unsigned pub_id, unsigned event_id)
+{
+ auto helper_it = this->event_helpers.find(pub_id);
+ if (helper_it == this->event_helpers.end())
+ {
+ MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id);
+ return nullptr;
+ }
+ auto helper_functions = helper_it->second.get_function_set(event_id);
+ if (!helper_functions)
+ {
+ MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id);
+ return nullptr;
+ }
+ return helper_functions->deserializer;
+}
+
+void MPUnixDomainTransport::init_connection()
+{
+ init_side_channels();
+}
+
+void MPUnixDomainTransport::thread_init()
+{
+}
+
+void MPUnixDomainTransport::thread_term()
+{
+}
+
+bool MPUnixDomainTransport::configure(const SnortConfig *c)
+{
+ config->max_processes = c->max_procs;
+ return true;
+}
+
+void MPUnixDomainTransport::cleanup()
+{
+ this->is_running = false;
+ MPUnixDomainTransport::unregister_receive_handler();
+ if (this->consume_thread)
+ {
+ this->consume_thread_cv.notify_all();
+ this->consume_thread->join();
+ delete this->consume_thread;
+ this->consume_thread = nullptr;
+ }
+ cleanup_side_channels();
+ for (auto &&ac_handler : this->accept_handlers)
+ {
+ ac_handler->listener->stop_accepting_connections();
+ delete ac_handler->listener;
+ delete ac_handler->connector_config;
+ delete ac_handler;
+ }
+ this->accept_handlers.clear();
+}
+
+void MPUnixDomainTransport::init_side_channels()
+{
+ assert(config);
+ if (config->max_processes < 2)
+ return;
+
+ auto instance_id = Snort::get_process_id();//Snort instance id
+ auto max_processes = config->max_processes;
+
+ this->is_running = true;
+
+ for (ushort i = instance_id; i < max_processes; i++)
+ {
+ auto listen_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i);
+ auto unix_listener = new UnixDomainConnectorListener(listen_path.c_str());
+
+ UnixDomainConnectorConfig* unix_config = new UnixDomainConnectorConfig();
+ unix_config->setup = UnixDomainConnectorConfig::Setup::ANSWER;
+ unix_config->async_receive = true;
+ if (config->conn_retries)
+ {
+ unix_config->conn_retries = config->conn_retries;
+ unix_config->retry_interval = config->retry_interval_seconds;
+ unix_config->max_retries = config->max_retries;
+ unix_config->connect_timeout_seconds = config->connect_timeout_seconds;
+ }
+ else
+ {
+ unix_config->conn_retries = false;
+ unix_config->retry_interval = 0;
+ unix_config->max_retries = 0;
+ unix_config->connect_timeout_seconds = 0;
+ }
+ unix_config->paths.push_back(listen_path);
+
+ unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2), unix_config);
+
+ auto unix_listener_handle = new UnixAcceptorHandle();
+ unix_listener_handle->connector_config = unix_config;
+ unix_listener_handle->listener = unix_listener;
+ this->accept_handlers.push_back(unix_listener_handle);
+ }
+
+ for (ushort i = 1; i < instance_id; i++)
+ {
+ auto side_channel = new SideChannel(ScMsgFormat::BINARY);
+ side_channel->register_receive_handler([this](SCMessage* msg) { this->side_channel_receive_handler(msg); });
+
+ auto send_path = config->unix_domain_socket_path + "/" + "snort_unix_connector_" + std::to_string(i);
+
+ UnixDomainConnectorConfig* connector_conf = new UnixDomainConnectorConfig();
+ connector_conf->setup = UnixDomainConnectorConfig::Setup::CALL;
+ connector_conf->async_receive = true;
+ connector_conf->conn_retries = config->conn_retries;
+ connector_conf->retry_interval = config->retry_interval_seconds;
+ connector_conf->max_retries = config->max_retries;
+ connector_conf->connect_timeout_seconds = config->connect_timeout_seconds;
+ connector_conf->paths.push_back(send_path);
+
+ auto connector = unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+
+ if (connector)
+ connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
+
+ side_channel->connector_receive = connector;
+ side_channel->connector_transmit = side_channel->connector_receive;
+ this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf));
+ }
+
+ this->consume_thread = new std::thread(&MPUnixDomainTransport::process_messages_from_side_channels, this);
+}
+void MPUnixDomainTransport::cleanup_side_channels()
+{
+ std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+
+ for (uint i = 0; i < this->side_channels.size(); i++)
+ {
+ auto side_channel = this->side_channels[i];
+ delete side_channel;
+ }
+
+ this->side_channels.clear();
+}
+
+SideChannelHandle::~SideChannelHandle()
+{
+ if (side_channel)
+ {
+ if (side_channel->connector_receive)
+ delete side_channel->connector_receive;
+
+ delete side_channel;
+ }
+
+ if (connector_config)
+ delete connector_config;
+}
+void MPUnixDomainTransport::enable_logging()
+{
+ this->is_logging_enabled_flag = true;
+}
+
+void MPUnixDomainTransport::disable_logging()
+{
+ this->is_logging_enabled_flag = false;
+}
+
+bool MPUnixDomainTransport::is_logging_enabled()
+{
+ return this->is_logging_enabled_flag;
+}
+
+};
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef UNIX_TRANSPORT_H
+#define UNIX_TRANSPORT_H
+
+#include "connectors/unixdomain_connector/unixdomain_connector.h"
+#include "framework/mp_data_bus.h"
+#include "main/snort_types.h"
+#include "side_channel/side_channel.h"
+
+#include <atomic>
+#include <thread>
+#include <condition_variable>
+
+namespace snort
+{
+
+struct MPUnixDomainTransportConfig
+{
+ std::string unix_domain_socket_path;
+ uint16_t max_processes = 0;
+ bool conn_retries = true;
+ bool enable_logging = false;
+ uint32_t retry_interval_seconds = 5;
+ uint32_t max_retries = 5;
+ uint32_t connect_timeout_seconds = 30;
+ uint32_t consume_message_timeout_milliseconds = 100;
+ uint32_t consume_message_batch_size = 5;
+};
+
+struct SerializeFunctionHandle
+{
+ std::unordered_map<unsigned, MPHelperFunctions> serialize_functions;
+
+ MPHelperFunctions* get_function_set(unsigned event_id)
+ {
+ auto it = serialize_functions.find(event_id);
+ if(it == serialize_functions.end())
+ return nullptr;
+ return &it->second;
+ }
+};
+
+struct SideChannelHandle
+{
+ SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc) :
+ side_channel(sc), connector_config(cc)
+ { }
+
+ ~SideChannelHandle();
+
+ SideChannel* side_channel;
+ UnixDomainConnectorConfig* connector_config;
+};
+
+struct UnixAcceptorHandle
+{
+ UnixDomainConnectorConfig* connector_config = nullptr;
+ UnixDomainConnectorListener* listener = nullptr;
+};
+
+class MPUnixDomainTransport : public MPTransport
+{
+ public:
+
+ MPUnixDomainTransport(MPUnixDomainTransportConfig* c);
+ ~MPUnixDomainTransport() override;
+
+ bool configure(const SnortConfig*) override;
+ void thread_init() override;
+ void thread_term() override;
+ void init_connection() override;
+ bool send_to_transport(MPEventInfo& event) override;
+ void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) override;
+ void register_receive_handler(const TransportReceiveEventHandler& handler) override;
+ void unregister_receive_handler() override;
+ void enable_logging() override;
+ void disable_logging() override;
+ bool is_logging_enabled() override;
+ void cleanup();
+
+ MPUnixDomainTransportConfig* get_config()
+ { return config; }
+
+
+ private:
+
+ void init_side_channels();
+ void cleanup_side_channels();
+ void side_channel_receive_handler(SCMessage* msg);
+ void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg);
+ void process_messages_from_side_channels();
+ void notify_process_thread();
+ void connector_update_handler(UnixDomainConnector* connector, bool is_recconecting, SideChannel* side_channel);
+
+ MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id);
+ MPDeserializeFunc get_event_deserialization_function(unsigned pub_id, unsigned event_id);
+
+ TransportReceiveEventHandler transport_receive_handler = nullptr;
+ MPUnixDomainTransportConfig* config = nullptr;
+
+ std::vector<SideChannelHandle*> side_channels;
+ std::vector<UnixAcceptorHandle*> accept_handlers;
+ std::unordered_map<unsigned, SerializeFunctionHandle> event_helpers;
+
+ std::atomic<bool> is_running = false;
+ std::atomic<bool> is_logging_enabled_flag;
+ std::atomic<bool> consume_message_received = false;
+
+ std::thread* consume_thread = nullptr;
+ std::condition_variable consume_thread_cv;
+};
+
+}
+#endif
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport_module.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_unix_transport_module.h"
+
+#include "main/snort_config.h"
+#include "log/messages.h"
+
+#define DEFAULT_UNIX_DOMAIN_SOCKET_PATH "/tmp/snort_unix_connectors"
+
+using namespace snort;
+
+static const Parameter unix_transport_params[] =
+{
+ { "unix_domain_socket_path" , Parameter::PT_STRING, nullptr, DEFAULT_UNIX_DOMAIN_SOCKET_PATH, "unix socket folder" },
+ { "max_connect_retries", Parameter::PT_INT, nullptr, "5", "max connection retries" },
+ { "retry_interval_seconds", Parameter::PT_INT, nullptr, "30", "retry interval in seconds" },
+ { "connect_timeout_seconds", Parameter::PT_INT, nullptr, "30", "connect timeout in seconds" },
+ { "consume_message_timeout_milliseconds", Parameter::PT_INT, nullptr, "100", "consume message timeout in milliseconds" },
+ { "consume_message_batch_size", Parameter::PT_INT, nullptr, "5", "consume message batch size" },
+ { "enable_logging", Parameter::PT_BOOL, nullptr, "false", "enable logging" },
+ { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+MPUnixDomainTransportModule::MPUnixDomainTransportModule(): Module(MODULE_NAME, MODULE_HELP, unix_transport_params)
+{
+ config = nullptr;
+}
+
+bool MPUnixDomainTransportModule::begin(const char *, int, SnortConfig *sc)
+{
+ assert(sc);
+ assert(!config);
+ config = new MPUnixDomainTransportConfig;
+ config->max_processes = sc->max_procs;
+ return true;
+}
+
+bool MPUnixDomainTransportModule::set(const char *, Value & v, SnortConfig *)
+{
+ if (v.is("unix_domain_socket_path"))
+ {
+ config->unix_domain_socket_path = v.get_string();
+ }
+ else if (v.is("max_connect_retries"))
+ {
+ config->conn_retries = true;
+ config->max_retries = v.get_int32();
+ }
+ else if (v.is("retry_interval_seconds"))
+ {
+ config->retry_interval_seconds = v.get_int32();
+ }
+ else if (v.is("connect_timeout_seconds"))
+ {
+ config->connect_timeout_seconds = v.get_int32();
+ }
+ else if (v.is("consume_message_timeout_milliseconds"))
+ {
+ config->consume_message_timeout_milliseconds = v.get_int32();
+ }
+ else if (v.is("consume_message_batch_size"))
+ {
+ config->consume_message_batch_size = v.get_int32();
+ }
+ else if (v.is("enable_logging"))
+ {
+ config->enable_logging = v.get_bool();
+ }
+ else
+ {
+ WarningMessage("MPUnixDomainTransportModule: received unrecognized parameter %s\n", v.get_as_string().c_str());
+ return false;
+ }
+
+ return true;
+}
+
+static struct MPTransportApi mp_unixdomain_transport_api =
+{
+ {
+ PT_MP_TRANSPORT,
+ sizeof(MPTransportApi),
+ MP_TRANSPORT_API_VERSION,
+ 2,
+ API_RESERVED,
+ API_OPTIONS,
+ MODULE_NAME,
+ MODULE_HELP,
+ mod_ctor,
+ mod_dtor
+ },
+ 0,
+ nullptr,
+ nullptr,
+ nullptr,
+ nullptr,
+ mp_unixdomain_transport_ctor,
+ mp_unixdomain_transport_dtor
+};
+
+#ifdef BUILDING_SO
+SO_PUBLIC const BaseApi* snort_plugins[] =
+#else
+const BaseApi* mp_unix_transport[] =
+#endif
+{
+ &mp_unixdomain_transport_api.base,
+ nullptr
+};
--- /dev/null
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation. You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport_module.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_UNIX_TRANSPORT_MODULE_H
+#define MP_UNIX_TRANSPORT_MODULE_H
+
+#define MODULE_NAME "unix_transport"
+#define MODULE_HELP "manage the unix transport layer"
+
+#include "framework/module.h"
+#include "framework/mp_transport.h"
+#include "mp_unix_transport.h"
+
+namespace snort
+{
+
+class MPUnixDomainTransportModule : public Module
+{
+ public:
+
+ MPUnixDomainTransportModule();
+
+ ~MPUnixDomainTransportModule() override
+ { delete config; }
+
+ bool begin(const char*, int, SnortConfig*) override;
+ bool set(const char*, Value&, SnortConfig*) override;
+
+ Usage get_usage() const override
+ { return GLOBAL; }
+
+ MPUnixDomainTransportConfig* config;
+};
+
+static Module* mod_ctor()
+{
+ return new MPUnixDomainTransportModule;
+}
+
+static void mod_dtor(Module* m)
+{
+ delete m;
+}
+
+static MPTransport* mp_unixdomain_transport_ctor(Module* m)
+{
+ auto unix_tr_mod = (MPUnixDomainTransportModule*)m;
+ return new MPUnixDomainTransport(unix_tr_mod->config);
+}
+
+static void mp_unixdomain_transport_dtor(MPTransport* t)
+{
+ delete t;
+}
+
+}
+
+#endif
--- /dev/null
+
+add_cpputest( unix_transport_test
+ SOURCES
+ ../mp_unix_transport.cc
+ ../../../side_channel/side_channel.cc
+ ../../../side_channel/side_channel_format.cc
+ ../../../framework/module.cc
+ ../../../managers/connector_manager.cc
+ $<TARGET_OBJECTS:catch_tests>
+ LIBS
+ ${CMAKE_THREAD_LIBS_INIT}
+)
+
+ add_cpputest( unix_transport_module_test
+ SOURCES
+ ../mp_unix_transport_module.cc
+ ../../../framework/value.cc
+ ../../../sfip/sf_ip.cc
+ $<TARGET_OBJECTS:catch_tests>
+ ../../../framework/module.cc
+ LIBS
+ ${CMAKE_THREAD_LIBS_INIT}
+ )
+
--- /dev/null
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "../mp_unix_transport_module.h"
+
+#include "framework/value.h"
+#include "main/snort_config.h"
+#include "main/snort.h"
+#include "main/thread_config.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+
+static int warning_cnt = 0;
+static int destroy_cnt = 0;
+
+namespace snort
+{
+ void WarningMessage(const char*,...) { warning_cnt++; }
+
+ SnortConfig::SnortConfig(snort::SnortConfig const*, char const*)
+ {
+ max_procs = 2;
+ }
+ SnortConfig::~SnortConfig()
+ {
+
+ }
+
+ unsigned ThreadConfig::get_instance_max()
+ {
+ return 1;
+ }
+ unsigned Snort::get_process_id()
+ {
+ return 1;
+ }
+ unsigned get_instance_id()
+ {
+ return 1;
+ }
+
+ MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig* config) :
+ MPTransport()
+ {
+ this->config = config;
+ }
+ MPUnixDomainTransport::~MPUnixDomainTransport()
+ { destroy_cnt++; }
+ void MPUnixDomainTransport::thread_init()
+ {}
+ void MPUnixDomainTransport::thread_term()
+ {}
+ void MPUnixDomainTransport::init_connection()
+ {}
+ void MPUnixDomainTransport::cleanup()
+ {}
+ void MPUnixDomainTransport::register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&)
+ {}
+ bool MPUnixDomainTransport::send_to_transport(MPEventInfo&)
+ { return true; }
+ void MPUnixDomainTransport::unregister_receive_handler()
+ { }
+ void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler&)
+ {}
+ bool MPUnixDomainTransport::configure(const SnortConfig*)
+ { return true; }
+ bool MPUnixDomainTransport::is_logging_enabled()
+ { return false; }
+ void MPUnixDomainTransport::enable_logging()
+ {}
+ void MPUnixDomainTransport::disable_logging()
+ {}
+
+ char* snort_strdup(const char*)
+ {
+ return nullptr;
+ }
+};
+
+void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
+void show_stats(PegCount*, const PegInfo*, const std::vector<unsigned>&, const char*, FILE*) { }
+
+using namespace snort;
+
+MPUnixDomainTransportModule* mod = nullptr;
+
+TEST_GROUP(MPUnixDomainTransportModuleTests)
+{
+ void setup() override
+ {
+ mod = (MPUnixDomainTransportModule*)mod_ctor();
+ }
+
+ void teardown() override
+ {
+ mod_dtor(mod);
+ }
+};
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigBegin)
+{
+ SnortConfig sc;
+ auto res = mod->begin("test", 0, &sc);
+ CHECK(res == true);
+ CHECK(mod->config != nullptr);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigEnd)
+{
+ auto res = mod->end("test", 0, nullptr);
+ CHECK(res == true);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigSet)
+{
+ Parameter p{"unix_domain_socket_path", Parameter::PT_STRING, nullptr, "test_value", nullptr};
+ Parameter p2{"max_connect_retries", Parameter::PT_INT, nullptr, "15", nullptr};
+ Parameter p3{"retry_interval_seconds", Parameter::PT_INT, nullptr, "33", nullptr};
+ Parameter p4{"connect_timeout_seconds", Parameter::PT_INT, nullptr, "32", nullptr};
+ Parameter p5{"consume_message_timeout_milliseconds", Parameter::PT_INT, nullptr, "200", nullptr};
+ Parameter p6{"consume_message_batch_size", Parameter::PT_INT, nullptr, "20", nullptr};
+ Parameter p7{"enable_logging", Parameter::PT_BOOL, nullptr, "true", nullptr};
+ Value v("test_value");
+ v.set(&p);
+
+ SnortConfig sc;
+ mod->begin("test", 0, &sc);
+ auto res = mod->set(nullptr, v, nullptr);
+
+ CHECK(res == true);
+ CHECK(strcmp("test_value", mod->config->unix_domain_socket_path.c_str()) == 0);
+
+ v.set((double)15);
+ v.set(&p2);
+ res = mod->set(nullptr, v, nullptr);
+ CHECK(res == true);
+ CHECK(mod->config->max_retries == 15);
+
+ v.set((double)33);
+ v.set(&p3);
+ res = mod->set(nullptr, v, nullptr);
+ CHECK(res == true);
+ CHECK(mod->config->retry_interval_seconds == 33);
+
+ v.set((double)32);
+ v.set(&p4);
+ res = mod->set(nullptr, v, nullptr);
+ CHECK(res == true);
+ CHECK(mod->config->connect_timeout_seconds == 32);
+
+ v.set((double)200);
+ v.set(&p5);
+ res = mod->set(nullptr, v, nullptr);
+ CHECK(res == true);
+ CHECK(mod->config->consume_message_timeout_milliseconds == 200);
+
+ v.set((double)20);
+ v.set(&p6);
+ res = mod->set(nullptr, v, nullptr);
+ CHECK(res == true);
+ CHECK(mod->config->consume_message_batch_size == 20);
+
+ v.set(true);
+ v.set(&p7);
+ res = mod->set(nullptr, v, nullptr);
+ CHECK(res == true);
+ CHECK(mod->config->enable_logging == true);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigUnknownSet)
+{
+ warning_cnt = 0;
+
+ Parameter p{"unknown_value", Parameter::PT_STRING, nullptr, "/tmp/unx_dmn_sck", nullptr};
+ Value v("/tmp/unx_dmn_sck");
+ v.set(&p);
+
+ SnortConfig sc;
+ mod->begin("test", 0, &sc);
+ auto res = mod->set(nullptr, v, nullptr);
+
+ CHECK(res == false);
+ CHECK(1 == warning_cnt);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleGetUsage)
+{
+ auto res = mod->get_usage();
+ CHECK(res == Module::Usage::GLOBAL);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleCreateDestroyTransport)
+{
+ destroy_cnt = 0;
+ auto transport = mp_unixdomain_transport_ctor(mod);
+ CHECK(transport != nullptr);
+
+ mp_unixdomain_transport_dtor(transport);
+ CHECK(destroy_cnt == 1);
+}
+
+int main(int argc, char** argv)
+{
+ int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
+ return return_value;
+}
\ No newline at end of file
--- /dev/null
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "framework/mp_transport.h"
+#include "../mp_unix_transport.h"
+#include "framework/counts.h"
+#include "framework/mp_data_bus.h"
+#include "main/snort.h"
+#include "main/thread_config.h"
+#include "main/snort_config.h"
+#include "connectors/unixdomain_connector/unixdomain_connector.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+#include <mutex>
+#include <condition_variable>
+#include <vector>
+#include <iostream>
+
+static int snort_instance_id = 0;
+
+static int accept_cnt = 0;
+
+static int test_socket_calls = 0;
+static int test_bind_calls = 0;
+static int test_listen_calls = 0;
+static int test_accept_calls = 0;
+static int test_close_calls = 0;
+static int test_connect_calls = 0;
+static int test_call_sock_created = 0;
+static int test_serialize_calls = 0;
+static int test_deserialize_calls = 0;
+
+int accept (int, struct sockaddr*, socklen_t*)
+{
+ test_accept_calls++;
+ return accept_cnt--;
+}
+
+int close (int)
+{
+ test_close_calls++;
+ return 0;
+}
+
+void clear_test_calls()
+{
+ test_socket_calls = 0;
+ test_bind_calls = 0;
+ test_listen_calls = 0;
+ test_accept_calls = 0;
+ test_close_calls = 0;
+ test_connect_calls = 0;
+ test_call_sock_created = 0;
+ test_serialize_calls = 0;
+ test_deserialize_calls = 0;
+}
+
+namespace snort
+{
+ void ErrorMessage(const char*,...) { }
+ void WarningMessage(const char*,...) { }
+ void LogMessage(const char* s, ...) { }
+ void LogText(const char*, FILE*) {}
+ void ParseError(const char*, ...) { }
+
+ unsigned ThreadConfig::get_instance_max()
+ {
+ return 2; // Mock value for testing
+ }
+
+ SnortConfig::SnortConfig(snort::SnortConfig const*, char const*)
+ {
+ max_procs = 2;
+ }
+ SnortConfig::~SnortConfig()
+ {
+
+ }
+
+ unsigned Snort::get_process_id()
+ {
+ return snort_instance_id;
+ }
+
+ unsigned get_instance_id()
+ {
+ return snort_instance_id;
+ }
+};
+static int test_send_calls = 0;
+UnixDomainConnector* listen_connector = nullptr;
+UnixDomainConnector* call_connector = nullptr;
+
+void UnixDomainConnector::set_message_received_handler(std::function<void ()> h)
+{
+ message_received_handler = h;
+}
+
+static bool expect_update_change = false;
+std::function<void (UnixDomainConnector*,bool)> test_update_handler = nullptr;
+
+void UnixDomainConnector::set_update_handler(std::function<void (UnixDomainConnector*,bool)> h)
+{
+ if(expect_update_change)
+ test_update_handler = h;
+}
+
+static snort::ConnectorMsg* test_msg_answer = nullptr;
+static snort::ConnectorMsg* test_msg_call = nullptr;
+static uint8_t* test_msg_call_data = nullptr;
+static uint8_t* test_msg_answer_data = nullptr;
+UnixDomainConnectorListener::UnixDomainConnectorListener(char const*) // cppcheck-suppress uninitMemberVar
+{}
+UnixDomainConnectorListener::~UnixDomainConnectorListener()
+{
+}
+void UnixDomainConnectorListener::stop_accepting_connections()
+{
+ close(0);
+}
+void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnectorAcceptHandler h, UnixDomainConnectorConfig* cfg)
+{
+ socket(0,0,0);
+ while(accept_cnt > 0)
+ {
+ accept(0, nullptr, nullptr);
+ auto cfg_copy = new UnixDomainConnectorConfig(*cfg);
+ h(new UnixDomainConnector(*cfg_copy, 0, 0), cfg_copy);
+ }
+}
+
+bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg& m, const ID&)
+{
+ test_send_calls++;
+
+ if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+ {
+ test_msg_call_data = new uint8_t[m.get_length()];
+ memcpy(test_msg_call_data, m.get_data(), m.get_length());
+ test_msg_call = new snort::ConnectorMsg(test_msg_call_data, m.get_length());
+ if(!call_connector)
+ call_connector = this;
+ listen_connector->process_receive();
+ }
+ else
+ {
+ test_msg_answer_data = new uint8_t[m.get_length()];
+ memcpy(test_msg_answer_data, m.get_data(), m.get_length());
+ test_msg_answer = new snort::ConnectorMsg(test_msg_answer_data, m.get_length());
+ if(!listen_connector)
+ listen_connector = this;
+ call_connector->process_receive();
+ }
+
+ return true;
+}
+void UnixDomainConnector::process_receive()
+{
+ if (message_received_handler)
+ {
+ message_received_handler();
+ }
+}
+bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg&&, const ID&)
+{ return true; }
+snort::ConnectorMsg UnixDomainConnector::receive_message(bool)
+{
+ if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+ {
+ if (test_msg_answer)
+ {
+ snort::ConnectorMsg msg(test_msg_answer_data, test_msg_answer->get_length());
+ delete test_msg_answer;
+ test_msg_answer = nullptr;
+ return std::move(msg); // cppcheck-suppress returnStdMoveLocal
+ }
+ }
+ else
+ {
+ if (test_msg_call)
+ {
+ snort::ConnectorMsg msg(test_msg_call_data, test_msg_call->get_length());
+ delete test_msg_call;
+ test_msg_call = nullptr;
+ return std::move(msg); // cppcheck-suppress returnStdMoveLocal
+ }
+ }
+ return snort::ConnectorMsg();
+}
+UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx) : Connector(config) // cppcheck-suppress uninitMemberVar
+{ cfg = config; } // cppcheck-suppress useInitializationList
+UnixDomainConnector::~UnixDomainConnector()
+{
+ close(0);
+}
+
+UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler)
+{
+ if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+ {
+ test_call_sock_created++;
+ auto new_conn = new UnixDomainConnector(cfg, 0, idx);
+ call_connector = new_conn;
+ return new_conn;
+ }
+ assert(false);
+ return nullptr;
+}
+
+void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
+void show_stats(PegCount*, const PegInfo*, const std::vector<unsigned>&, const char*, FILE*) { }
+
+using namespace snort;
+
+
+static int s_socket_return = 1;
+static int s_bind_return = 0;
+static int s_listen_return = 0;
+static int s_connect_return = 1;
+
+#ifdef __GLIBC__
+int socket (int, int, int) __THROW { test_socket_calls++; return s_socket_return; }
+int bind (int, const struct sockaddr*, socklen_t) __THROW { test_bind_calls++; return s_bind_return; }
+int listen (int, int) __THROW { test_listen_calls++; return s_listen_return; }
+int connect (int, const struct sockaddr*, socklen_t) __THROW { test_connect_calls++; return s_connect_return; }
+int unlink (const char *__name) __THROW { return 0;};
+#else
+int socket (int, int, int) { test_socket_calls++; return s_socket_return; }
+int bind (int, const struct sockaddr*, socklen_t) { test_bind_calls++; return s_bind_return; }
+int listen (int, int) { test_listen_calls++; return s_listen_return; }
+int connect (int, const struct sockaddr*, socklen_t) { test_connect_calls++; return s_connect_return; }
+int unlink (const char *__name) { return 0;};
+#endif
+
+int fcntl (int __fd, int __cmd, ...) { return 0;}
+ssize_t send (int, const void*, size_t n, int) { return n; }
+
+
+std::mutex accept_mutex;
+std::condition_variable accept_cond;
+
+
+class TestDataEvent : public DataEvent
+{
+public:
+ TestDataEvent() {}
+ ~TestDataEvent() override {}
+};
+
+bool serialize_mock(DataEvent* event, char*& buffer, uint16_t* length)
+{
+ test_serialize_calls++;
+ buffer = new char[9];
+ *length = 9;
+ memcpy(buffer, "test_data", 9);
+ return true;
+}
+
+bool deserialize_mock(const char* buffer, uint16_t length, DataEvent*& event)
+{
+ test_deserialize_calls++;
+ event = new TestDataEvent();
+ return true;
+}
+
+MPHelperFunctions mp_helper_functions_mock(serialize_mock, deserialize_mock);
+
+static MPUnixDomainTransportConfig test_config;
+static MPUnixDomainTransport* test_transport = nullptr;
+
+static SnortConfig test_snort_config(nullptr, nullptr);
+
+TEST_GROUP(unix_transport_test_connectivity_group)
+{
+ void setup() override
+ {
+ test_snort_config.max_procs = 2;
+ test_transport = new MPUnixDomainTransport(&test_config);
+ test_transport->configure(&test_snort_config);
+ }
+
+ void teardown() override
+ {
+ delete test_transport;
+ test_transport = nullptr;
+ }
+};
+
+static MPUnixDomainTransportConfig test_config_message;
+
+static MPTransport* test_transport_message_1 = nullptr;
+static MPTransport* test_transport_message_2 = nullptr;
+
+static int reciveved_1_msg_cnt = 0;
+static int reciveved_2_msg_cnt = 0;
+
+TEST_GROUP(unix_transport_test_messaging)
+{
+ void setup() override
+ {
+ test_snort_config.max_procs = 2;
+
+ accept_cnt = 1;
+
+ test_config_message.unix_domain_socket_path = "/tmp";
+ test_config_message.max_processes = 2;
+ test_config_message.conn_retries = false;
+ test_config_message.retry_interval_seconds = 0;
+ test_config_message.max_retries = 0;
+ test_config_message.connect_timeout_seconds = 30;
+
+ test_transport_message_1 = new MPUnixDomainTransport(&test_config_message);
+ snort_instance_id = 1;
+ test_transport_message_1->configure(&test_snort_config);
+ test_transport_message_1->init_connection();
+ test_transport_message_1->register_receive_handler([reciveved_1_msg_cnt](const snort::MPEventInfo& e)
+ {
+ reciveved_1_msg_cnt++;
+ });
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+ test_transport_message_2 = new MPUnixDomainTransport(&test_config_message);
+ snort_instance_id = 2;
+ test_transport_message_2->configure(&test_snort_config);
+ test_transport_message_2->init_connection();
+ test_transport_message_2->register_receive_handler([reciveved_2_msg_cnt](const snort::MPEventInfo& e)
+ {
+ reciveved_2_msg_cnt++;
+ });
+ }
+
+ void teardown() override
+ {
+ delete test_transport_message_1;
+ test_transport_message_1 = nullptr;
+ delete test_transport_message_2;
+ test_transport_message_2 = nullptr;
+ delete[] test_msg_call_data;
+ test_msg_call_data = nullptr;
+ delete[] test_msg_answer_data;
+ test_msg_answer_data = nullptr;
+ }
+};
+
+TEST(unix_transport_test_connectivity_group, get_config)
+{
+ auto unix_transport = (MPUnixDomainTransport*)test_transport;
+ CHECK(unix_transport->get_config() == &test_config);
+};
+
+TEST(unix_transport_test_connectivity_group, set_logging_enabled_disabled)
+{
+ auto logging_status = test_transport->is_logging_enabled();
+ CHECK(logging_status == false);
+
+ test_transport->enable_logging();
+ logging_status = test_transport->is_logging_enabled();
+ CHECK(logging_status == true);
+
+ test_transport->disable_logging();
+ logging_status = test_transport->is_logging_enabled();
+ CHECK(logging_status == false);
+};
+
+TEST(unix_transport_test_connectivity_group, init_connection_single_snort_instance)
+{
+ clear_test_calls();
+ test_config.unix_domain_socket_path = "/tmp";
+ test_config.max_processes = 1;
+
+ test_transport->init_connection();
+
+ CHECK(test_socket_calls == 0);
+ CHECK(test_bind_calls == 0);
+ CHECK(test_listen_calls == 0);
+ CHECK(test_accept_calls == 0);
+ CHECK(test_close_calls == 0);
+
+ test_transport->cleanup();
+ CHECK(test_close_calls == 0);
+};
+
+TEST(unix_transport_test_connectivity_group, init_connection_first_snort_instance)
+{
+ clear_test_calls();
+ snort_instance_id = 1;
+
+ test_config.unix_domain_socket_path = "/tmp";
+ test_config.max_processes = 2;
+
+ accept_cnt = 1;
+
+ test_transport->init_connection();
+
+ CHECK(test_accept_calls == 1);
+
+ test_transport->cleanup();
+ CHECK(test_close_calls == 2);
+};
+
+TEST(unix_transport_test_connectivity_group, init_connection_second_snort_instance)
+{
+ clear_test_calls();
+ snort_instance_id = 2;
+ test_config.unix_domain_socket_path = "/tmp";
+ test_config.max_processes = 2;
+
+ test_transport->init_connection();
+
+ CHECK(test_bind_calls == 0);
+ CHECK(test_listen_calls == 0);
+ CHECK(test_accept_calls == 0);
+ CHECK(test_close_calls == 0);
+ CHECK(test_call_sock_created == 1);
+
+ test_transport->cleanup();
+ CHECK(test_close_calls == 1);
+};
+
+TEST(unix_transport_test_connectivity_group, connector_update_handler_call)
+{
+ clear_test_calls();
+
+ test_config.unix_domain_socket_path = "/tmp";
+ test_config.max_processes = 2;
+
+ accept_cnt = 1;
+ snort_instance_id = 1;
+
+ test_update_handler = nullptr;
+ expect_update_change = true;
+
+ test_transport->init_connection();
+
+ CHECK(test_update_handler != nullptr);
+
+ test_update_handler(nullptr, false);
+
+ CHECK(test_close_calls == 1);
+ expect_update_change = false;
+ test_update_handler = nullptr;
+};
+
+static TestDataEvent test_event;
+
+TEST(unix_transport_test_messaging, send_to_transport_biderectional)
+{
+ clear_test_calls();
+
+ test_transport_message_1->register_event_helpers(0, 0, mp_helper_functions_mock);
+ test_transport_message_2->register_event_helpers(0, 0, mp_helper_functions_mock);
+
+ MPEventInfo event(&test_event, 0, 0);
+
+ auto res = test_transport_message_1->send_to_transport(event);
+
+
+ CHECK(res == true);
+ CHECK(test_serialize_calls == 1);
+
+ res = test_transport_message_2->send_to_transport(event);
+
+ CHECK(res == true);
+ CHECK(test_serialize_calls == 2);
+
+ std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+ CHECK(test_deserialize_calls == 2);
+ CHECK(reciveved_1_msg_cnt == 1);
+ CHECK(reciveved_2_msg_cnt == 1);
+ CHECK(test_send_calls == 2);
+};
+
+TEST(unix_transport_test_messaging, send_to_transport_no_helpers)
+{
+ clear_test_calls();
+
+ MPEventInfo event(&test_event, 0, 0);
+
+ auto res = test_transport_message_1->send_to_transport(event);
+ CHECK(res == false);
+ CHECK(test_serialize_calls == 0);
+ CHECK(test_deserialize_calls == 0);
+ CHECK(reciveved_1_msg_cnt == 0);
+ CHECK(reciveved_2_msg_cnt == 0);
+ CHECK(test_send_calls == 0);
+
+ res = test_transport_message_2->send_to_transport(event);
+ CHECK(res == false);
+ CHECK(test_serialize_calls == 0);
+ CHECK(test_deserialize_calls == 0);
+ CHECK(reciveved_1_msg_cnt == 0);
+ CHECK(reciveved_2_msg_cnt == 0);
+ CHECK(test_send_calls == 0);
+}
+
+int main(int argc, char** argv)
+{
+ int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
+ return return_value;
+}
\ No newline at end of file
// return true iff we received any messages.
bool SideChannel::process(int max_messages)
{
+ if(!connector_receive)
+ return false;
+
bool received_message = false;
while (true)