]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4695: mp_unix_transport: mp_transport plugin type, implementation of...
authorOleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco) <ostepano@cisco.com>
Thu, 24 Apr 2025 18:16:27 +0000 (18:16 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Thu, 24 Apr 2025 18:16:27 +0000 (18:16 +0000)
Merge in SNORT/snort3 from ~OSTEPANO/snort3:mp_transport_layer to master

Squashed commit of the following:

commit edb3158929808ca911049623f5e676554134eab7
Author: Oleksandr Stepanov <ostepano@cisco.com>
Date:   Thu Mar 27 16:06:10 2025 -0400

    mp_unix_transport: mp_transport plugin type, implementation of unix domain name based mp transport

32 files changed:
src/CMakeLists.txt
src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc
src/connectors/unixdomain_connector/unixdomain_connector.cc
src/connectors/unixdomain_connector/unixdomain_connector.h
src/connectors/unixdomain_connector/unixdomain_connector_config.h
src/framework/CMakeLists.txt
src/framework/base_api.h
src/framework/mp_data_bus.cc
src/framework/mp_data_bus.h
src/framework/mp_transport.h [new file with mode: 0644]
src/main/help.cc
src/main/snort.cc
src/managers/CMakeLists.txt
src/managers/mp_transport_manager.cc [new file with mode: 0644]
src/managers/mp_transport_manager.h [new file with mode: 0644]
src/managers/plugin_manager.cc
src/managers/test/CMakeLists.txt
src/managers/test/mp_transport_manager_test.cc [new file with mode: 0644]
src/mp_transport/CMakeLists.txt [new file with mode: 0644]
src/mp_transport/dev_notes.txt [new file with mode: 0644]
src/mp_transport/mp_transports.cc [new file with mode: 0644]
src/mp_transport/mp_transports.h [new file with mode: 0644]
src/mp_transport/mp_unix_transport/CMakeLists.txt [new file with mode: 0644]
src/mp_transport/mp_unix_transport/dev_notes.txt [new file with mode: 0644]
src/mp_transport/mp_unix_transport/mp_unix_transport.cc [new file with mode: 0644]
src/mp_transport/mp_unix_transport/mp_unix_transport.h [new file with mode: 0644]
src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc [new file with mode: 0644]
src/mp_transport/mp_unix_transport/mp_unix_transport_module.h [new file with mode: 0644]
src/mp_transport/mp_unix_transport/test/CMakeLists.txt [new file with mode: 0644]
src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc [new file with mode: 0644]
src/mp_transport/mp_unix_transport/test/unix_transport_test.cc [new file with mode: 0644]
src/side_channel/side_channel.cc

index bd0803cf5d40bdc70b085a839288976eedb98d79..de2df2eb4ca60583bd45990b01839a4f2e8d47ab 100644 (file)
@@ -137,6 +137,7 @@ add_subdirectory(policy_selectors)
 add_subdirectory(search_engines)
 add_subdirectory(side_channel)
 add_subdirectory(connectors)
+add_subdirectory(mp_transport)
 
 
 # FIXIT-L Delegate building out the target objects list to subdirectories
@@ -198,8 +199,10 @@ add_executable( snort
     $<TARGET_OBJECTS:target_based>
     $<TARGET_OBJECTS:tcp_connector>
     $<TARGET_OBJECTS:unixdomain_connector>
+    $<TARGET_OBJECTS:mp_unix_transport>
     $<TARGET_OBJECTS:time>
     $<TARGET_OBJECTS:trace>
+    $<TARGET_OBJECTS:mp_transports>
     $<TARGET_OBJECTS:utils>
     ${STATIC_CODEC_PLUGINS}
     ${STATIC_NETWORK_INSPECTOR_PLUGINS}
index eba38357090c05c78007d87366e06ae5ece6f51b..269e54782f67492a67369f9ba7dc1f5a4612f885 100644 (file)
@@ -148,8 +148,19 @@ int socket (int, int, int) { return s_socket_return; }
 int bind (int, const struct sockaddr*, socklen_t) { return s_bind_return; }
 int listen (int, int) { return s_listen_return; }
 #endif
-
-int accept (int, struct sockaddr*, socklen_t*) { return s_accept_return; }
+static bool use_test_accept_counter = false;
+static uint test_accept_counter = 0;
+int accept (int, struct sockaddr*, socklen_t*)
+{
+    if ( use_test_accept_counter )
+    {
+        if ( test_accept_counter == 0 )
+            return -1;
+        else
+            test_accept_counter--;
+    }
+    return s_accept_return;
+}
 int close (int) { return 0; }
 
 static void set_normal_status()
@@ -753,6 +764,65 @@ TEST(unixdomain_connector_no_tinit_tterm_call, receive_recv_body_closed)
     delete[] message;
 }
 
+static const char* test_listener_path = "/tmp/test_path";
+UnixDomainConnectorListener* test_listener = nullptr;
+
+TEST_GROUP(unixdomain_connector_listener)
+{
+    void setup() override
+    {
+        test_listener = new UnixDomainConnectorListener(test_listener_path);
+    }
+
+    void teardown() override
+    {
+        if (test_listener)
+        {
+            delete test_listener;
+            test_listener = nullptr;
+        }
+    }
+};
+
+UnixDomainConnector* test_listener_connector = nullptr;
+UnixDomainConnectorConfig* test_listener_config = nullptr;
+
+void connection_callback(UnixDomainConnector* c, UnixDomainConnectorConfig* conf)
+{
+    assert(c != nullptr);
+    assert(conf != nullptr);
+
+    test_listener_connector = c;
+    test_listener_config = conf;
+}
+
+TEST(unixdomain_connector_listener, listener_accept_stop)
+{
+    UnixDomainConnectorConfig cfg;
+    cfg.direction = Connector::CONN_DUPLEX;
+    cfg.connector_name = "unixdomain";
+    cfg.paths.push_back(test_listener_path);
+    cfg.setup = UnixDomainConnectorConfig::Setup::ANSWER;
+    cfg.async_receive = true;
+
+    use_test_accept_counter = true;
+    test_accept_counter = 1;
+
+    test_listener->start_accepting_connections(&connection_callback, &cfg);
+    std::this_thread::sleep_for(std::chrono::milliseconds(100));
+    test_listener->stop_accepting_connections();
+
+    CHECK(test_listener_connector != nullptr);
+    CHECK(test_listener_config != nullptr);
+    CHECK(test_listener_connector->get_connector_direction() == Connector::CONN_DUPLEX);
+    CHECK(test_listener_config->async_receive == true);
+
+    delete test_listener_connector;
+    test_listener_connector = nullptr;
+    delete test_listener_config;
+    test_listener_config = nullptr;
+}
+
 int main(int argc, char** argv)
 {
     int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
index 19f7ae8faa79821960e0b317230fa8c5d1dad32d..f3e9b382b39b6bac5837357be980bddcfd51bed6 100644 (file)
@@ -83,37 +83,61 @@ static bool attempt_connection(int& sfd, const char* path) {
 }
 
 // Function to handle connection retries
-static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx) {
-    ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr);
+static void connection_retry_handler(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) {
+    if(update_handler)
+        update_handler(nullptr, (cfg.conn_retries > 0));
+    else
+        ConnectorManager::update_thread_connector(cfg.connector_name, idx, nullptr);
 
-    if ( cfg.setup == UnixDomainConnectorConfig::Setup::CALL and cfg.conn_retries) {
 
+    if (cfg.conn_retries)
+    {
         const auto& paths = cfg.paths;
 
         if (idx >= paths.size())
             return;
-        
-        uint32_t retry_count = 0; 
+
         const char* path = paths[idx].c_str();
 
-        while (retry_count < cfg.max_retries) {
-            int sfd;
-            if (attempt_connection(sfd, path)) {
-                // Connection successful
-                UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
-                LogMessage("UnixDomainC: Connected to %s", path);
-                ConnectorManager::update_thread_connector(cfg.connector_name, idx, unixdomain_conn);
-                break;
+        if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+        {
+            LogMessage("UnixDomainC: Attempting to reconnect to %s\n", cfg.paths[idx].c_str());
+
+            uint32_t retry_count = 0;
+
+            while (retry_count < cfg.max_retries) {
+                int sfd;
+                if (attempt_connection(sfd, path)) {
+                    // Connection successful
+                    UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
+                    LogMessage("UnixDomainC: Connected to %s", path);
+                    if(update_handler)
+                    {
+                        unixdomain_conn->set_update_handler(update_handler);
+                        update_handler(unixdomain_conn, false);
+                    }
+                    else
+                        ConnectorManager::update_thread_connector(cfg.connector_name, idx, unixdomain_conn);
+                    break;
+                }
+            
+                std::this_thread::sleep_for(std::chrono::seconds(cfg.retry_interval));
+                retry_count++;
             }
-
-            std::this_thread::sleep_for(std::chrono::seconds(cfg.retry_interval));
-            retry_count++;
+        }
+        else if (cfg.setup == UnixDomainConnectorConfig::Setup::ANSWER)
+        {
+            return;
+        }
+        else
+        {
+            LogMessage("UnixDomainC: Unexpected setup type at retry connection\n");
         }
     }
 }
 
-static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx) {
-    std::thread retry_thread(connection_retry_handler, cfg, idx);
+static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnectorUpdateHandler update_handler = nullptr) {
+    std::thread retry_thread(connection_retry_handler, cfg, idx, update_handler);
     retry_thread.detach();
 }
 
@@ -254,7 +278,7 @@ void UnixDomainConnector::process_receive() {
             sock_fd = -1;
         }   
 
-        start_retry_thread(cfg, instance_id);
+        start_retry_thread(cfg, instance_id, update_handler);
         return;
     } 
     else if (rval > 0 && pfds[0].revents & POLLIN) {
@@ -263,9 +287,23 @@ void UnixDomainConnector::process_receive() {
             ErrorMessage("UnixDomainC: Input Thread: overrun\n");
             delete connector_msg;
         }
+        if(message_received_handler)
+        {
+            message_received_handler();
+        }
     }
 }
 
+void UnixDomainConnector::set_update_handler(UnixDomainConnectorUpdateHandler handler)
+{
+    update_handler = std::move(handler);
+}
+
+void UnixDomainConnector::set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler)
+{
+    message_received_handler = std::move(handler);
+}
+
 void UnixDomainConnector::receive_processing_thread() {
     while (run_thread.load(std::memory_order_relaxed)) {
         process_receive();
@@ -347,12 +385,12 @@ static void mod_dtor(Module* m) {
     delete m;
 }
 
-static UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx) {
+UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler) {
     int sfd;
     if (!attempt_connection(sfd, path)) {
         if (cfg.conn_retries) {
             // Spawn a new thread to handle connection retries
-            start_retry_thread(cfg, idx);
+            start_retry_thread(cfg, idx, update_handler);
 
             return nullptr; // Return nullptr as the connection is not yet established
         } else {
@@ -362,6 +400,10 @@ static UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConn
     }
     LogMessage("UnixDomainC: Connected to %s", path);
     UnixDomainConnector* unixdomain_conn = new UnixDomainConnector(cfg, sfd, idx);
+    unixdomain_conn->set_update_handler(update_handler);
+    if(update_handler)
+        update_handler(unixdomain_conn, false);
+    
     return unixdomain_conn;
 }
 
@@ -400,7 +442,7 @@ static UnixDomainConnector* unixdomain_connector_tinit_answer(const UnixDomainCo
 
     LogMessage("UnixDomainC: Accepted connection from %s \n", path);
     return new UnixDomainConnector(cfg, peer_sfd, idx);
-} 
+}
 
 static bool is_valid_path(const std::string& path) {
     if (path.empty()) {
@@ -491,3 +533,92 @@ const BaseApi* unixdomain_connector[] =
     &unixdomain_connector_api.base,
     nullptr
 };
+
+UnixDomainConnectorListener::UnixDomainConnectorListener(const char *path)
+{
+    assert(path);
+    
+    sock_path = strdup(path);
+    sock_fd = 0;
+    accept_thread = nullptr;
+    should_accept = false;
+}
+
+UnixDomainConnectorListener::~UnixDomainConnectorListener()
+{
+    stop_accepting_connections();
+    free(sock_path);
+    sock_path = nullptr;
+}
+
+void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnectorAcceptHandler handler, UnixDomainConnectorConfig* config)
+{
+    assert(accept_thread == nullptr);
+    assert(sock_path);
+
+    should_accept = true;
+    accept_thread = new std::thread([this, handler, config]()
+    {
+        sock_fd = socket(AF_UNIX, SOCK_STREAM, 0);
+        if (sock_fd == -1) {
+            ErrorMessage("UnixDomainC: socket error: %s \n", strerror(errno));
+            return;
+        }
+
+        struct sockaddr_un addr;
+        memset(&addr, 0, sizeof(struct sockaddr_un));
+        addr.sun_family = AF_UNIX;
+        strncpy(addr.sun_path, sock_path, sizeof(addr.sun_path) - 1);
+
+        unlink(sock_path);
+
+        if (bind(sock_fd, (struct sockaddr*)&addr, sizeof(struct sockaddr_un)) == -1) {
+            ErrorMessage("UnixDomainC: bind error: %s \n", strerror(errno));
+            close(sock_fd);
+            return;
+        }
+
+        if (listen(sock_fd, 10) == -1) {
+            ErrorMessage("UnixDomainC: listen error: %s \n", strerror(errno));
+            close(sock_fd);
+            return;
+        }
+
+        ushort error_count = 0;
+
+        while (should_accept) {
+            if(error_count > 10)
+            {
+                ErrorMessage("UnixDomainC: Too many errors, stopping accept thread\n");
+                close(sock_fd);
+                return;
+            }
+            int peer_sfd = accept(sock_fd, nullptr, nullptr);
+            if (peer_sfd == -1) 
+            {
+                error_count++;
+                ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno));
+                continue;
+            }
+            error_count = 0;
+            auto config_copy = new UnixDomainConnectorConfig(*config);
+            auto unix_conn = new UnixDomainConnector(*config_copy, peer_sfd, 0);
+            handler(unix_conn, config_copy);
+        }
+    });
+
+}
+
+void UnixDomainConnectorListener::stop_accepting_connections()
+{
+    if(should_accept)
+    {
+        should_accept = false;
+        close(sock_fd);
+        if (accept_thread && accept_thread->joinable()) {
+            accept_thread->join();
+        }
+        delete accept_thread;
+        accept_thread = nullptr;
+    }
+}
index 28abd1b65a2d5aff6a46f436730fe76e2fd44757..08339e30f5068a3c0dac2b609e9986997ffdd27a 100644 (file)
@@ -23,6 +23,7 @@
 
 #include <atomic>
 #include <thread>
+#include <functional>
 
 #include "framework/connector.h"
 #include "managers/connector_manager.h"
@@ -48,6 +49,11 @@ public:
     uint16_t connector_msg_length;
 };
 
+class UnixDomainConnector;
+
+typedef std::function<void (UnixDomainConnector*,bool)> UnixDomainConnectorUpdateHandler;
+typedef std::function<void ()> UnixDomainConnectorMessageReceivedHandler;
+
 class UnixDomainConnector :  public snort::Connector 
 {
 public:
@@ -60,6 +66,9 @@ public:
     snort::ConnectorMsg receive_message(bool) override;
     void process_receive();
 
+    void set_update_handler(UnixDomainConnectorUpdateHandler handler);
+    void set_message_received_handler(UnixDomainConnectorMessageReceivedHandler handler);
+
     int sock_fd;
 
 private:
@@ -77,7 +86,31 @@ private:
     ReceiveRing* receive_ring;
     size_t instance_id;
     UnixDomainConnectorConfig cfg;
+
+    UnixDomainConnectorUpdateHandler update_handler;
+    UnixDomainConnectorMessageReceivedHandler message_received_handler;
+};
+
+typedef std::function<void (UnixDomainConnector*, UnixDomainConnectorConfig*)> UnixDomainConnectorAcceptHandler;
+
+class UnixDomainConnectorListener
+{
+public:
+    UnixDomainConnectorListener(const char* path);
+    ~UnixDomainConnectorListener();
+
+    void start_accepting_connections(UnixDomainConnectorAcceptHandler handler, UnixDomainConnectorConfig* config);
+    void stop_accepting_connections();
+
+    private:
+    char* sock_path;
+    int sock_fd;
+    std::thread* accept_thread;
+    std::atomic<bool> should_accept;
+
 };
 
+extern SO_PUBLIC UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler = nullptr);
+
 #endif // UNIXDOMAIN_CONNECTOR_H
 
index 931395f3e25ff0f36c3c7449190cbf2c29053bec..57d56abede7054d0fd34cb92febb5f8b77966643 100644 (file)
@@ -27,11 +27,23 @@ public:
     UnixDomainConnectorConfig()
     { direction = snort::Connector::CONN_DUPLEX; async_receive = true; }
 
+    UnixDomainConnectorConfig(const UnixDomainConnectorConfig& other) :
+        paths(other.paths), setup(other.setup),
+        conn_retries(other.conn_retries),
+        retry_interval(other.retry_interval),
+        max_retries(other.max_retries),
+        connect_timeout_seconds(other.connect_timeout_seconds),
+        async_receive(other.async_receive)
+    {
+        direction = other.direction;
+    }
+
     std::vector<std::string> paths; 
     Setup setup = {};
     bool conn_retries = false;
     uint32_t retry_interval = 4;
     uint32_t max_retries = 5;
+    uint32_t connect_timeout_seconds = 30;
     bool async_receive;
 };
 
index 109b20d35ab63b11b1b8a2af1ad331fcc730c3f1..d3eb4a7da90afcc45a2eb6f98906ae2f99c2e48e 100644 (file)
@@ -25,6 +25,7 @@ set (FRAMEWORK_INCLUDES
     range.h
     so_rule.h
     value.h
+    mp_transport.h
 )
 
 add_library ( framework OBJECT
index c806b6a61e001f63e0a8e2b8c142e69475526abc..4306687da2920273a6ac46eef99261a0c5498439 100644 (file)
@@ -54,6 +54,7 @@ enum PlugType
     PT_LOGGER,
     PT_CONNECTOR,
     PT_POLICY_SELECTOR,
+    PT_MP_TRANSPORT,
     PT_MAX
 };
 
index ef860647102dae5eb16cb07e0c6373dacf5f539f..1263b2e864060e830158901269635982214c8e21 100644 (file)
@@ -31,6 +31,8 @@
 #include "pub_sub/intrinsic_event_ids.h"
 #include "utils/stats.h"
 #include "main/snort_types.h"
+#include "managers/mp_transport_manager.h"
+#include "log/messages.h"
 
 using namespace snort;
 
@@ -105,6 +107,7 @@ void MPDataBus::receive_message(const MPEventInfo& event_info)
     UNUSED(event_info);
 }
 
+
 //--------------------------------------------------------------------------
 // private methods
 //--------------------------------------------------------------------------
index 0fb93aee76050ffeabae79efe640098e82eac72f..d69918c94b710e6050b53d01e7829c17f118a466 100644 (file)
@@ -39,6 +39,7 @@
 
 #include "main/snort_types.h"
 #include "data_bus.h"
+#include "framework/mp_transport.h"
 #include <bitset>
 
 namespace snort
@@ -47,8 +48,8 @@ class Flow;
 struct Packet;
 struct SnortConfig;
 
-typedef bool (*MPSerializeFunc)(const DataEvent& event, char** buffer, size_t* length);
-typedef bool (*MPDeserializeFunc)(const char* buffer, size_t length, DataEvent* event);
+typedef bool (*MPSerializeFunc)(DataEvent* event, char*& buffer, uint16_t* length);
+typedef bool (*MPDeserializeFunc)(const char* buffer, uint16_t length, DataEvent*& event);
 
 // Similar to the DataBus class, the MPDataBus class uses uses a combination of PubKey and event ID
 // for event subscriptions and publishing. New MP-specific event type enums should be added to the
@@ -62,16 +63,16 @@ struct MPEventInfo
 {
     MPEventType type;
     unsigned pub_id;
-    DataEvent event;
-    MPEventInfo(const DataEvent& e, MPEventType t, unsigned id = 0)
+    DataEvent* event;
+    MPEventInfo(DataEvent* e, MPEventType t, unsigned id = 0)
         : type(t), pub_id(id), event(e) {}
 };
 
 struct MPHelperFunctions {
-    MPSerializeFunc* serializer;
-    MPDeserializeFunc* deserializer;
+    MPSerializeFunc serializer;
+    MPDeserializeFunc deserializer;
     
-    MPHelperFunctions(MPSerializeFunc* s, MPDeserializeFunc* d) 
+    MPHelperFunctions(MPSerializeFunc s, MPDeserializeFunc d) 
         : serializer(s), deserializer(d) {}
 };
 
diff --git a/src/framework/mp_transport.h b/src/framework/mp_transport.h
new file mode 100644 (file)
index 0000000..4b2decb
--- /dev/null
@@ -0,0 +1,81 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transport.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_TRANSPORT_H
+#define MP_TRANSPORT_H
+
+#include "main/snort_types.h"
+#include "framework/base_api.h"
+
+#include <functional>
+
+namespace snort
+{
+
+#define MP_TRANSPORT_API_VERSION ((BASE_API_VERSION << 16) | 1)
+
+struct SnortConfig;
+struct MPEventInfo;
+struct MPHelperFunctions;
+
+typedef std::function<void (const MPEventInfo& event_info)> TransportReceiveEventHandler;
+
+class MPTransport
+{
+    public:
+
+    MPTransport() = default;
+    virtual ~MPTransport() = default;
+
+    virtual bool configure(const SnortConfig*) = 0;
+    virtual void thread_init() = 0;
+    virtual void thread_term() = 0;
+    virtual void init_connection() = 0;
+    virtual bool send_to_transport(MPEventInfo& event) = 0;
+    virtual void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) = 0;
+    virtual void register_receive_handler(const TransportReceiveEventHandler& handler) = 0;
+    virtual void unregister_receive_handler() = 0;
+    virtual void enable_logging() = 0;
+    virtual void disable_logging() = 0;
+    virtual bool is_logging_enabled() = 0;
+};
+
+
+typedef MPTransport* (* MPTransportNewFunc)(Module*);
+typedef void (* MPTransportDelFunc)(MPTransport*);
+typedef void (* MPTransportThreadInitFunc)(MPTransport*);
+typedef void (* MPTransportThreadTermFunc)(MPTransport*);
+typedef void (* MPTransportFunc)();
+
+struct MPTransportApi
+{
+    BaseApi base;
+    unsigned flags;
+
+    MPTransportFunc pinit;     // plugin init
+    MPTransportFunc pterm;     // cleanup pinit()
+    MPTransportThreadInitFunc tinit;     // thread local init
+    MPTransportThreadTermFunc tterm;     // cleanup tinit()
+
+    MPTransportNewFunc ctor;
+    MPTransportDelFunc dtor;
+};
+
+}
+#endif // MP_TRANSPORT_H
index 7f3fafbfbcb9849a801c970f958cf83a74b5dcbe..8d6f7eb42dd26c00ac0c875726957b6cf8747a1c 100644 (file)
@@ -34,6 +34,7 @@
 #include "managers/plugin_manager.h"
 #include "managers/script_manager.h"
 #include "managers/so_manager.h"
+#include "managers/mp_transport_manager.h"
 #include "packet_io/sfdaq.h"
 #include "utils/util.h"
 
@@ -211,6 +212,7 @@ enum HelpType
     ModuleManager::term();
     PluginManager::release_plugins();
     ScriptManager::release_scripts();
+    MPTransportManager::term();
     delete SnortConfig::get_conf();
     exit(0);
 }
index e9e5d41c558d9bbb88153f27b26ead285b78e528..2f889ddc21a23154f11e300205c3df6be81d9504 100644 (file)
@@ -60,6 +60,7 @@
 #include "managers/plugin_manager.h"
 #include "managers/policy_selector_manager.h"
 #include "managers/script_manager.h"
+#include "managers/mp_transport_manager.h"
 #include "memory/memory_cap.h"
 #include "network_inspectors/network_inspectors.h"
 #include "packet_io/active.h"
@@ -79,6 +80,7 @@
 #include "trace/trace_api.h"
 #include "trace/trace_config.h"
 #include "trace/trace_logger.h"
+#include "mp_transport/mp_transports.h"
 #include "utils/stats.h"
 #include "utils/util.h"
 
@@ -132,6 +134,7 @@ void Snort::init(int argc, char** argv)
     load_stream_inspectors();
     load_network_inspectors();
     load_service_inspectors();
+    load_mp_transports();
 
     snort_cmd_line_conf = parse_cmd_line(argc, argv);
     SnortConfig::set_conf(snort_cmd_line_conf);
@@ -371,6 +374,7 @@ void Snort::term()
     host_cache.term();
     PluginManager::release_plugins();
     ScriptManager::release_scripts();
+    MPTransportManager::term();
     memory::MemoryCap::term();
     detection_filter_term();
 
index d748f99e4128e21af35f2d314f3e36f9318d37e6..f72ddfb7eb67c8e3939188d1c27852b2dcdacccc 100644 (file)
@@ -38,6 +38,8 @@ add_library( managers OBJECT
     so_manager.h
     connector_manager.cc
     connector_manager.h
+    mp_transport_manager.cc
+    mp_transport_manager.h
 )
 
 add_custom_command (
diff --git a/src/managers/mp_transport_manager.cc b/src/managers/mp_transport_manager.cc
new file mode 100644 (file)
index 0000000..9dc73af
--- /dev/null
@@ -0,0 +1,105 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transport_manager.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_transport_manager.h"
+
+#include <unordered_map>
+
+using namespace snort;
+
+struct MPTransportHandler
+{
+    MPTransportHandler(MPTransport* transport, const MPTransportApi* api)
+        : transport(transport), api(api) {}
+    MPTransport* transport;
+    const MPTransportApi* api;
+};
+
+static std::unordered_map<std::string, MPTransportHandler*> transports_map;
+
+void MPTransportManager::instantiate(const MPTransportApi *api, Module *mod, SnortConfig*)
+{
+    if(transports_map.find(api->base.name) != transports_map.end())
+    {
+        return;
+    }
+
+    transports_map.insert(std::make_pair(api->base.name, new MPTransportHandler(api->ctor(mod), api)));
+}
+
+MPTransport *MPTransportManager::get_transport(const std::string &name)
+{
+    auto it = transports_map.find(name);
+    if (it != transports_map.end())
+    {
+        return it->second->transport;
+    }
+    return nullptr;
+}
+
+void MPTransportManager::add_plugin(const MPTransportApi *api)
+{
+    if (api->pinit)
+    {
+        api->pinit();
+    }
+}
+
+void MPTransportManager::thread_init()
+{
+    for (auto &transport : transports_map)
+    {
+        if (transport.second->api->tinit)
+        {
+            transport.second->api->tinit(transport.second->transport);
+        }
+    }
+}
+
+void MPTransportManager::thread_term()
+{
+    for (auto &transport : transports_map)
+    {
+        if (transport.second->api->tterm)
+        {
+            transport.second->api->tterm(transport.second->transport);
+        }
+    }
+}
+
+void MPTransportManager::term()
+{
+    for (auto &transport : transports_map)
+    {
+        if (transport.second->api->dtor)
+        {
+            transport.second->api->dtor(transport.second->transport);
+        }
+        if (transport.second->api->pterm)
+        {
+            transport.second->api->pterm();
+        }
+        delete transport.second;
+    }
+    transports_map.clear();
+}
diff --git a/src/managers/mp_transport_manager.h b/src/managers/mp_transport_manager.h
new file mode 100644 (file)
index 0000000..c1e6503
--- /dev/null
@@ -0,0 +1,51 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transport_manager.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_TRANSPORT_MANAGER_H
+#define MP_TRANSPORT_MANAGER_H
+
+// Manager for multiprocess layer objects
+
+#include <string>
+#include "framework/mp_transport.h"
+
+namespace snort
+{
+class Module;
+struct SnortConfig;
+
+
+//-------------------------------------------------------------------------
+
+class MPTransportManager
+{
+public:
+    static void instantiate(const MPTransportApi* api, Module* mod, SnortConfig*);
+    static MPTransport* get_transport(const std::string& name);
+
+    static void add_plugin(const MPTransportApi* api);
+
+    static void thread_init();
+    static void thread_term();
+
+    static void term();
+};
+}
+#endif
+
index 6f0f4ec09c80aa1261817c4e50f0002c011d0a1e..e87ab9b3dba22b4ad980e332697f2bc4d25456ef 100644 (file)
@@ -42,6 +42,7 @@
 #include "ips_manager.h"
 #include "module_manager.h"
 #include "mpse_manager.h"
+#include "mp_transport_manager.h"
 #include "policy_selector_manager.h"
 #include "script_manager.h"
 #include "so_manager.h"
@@ -72,6 +73,7 @@ static Symbol symbols[PT_MAX] =
     { "logger", LOGAPI_VERSION, sizeof(LogApi) },
     { "connector", CONNECTOR_API_VERSION, sizeof(ConnectorApi) },
     { "policy_selector", POLICY_SELECTOR_API_VERSION, sizeof(PolicySelectorApi) },
+    { "mp_transport", MP_TRANSPORT_API_VERSION, sizeof(MPTransportApi) },
 };
 #else
 // this gets around the sequence issue with some compilers
@@ -88,7 +90,9 @@ static Symbol symbols[PT_MAX] =
     [PT_LOGGER] = { stringify(PT_LOGGER), LOGAPI_VERSION, sizeof(LogApi) },
     [PT_CONNECTOR] = { stringify(PT_CONNECTOR), CONNECTOR_API_VERSION, sizeof(ConnectorApi) },
     [PT_POLICY_SELECTOR] = { stringify(PT_POLICY_SELECTOR), POLICY_SELECTOR_API_VERSION,
-        sizeof(PolicySelectorApi) }
+        sizeof(PolicySelectorApi) },
+    [PT_MP_TRANSPORT] = { stringify(PT_MP_TRANSPORT), MP_TRANSPORT_API_VERSION,
+        sizeof(MPTransportApi) }
 };
 #endif
 
@@ -321,6 +325,10 @@ static void add_plugin(Plugin& p)
     case PT_POLICY_SELECTOR:
         PolicySelectorManager::add_plugin((const PolicySelectorApi*)p.api);
         break;
+    
+    case PT_MP_TRANSPORT:
+        MPTransportManager::add_plugin((const MPTransportApi*)p.api);
+        break;
 
     default:
         assert(false);
@@ -546,6 +554,10 @@ void PluginManager::instantiate(
         //IpsManager::instantiate((SoApi*)api, mod, sc);
         break;
 
+    case PT_MP_TRANSPORT:
+        MPTransportManager::instantiate((const MPTransportApi*)api, mod, sc);
+        break;
+
     case PT_LOGGER:
         EventManager::instantiate((const LogApi*)api, mod, sc);
         break;
index cecd450f16ef089c86bdb4cba74082af309d7afa..059dbf51bc2e40a1e08b3ed26c607a59dc2784f5 100644 (file)
@@ -3,3 +3,10 @@ add_cpputest(get_inspector_test
         get_inspector_stubs.h
         ../inspector_manager.cc
 )
+
+add_cpputest(mp_transport_manager_test
+    SOURCES
+        mp_transport_manager_test.cc
+        ../mp_transport_manager.cc
+        ../../framework/module.cc
+)
diff --git a/src/managers/test/mp_transport_manager_test.cc b/src/managers/test/mp_transport_manager_test.cc
new file mode 100644 (file)
index 0000000..5438ac5
--- /dev/null
@@ -0,0 +1,197 @@
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "framework/mp_data_bus.h"
+#include "framework/mp_transport.h"
+#include "framework/module.h"
+#include "main/thread_config.h"
+
+#include "../mp_transport_manager.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+#include <unordered_map>
+
+#define MODULE_NAME "mock_transport"
+#define MODULE_HELP "mock transport for testing"
+
+static int test_transport_ctor_calls = 0;
+static int test_transport_dtor_calls = 0;
+static int test_transport_pinit_calls = 0;
+static int test_transport_pterm_calls = 0;
+static int test_transport_tinit_calls = 0;
+static int test_transport_tterm_calls = 0;
+
+namespace snort
+{
+
+class MockTransport : public MPTransport
+{
+    public:
+    MockTransport() : MPTransport()
+    { }
+    virtual ~MockTransport() override
+    { }
+    virtual bool send_to_transport(MPEventInfo& event) override
+    { return true; }
+    virtual void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) override
+    { }
+    virtual void init_connection() override
+    { }
+    virtual void register_receive_handler(const TransportReceiveEventHandler& handler) override
+    { }
+    virtual void unregister_receive_handler() override
+    { }
+    virtual void thread_init() override
+    { }
+    virtual void thread_term() override
+    { }
+    virtual bool configure(const SnortConfig*) override
+    { return true; }
+    virtual void enable_logging() override
+    { }
+    virtual void disable_logging() override
+    { }
+    virtual bool is_logging_enabled() override
+    { return false; }
+};
+
+unsigned get_instance_id() { return 0; }
+unsigned ThreadConfig::get_instance_max() { return 1; }
+};
+
+
+using namespace snort;
+
+void show_stats(unsigned long*, PegInfo const*, std::vector<unsigned int, std::allocator<unsigned int> > const&, char const*, _IO_FILE*)
+{}
+void show_stats(unsigned long*, PegInfo const*, unsigned int, char const*)
+{}
+
+static void mock_transport_tinit(MPTransport* t)
+{
+    test_transport_tinit_calls++;
+}
+static void mock_transport_tterm(MPTransport* t)
+{
+    test_transport_tterm_calls++;
+}
+static MPTransport* mock_transport_ctor(Module* m)
+{   
+    test_transport_ctor_calls++;
+    return new MockTransport();
+}
+static void mock_transport_dtor(MPTransport* t)
+{
+    test_transport_dtor_calls++;
+    delete t;
+}
+
+static void mock_transport_pinit()
+{
+    // Mock plugin init
+    test_transport_pinit_calls++;
+}
+static void mock_transport_pterm()
+{
+    // Mock plugin term
+    test_transport_pterm_calls++;
+}
+
+static void clear_test_calls()
+{
+    test_transport_ctor_calls = 0;
+    test_transport_dtor_calls = 0;
+    test_transport_pinit_calls = 0;
+    test_transport_pterm_calls = 0;
+    test_transport_tinit_calls = 0;
+    test_transport_tterm_calls = 0;
+}
+
+static struct MPTransportApi mock_transport_api =
+{
+    {
+        PT_MP_TRANSPORT,
+        sizeof(MPTransportApi),
+        MP_TRANSPORT_API_VERSION,
+        2,
+        API_RESERVED,
+        API_OPTIONS,
+        MODULE_NAME,
+        MODULE_HELP,
+        nullptr,
+        nullptr
+    },
+    0,
+    mock_transport_pinit,
+    mock_transport_pterm,
+    mock_transport_tinit,
+    mock_transport_tterm,
+    mock_transport_ctor,
+    mock_transport_dtor
+};
+
+TEST_GROUP(mp_transport_manager_test_group)
+{
+    void setup() override
+    {
+        clear_test_calls();
+    }
+
+    void teardown() override
+    {
+        MPTransportManager::term();
+    }
+};
+
+TEST(mp_transport_manager_test_group, instantiate_transport_object)
+{
+    MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+    CHECK(test_transport_ctor_calls == 1);
+}
+
+TEST(mp_transport_manager_test_group, get_transport_object)
+{
+    MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+    MPTransport* transport = MPTransportManager::get_transport(MODULE_NAME);
+    CHECK(transport != nullptr);
+
+    transport = MPTransportManager::get_transport("non_existent_transport");
+    CHECK(transport == nullptr);
+
+    MPTransportManager::term();
+
+    transport = MPTransportManager::get_transport(MODULE_NAME);
+    CHECK(transport == nullptr);
+}
+
+TEST(mp_transport_manager_test_group, add_plugin)
+{
+    MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+    MPTransportManager::add_plugin(&mock_transport_api);
+    CHECK(test_transport_pinit_calls == 1);
+    
+    MPTransportManager::term();
+    CHECK(test_transport_pterm_calls == 1);
+}
+
+TEST(mp_transport_manager_test_group, thread_init_term)
+{
+    MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+    MPTransportManager::thread_init();
+    CHECK(test_transport_tinit_calls == 1);
+
+    MPTransportManager::thread_term();
+    CHECK(test_transport_tterm_calls == 1);
+}
+
+int main(int argc, char** argv)
+{
+    // Allocate object in internal MPTransportManager unordered_map to prevent false-positive memory leak detection ( buckets allocation at first insert )
+    MPTransportManager::instantiate(&mock_transport_api, nullptr, nullptr);
+    MPTransportManager::term();
+    int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
+    return return_value;
+}
\ No newline at end of file
diff --git a/src/mp_transport/CMakeLists.txt b/src/mp_transport/CMakeLists.txt
new file mode 100644 (file)
index 0000000..5c42545
--- /dev/null
@@ -0,0 +1,6 @@
+add_subdirectory(mp_unix_transport)
+
+add_library( mp_transports OBJECT
+    mp_transports.cc
+    mp_transports.h
+)
\ No newline at end of file
diff --git a/src/mp_transport/dev_notes.txt b/src/mp_transport/dev_notes.txt
new file mode 100644 (file)
index 0000000..583d8b3
--- /dev/null
@@ -0,0 +1,5 @@
+The `mp_transport` files provide the foundation for message-passing transport mechanisms, enabling communication 
+between distributed components or processes in a modular and extensible manner. These files abstract the complexities 
+of underlying communication protocols, offering a unified API for sending, receiving, and managing messages. This 
+abstraction ensures that developers can focus on higher-level application logic without worrying about protocol-specific 
+details. These transports are utilized by `mp_data_bus.h` to transfer events between distributed Snort instances.
diff --git a/src/mp_transport/mp_transports.cc b/src/mp_transport/mp_transports.cc
new file mode 100644 (file)
index 0000000..c656109
--- /dev/null
@@ -0,0 +1,36 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transports.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_transports.h"
+
+#include "framework/mp_transport.h"
+#include "managers/plugin_manager.h"
+
+using namespace snort;
+
+extern const BaseApi* mp_unix_transport[];
+
+void load_mp_transports()
+{
+    PluginManager::load_plugins(mp_unix_transport);
+}
diff --git a/src/mp_transport/mp_transports.h b/src/mp_transport/mp_transports.h
new file mode 100644 (file)
index 0000000..4416d89
--- /dev/null
@@ -0,0 +1,25 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_transports.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_TRANSPORTS_H
+#define MP_TRANSPORTS_H
+
+void load_mp_transports();
+
+#endif // MP_TRANSPORTS_H
diff --git a/src/mp_transport/mp_unix_transport/CMakeLists.txt b/src/mp_transport/mp_unix_transport/CMakeLists.txt
new file mode 100644 (file)
index 0000000..11b8026
--- /dev/null
@@ -0,0 +1,24 @@
+
+set( MP_TRANSPORT_INCLUDES
+    mp_unix_transport.h
+    mp_unix_transport_module.h
+)
+
+add_library( mp_unix_transport OBJECT
+    ${MP_TRANSPORT_INCLUDES}
+    mp_unix_transport.cc
+    mp_unix_transport_module.cc
+)
+
+install (FILES ${MIME_INCLUDES}
+    DESTINATION "${INCLUDE_INSTALL_PATH}/mp_unix_transport"
+)
+
+add_dependencies( mp_unix_transport framework )
+add_dependencies( mp_unix_transport unixdomain_connector )
+
+find_package(Threads REQUIRED)
+target_link_libraries(mp_unix_transport PRIVATE Threads::Threads)
+target_link_libraries(mp_unix_transport PRIVATE $<TARGET_OBJECTS:unixdomain_connector>)
+
+add_subdirectory( test )
\ No newline at end of file
diff --git a/src/mp_transport/mp_unix_transport/dev_notes.txt b/src/mp_transport/mp_unix_transport/dev_notes.txt
new file mode 100644 (file)
index 0000000..9f4af4d
--- /dev/null
@@ -0,0 +1,13 @@
+The MP Unix Domain Transport provides an implementation of the Multi-Process (MP) Transport
+interface using existing `UnixDomainConnector` infrastructure. This transport enables
+inter-process communication (IPC) between multiple Snort instances running on the same host,
+allowing them to exchange events and data.
+
+ * MPUnixDomainTransport - Main class implementing the MPTransport interface with Unix domain socket functionality
+ * MPUnixDomainTransportModule - Module class that handles configuration parameters and provides an API for MPUnixDomainTransport creation
+ * MPUnixDomainTransportConfig - Configuration structure for socket paths and connection parameters
+
+Connention between snort establishes in next sequence:
+* First Snort instance acts as a server socket accepting connections from other instances
+* Additional instances connect to existing socket paths
+* Dynamic re-connection handling with configurable retry parameters
diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport.cc
new file mode 100644 (file)
index 0000000..6b15aca
--- /dev/null
@@ -0,0 +1,447 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_unix_transport.h"
+
+#include <cstring>
+#include <fcntl.h>
+#include <iostream>
+#include <poll.h>
+#include <sys/socket.h>
+#include <sys/un.h>
+#include <unistd.h>
+
+#include "framework/mp_data_bus.h"
+#include "log/messages.h"
+#include "main/snort.h"
+#include "main/snort_config.h"
+
+static std::mutex _receive_mutex;
+static std::mutex _update_connectors_mutex;
+
+#define UNIX_SOCKET_NAME_PREFIX "/snort_unix_connector_"
+
+#define MP_TRANSPORT_LOG_LABEL "MPUnixTransport"
+
+#define MP_TRANSPORT_LOG(msg, ...) do { \
+        if (!this->is_logging_enabled_flag) \
+            break; \
+        LogMessage(msg, __VA_ARGS__); \
+    } while (0)
+
+namespace snort
+{
+
+#pragma pack(push, 1) 
+    enum MPTransportMessageType 
+    { 
+        EVENT_MESSAGE = 0,
+        MAX_TYPE
+    };
+
+struct MPTransportMessageHeader
+{
+    MPTransportMessageType type;// Type of the message 
+    int32_t pub_id;             // Identifier for the module sending or receiving the message 
+    int32_t event_id;           // Identifier for the specific event 
+    uint16_t data_length;       // Length of the data payload 
+};
+
+struct MPTransportMessage
+{
+    MPTransportMessageHeader header; // Header containing metadata about the message
+    char* data;                      // Placeholder for the actual data payload 
+}; 
+#pragma pack(pop) 
+
+void MPUnixDomainTransport::side_channel_receive_handler(SCMessage* msg)
+{
+    if (transport_receive_handler and msg)
+    {
+        if (msg->content_length < sizeof(MPTransportMessage))
+        {
+            MP_TRANSPORT_LOG("%s: Incomplete message received\n", MP_TRANSPORT_LOG_LABEL);
+            return;
+        }
+
+        MPTransportMessageHeader* transport_message_header = (MPTransportMessageHeader*)msg->content;
+        
+        if (transport_message_header->type >= MAX_TYPE)
+        {
+            MP_TRANSPORT_LOG("%s: Invalid message type received\n", MP_TRANSPORT_LOG_LABEL);
+            return;
+        }
+
+        auto deserialize_func = get_event_deserialization_function(transport_message_header->pub_id, transport_message_header->event_id);
+        if (!deserialize_func)
+        {
+            MP_TRANSPORT_LOG("%s: No deserialization function found for event: type %d, id %d\n", MP_TRANSPORT_LOG_LABEL, transport_message_header->type, transport_message_header->event_id);
+            return;
+        }
+
+        DataEvent* internal_event = nullptr;
+        (deserialize_func)((const char*)(msg->content + sizeof(MPTransportMessageHeader)), transport_message_header->data_length, internal_event);
+        MPEventInfo event(internal_event, transport_message_header->event_id, transport_message_header->pub_id);
+
+        (transport_receive_handler)(event);
+
+        delete internal_event;
+    }
+    delete msg;
+}
+
+void MPUnixDomainTransport::handle_new_connection(UnixDomainConnector *connector, UnixDomainConnectorConfig* cfg)
+{
+    assert(connector);
+    assert(cfg);
+
+    std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+
+    auto side_channel = new SideChannel(ScMsgFormat::BINARY);
+    side_channel->connector_receive = connector;
+    side_channel->connector_transmit = side_channel->connector_receive;
+    side_channel->register_receive_handler(std::bind(&MPUnixDomainTransport::side_channel_receive_handler, this, std::placeholders::_1));
+    connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
+    this->side_channels.push_back(new SideChannelHandle(side_channel, cfg));
+    connector->set_update_handler(std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+}
+
+MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig *c) : MPTransport(), 
+    config(c)
+{
+    this->is_logging_enabled_flag = c->enable_logging;
+}
+
+MPUnixDomainTransport::~MPUnixDomainTransport()
+{
+    cleanup();
+}
+
+bool MPUnixDomainTransport::send_to_transport(MPEventInfo &event)
+{
+    auto serialize_func = get_event_serialization_function(event.pub_id, event.type);
+
+    if (!serialize_func)
+    {
+        MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event.type);
+        return false;
+    }
+
+    MPTransportMessage transport_message;
+    transport_message.header.type = EVENT_MESSAGE;
+    transport_message.header.pub_id = event.pub_id;
+    transport_message.header.event_id = event.type;
+
+    
+    (serialize_func)(event.event, transport_message.data, &transport_message.header.data_length);
+    for (auto &&sc_handler : this->side_channels)
+    {
+        auto msg = sc_handler->side_channel->alloc_transmit_message(sizeof(MPTransportMessageHeader) + transport_message.header.data_length);
+        memcpy(msg->content, &transport_message, sizeof(MPTransportMessageHeader));
+        memcpy(msg->content + sizeof(MPTransportMessageHeader), transport_message.data, transport_message.header.data_length);
+        auto send_result = sc_handler->side_channel->transmit_message(msg);
+        if (!send_result)
+        {
+            MP_TRANSPORT_LOG("%s: Failed to send message to side channel\n", MP_TRANSPORT_LOG_LABEL);
+        }
+    }
+
+    delete[] transport_message.data;
+    
+    return true;
+}
+
+void MPUnixDomainTransport::register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper)
+{
+    assert(helper.deserializer);
+    assert(helper.serializer);
+    
+    this->event_helpers[pub_id] = SerializeFunctionHandle();
+    this->event_helpers[pub_id].serialize_functions.insert({event_id, helper});
+}
+
+void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler& handler)
+{
+    this->transport_receive_handler = handler;
+}
+
+void MPUnixDomainTransport::unregister_receive_handler()
+{
+    this->transport_receive_handler = nullptr;
+}
+
+void MPUnixDomainTransport::process_messages_from_side_channels()
+{
+    std::unique_lock<std::mutex> lock(_receive_mutex);
+    do
+    {
+        if ( (std::cv_status::timeout == this->consume_thread_cv.wait_for(lock, std::chrono::milliseconds(config->consume_message_timeout_milliseconds)) )
+            and this->consume_message_received == false )
+        {
+            continue;
+        }
+
+        {
+            std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+            bool messages_left;
+
+            do
+            {
+                messages_left = false;
+                for (auto &&sc_handler : this->side_channels)
+                {
+                    messages_left |= sc_handler->side_channel->process(config->consume_message_batch_size);
+                }
+            } while (messages_left);
+        }
+
+        this->consume_message_received = false;
+
+    } while (this->is_running);
+}
+
+void MPUnixDomainTransport::notify_process_thread()
+{
+    this->consume_thread_cv.notify_all();
+    this->consume_message_received = true;
+}
+
+void MPUnixDomainTransport::connector_update_handler(UnixDomainConnector *connector, bool is_recconecting, SideChannel *side_channel)
+{
+    std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+    if (side_channel->connector_receive)
+    {
+        delete side_channel->connector_receive;
+        side_channel->connector_receive = side_channel->connector_transmit = nullptr;
+    }
+
+    if (connector)
+    {
+        side_channel->connector_receive = side_channel->connector_transmit = connector;
+    }
+    else
+    {
+        if (is_recconecting == false)
+        {
+            MP_TRANSPORT_LOG("%s: Accepted connection interrupted, removing handle\n", MP_TRANSPORT_LOG_LABEL);
+            for(auto it = this->side_channels.begin(); it != this->side_channels.end(); ++it)
+            {
+                if ((*it)->side_channel == side_channel)
+                {
+                    delete *it;
+                    this->side_channels.erase(it);
+                    break;
+                }
+            }
+        }
+    }
+}
+
+MPSerializeFunc MPUnixDomainTransport::get_event_serialization_function(unsigned pub_id, unsigned event_id)
+{
+    auto helper_it = this->event_helpers.find(pub_id);
+    if (helper_it == this->event_helpers.end())
+    {
+        MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id);
+        return nullptr;
+    }
+    auto helper_functions = helper_it->second.get_function_set(event_id);
+    if (!helper_functions)
+    {
+        MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id);
+        return nullptr;
+    }
+    return helper_functions->serializer;
+}
+
+MPDeserializeFunc MPUnixDomainTransport::get_event_deserialization_function(unsigned pub_id, unsigned event_id)
+{
+    auto helper_it = this->event_helpers.find(pub_id);
+    if (helper_it == this->event_helpers.end())
+    {
+        MP_TRANSPORT_LOG("%s: No available helper functions is registered for %d\n", MP_TRANSPORT_LOG_LABEL, pub_id);
+        return nullptr;
+    }
+    auto helper_functions = helper_it->second.get_function_set(event_id);
+    if (!helper_functions)
+    {
+        MP_TRANSPORT_LOG("%s: No serialize function found for event %d\n", MP_TRANSPORT_LOG_LABEL, event_id);
+        return nullptr;
+    }
+    return helper_functions->deserializer;
+}
+
+void MPUnixDomainTransport::init_connection()
+{
+    init_side_channels();
+}
+
+void MPUnixDomainTransport::thread_init()
+{
+}
+
+void MPUnixDomainTransport::thread_term()
+{
+}
+
+bool MPUnixDomainTransport::configure(const SnortConfig *c)
+{
+    config->max_processes = c->max_procs;
+    return true;
+}
+
+void MPUnixDomainTransport::cleanup()
+{
+    this->is_running = false;
+    MPUnixDomainTransport::unregister_receive_handler();
+    if (this->consume_thread)
+    {
+        this->consume_thread_cv.notify_all();
+        this->consume_thread->join();
+        delete this->consume_thread;
+        this->consume_thread = nullptr;
+    }
+    cleanup_side_channels();
+    for (auto &&ac_handler : this->accept_handlers)
+    {
+        ac_handler->listener->stop_accepting_connections();
+        delete ac_handler->listener;
+        delete ac_handler->connector_config;
+        delete ac_handler;
+    }
+    this->accept_handlers.clear();
+}
+
+void MPUnixDomainTransport::init_side_channels()
+{
+    assert(config);
+    if (config->max_processes < 2)
+        return;
+
+    auto instance_id = Snort::get_process_id();//Snort instance id
+    auto max_processes = config->max_processes;
+
+    this->is_running = true;
+
+    for (ushort i = instance_id; i < max_processes; i++)
+    {
+        auto listen_path = config->unix_domain_socket_path + UNIX_SOCKET_NAME_PREFIX + std::to_string(i);
+        auto unix_listener = new UnixDomainConnectorListener(listen_path.c_str());
+        
+        UnixDomainConnectorConfig* unix_config = new UnixDomainConnectorConfig();
+        unix_config->setup = UnixDomainConnectorConfig::Setup::ANSWER;
+        unix_config->async_receive = true;
+        if (config->conn_retries)
+        {
+            unix_config->conn_retries = config->conn_retries;
+            unix_config->retry_interval = config->retry_interval_seconds;
+            unix_config->max_retries = config->max_retries;
+            unix_config->connect_timeout_seconds = config->connect_timeout_seconds;
+        }
+        else
+        {
+            unix_config->conn_retries = false;
+            unix_config->retry_interval = 0;
+            unix_config->max_retries = 0;
+            unix_config->connect_timeout_seconds = 0;
+        }
+        unix_config->paths.push_back(listen_path);
+
+        unix_listener->start_accepting_connections( std::bind(&MPUnixDomainTransport::handle_new_connection, this, std::placeholders::_1, std::placeholders::_2), unix_config);
+        
+        auto unix_listener_handle = new UnixAcceptorHandle();
+        unix_listener_handle->connector_config = unix_config;
+        unix_listener_handle->listener = unix_listener;
+        this->accept_handlers.push_back(unix_listener_handle);
+    }
+
+    for (ushort i = 1; i < instance_id; i++)
+    {
+        auto side_channel = new SideChannel(ScMsgFormat::BINARY);
+        side_channel->register_receive_handler([this](SCMessage* msg) { this->side_channel_receive_handler(msg); });
+
+        auto send_path = config->unix_domain_socket_path + "/" + "snort_unix_connector_" + std::to_string(i);
+
+        UnixDomainConnectorConfig* connector_conf = new UnixDomainConnectorConfig();
+        connector_conf->setup = UnixDomainConnectorConfig::Setup::CALL;
+        connector_conf->async_receive = true;
+        connector_conf->conn_retries = config->conn_retries;
+        connector_conf->retry_interval = config->retry_interval_seconds;
+        connector_conf->max_retries = config->max_retries;
+        connector_conf->connect_timeout_seconds = config->connect_timeout_seconds;
+        connector_conf->paths.push_back(send_path);
+
+        auto connector = unixdomain_connector_tinit_call(*connector_conf, send_path.c_str(), 0, std::bind(&MPUnixDomainTransport::connector_update_handler, this, std::placeholders::_1, std::placeholders::_2, side_channel));
+
+        if (connector)
+            connector->set_message_received_handler(std::bind(&MPUnixDomainTransport::notify_process_thread, this));
+
+        side_channel->connector_receive = connector;
+        side_channel->connector_transmit = side_channel->connector_receive;
+        this->side_channels.push_back( new SideChannelHandle(side_channel, connector_conf));
+    }
+
+    this->consume_thread = new std::thread(&MPUnixDomainTransport::process_messages_from_side_channels, this);
+}
+void MPUnixDomainTransport::cleanup_side_channels()
+{
+    std::lock_guard<std::mutex> guard(_update_connectors_mutex);
+
+    for (uint i = 0; i < this->side_channels.size(); i++)
+    {
+        auto side_channel = this->side_channels[i];
+        delete side_channel;
+    }
+
+    this->side_channels.clear();
+}
+
+SideChannelHandle::~SideChannelHandle()
+{
+    if (side_channel)
+    {
+        if (side_channel->connector_receive)
+            delete side_channel->connector_receive;
+
+        delete side_channel;
+    }
+    
+    if (connector_config)
+        delete connector_config;
+}
+void MPUnixDomainTransport::enable_logging()
+{
+    this->is_logging_enabled_flag = true;
+}
+
+void MPUnixDomainTransport::disable_logging()
+{
+    this->is_logging_enabled_flag = false;
+}
+
+bool MPUnixDomainTransport::is_logging_enabled()
+{
+    return this->is_logging_enabled_flag;
+}
+
+};
diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport.h b/src/mp_transport/mp_unix_transport/mp_unix_transport.h
new file mode 100644 (file)
index 0000000..8208e03
--- /dev/null
@@ -0,0 +1,132 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef UNIX_TRANSPORT_H
+#define UNIX_TRANSPORT_H
+
+#include "connectors/unixdomain_connector/unixdomain_connector.h"
+#include "framework/mp_data_bus.h"
+#include "main/snort_types.h"
+#include "side_channel/side_channel.h"
+
+#include <atomic>
+#include <thread>
+#include <condition_variable>
+
+namespace snort
+{
+
+struct MPUnixDomainTransportConfig
+{
+    std::string unix_domain_socket_path;
+    uint16_t max_processes = 0;
+    bool conn_retries = true;
+    bool enable_logging = false;
+    uint32_t retry_interval_seconds = 5;
+    uint32_t max_retries = 5;
+    uint32_t connect_timeout_seconds = 30;
+    uint32_t consume_message_timeout_milliseconds = 100;
+    uint32_t consume_message_batch_size = 5;
+};
+
+struct SerializeFunctionHandle
+{
+    std::unordered_map<unsigned, MPHelperFunctions> serialize_functions;
+
+    MPHelperFunctions* get_function_set(unsigned event_id)
+    {
+        auto it = serialize_functions.find(event_id);
+        if(it == serialize_functions.end())
+            return nullptr;
+        return &it->second;
+    }
+};
+
+struct SideChannelHandle
+{
+    SideChannelHandle(SideChannel* sc, UnixDomainConnectorConfig* cc) :
+        side_channel(sc), connector_config(cc)
+    { }
+
+    ~SideChannelHandle();
+
+    SideChannel* side_channel;
+    UnixDomainConnectorConfig* connector_config;
+};
+
+struct UnixAcceptorHandle
+{
+    UnixDomainConnectorConfig* connector_config = nullptr;
+    UnixDomainConnectorListener* listener = nullptr;
+};
+
+class MPUnixDomainTransport : public MPTransport
+{
+    public:
+
+    MPUnixDomainTransport(MPUnixDomainTransportConfig* c);
+    ~MPUnixDomainTransport() override;
+
+    bool configure(const SnortConfig*) override;
+    void thread_init() override;
+    void thread_term() override;
+    void init_connection() override;
+    bool send_to_transport(MPEventInfo& event) override;
+    void register_event_helpers(const unsigned& pub_id, const unsigned& event_id, MPHelperFunctions& helper) override;
+    void register_receive_handler(const TransportReceiveEventHandler& handler) override;
+    void unregister_receive_handler() override;
+    void enable_logging() override;
+    void disable_logging() override;
+    bool is_logging_enabled() override;
+    void cleanup();
+
+    MPUnixDomainTransportConfig* get_config()
+    { return config; }
+
+
+    private:
+
+    void init_side_channels();
+    void cleanup_side_channels();
+    void side_channel_receive_handler(SCMessage* msg);
+    void handle_new_connection(UnixDomainConnector* connector, UnixDomainConnectorConfig* cfg);
+    void process_messages_from_side_channels();
+    void notify_process_thread();
+    void connector_update_handler(UnixDomainConnector* connector, bool is_recconecting, SideChannel* side_channel);
+
+    MPSerializeFunc get_event_serialization_function(unsigned pub_id, unsigned event_id);
+    MPDeserializeFunc get_event_deserialization_function(unsigned pub_id, unsigned event_id);
+
+    TransportReceiveEventHandler transport_receive_handler = nullptr;
+    MPUnixDomainTransportConfig* config = nullptr;
+
+    std::vector<SideChannelHandle*> side_channels;
+    std::vector<UnixAcceptorHandle*> accept_handlers;
+    std::unordered_map<unsigned, SerializeFunctionHandle> event_helpers;
+
+    std::atomic<bool> is_running = false;
+    std::atomic<bool> is_logging_enabled_flag;
+    std::atomic<bool> consume_message_received = false;
+
+    std::thread* consume_thread = nullptr;
+    std::condition_variable consume_thread_cv;
+};
+
+}
+#endif
diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc
new file mode 100644 (file)
index 0000000..b0c8c5f
--- /dev/null
@@ -0,0 +1,130 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport_module.cc author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "mp_unix_transport_module.h"
+
+#include "main/snort_config.h"
+#include "log/messages.h"
+
+#define DEFAULT_UNIX_DOMAIN_SOCKET_PATH "/tmp/snort_unix_connectors"
+
+using namespace snort;
+
+static const Parameter unix_transport_params[] =
+{
+    { "unix_domain_socket_path" , Parameter::PT_STRING, nullptr, DEFAULT_UNIX_DOMAIN_SOCKET_PATH, "unix socket folder" },
+    { "max_connect_retries", Parameter::PT_INT, nullptr, "5", "max connection retries" },
+    { "retry_interval_seconds", Parameter::PT_INT, nullptr, "30", "retry interval in seconds" },
+    { "connect_timeout_seconds", Parameter::PT_INT, nullptr, "30", "connect timeout in seconds" },
+    { "consume_message_timeout_milliseconds", Parameter::PT_INT, nullptr, "100", "consume message timeout in milliseconds" },
+    { "consume_message_batch_size", Parameter::PT_INT, nullptr, "5", "consume message batch size" },
+    { "enable_logging", Parameter::PT_BOOL, nullptr, "false", "enable logging" },
+    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
+MPUnixDomainTransportModule::MPUnixDomainTransportModule(): Module(MODULE_NAME, MODULE_HELP, unix_transport_params)
+{ 
+    config = nullptr;
+}
+
+bool MPUnixDomainTransportModule::begin(const char *, int, SnortConfig *sc)
+{
+    assert(sc);
+    assert(!config);
+    config = new MPUnixDomainTransportConfig;
+    config->max_processes = sc->max_procs;
+    return true;
+}
+
+bool MPUnixDomainTransportModule::set(const char *, Value & v, SnortConfig *)
+{
+    if (v.is("unix_domain_socket_path"))
+    {
+        config->unix_domain_socket_path = v.get_string();
+    }
+    else if (v.is("max_connect_retries"))
+    {
+        config->conn_retries = true;
+        config->max_retries = v.get_int32();
+    }
+    else if (v.is("retry_interval_seconds"))
+    {
+        config->retry_interval_seconds = v.get_int32();    
+    }
+    else if (v.is("connect_timeout_seconds"))
+    {
+        config->connect_timeout_seconds = v.get_int32();
+    }
+    else if (v.is("consume_message_timeout_milliseconds"))
+    {
+        config->consume_message_timeout_milliseconds = v.get_int32();
+    }
+    else if (v.is("consume_message_batch_size"))
+    {
+        config->consume_message_batch_size = v.get_int32();
+    }
+    else if (v.is("enable_logging"))
+    {
+        config->enable_logging = v.get_bool();
+    }
+    else
+    {
+        WarningMessage("MPUnixDomainTransportModule: received unrecognized parameter %s\n", v.get_as_string().c_str());
+        return false;
+    }
+
+    return true;
+}
+
+static struct MPTransportApi mp_unixdomain_transport_api =
+{
+    {
+        PT_MP_TRANSPORT,
+        sizeof(MPTransportApi),
+        MP_TRANSPORT_API_VERSION,
+        2,
+        API_RESERVED,
+        API_OPTIONS,
+        MODULE_NAME,
+        MODULE_HELP,
+        mod_ctor,
+        mod_dtor
+    },
+    0,
+    nullptr,
+    nullptr,
+    nullptr,
+    nullptr,
+    mp_unixdomain_transport_ctor,
+    mp_unixdomain_transport_dtor
+};
+
+#ifdef BUILDING_SO
+SO_PUBLIC const BaseApi* snort_plugins[] =
+#else
+const BaseApi* mp_unix_transport[] =
+#endif
+{
+    &mp_unixdomain_transport_api.base,
+    nullptr
+};
diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.h
new file mode 100644 (file)
index 0000000..464b4f3
--- /dev/null
@@ -0,0 +1,74 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// mp_unix_transport_module.h author Oleksandr Stepanov <ostepano@cisco.com>
+
+#ifndef MP_UNIX_TRANSPORT_MODULE_H
+#define MP_UNIX_TRANSPORT_MODULE_H
+
+#define MODULE_NAME "unix_transport"
+#define MODULE_HELP "manage the unix transport layer"
+
+#include "framework/module.h"
+#include "framework/mp_transport.h"
+#include "mp_unix_transport.h"
+
+namespace snort
+{
+
+class MPUnixDomainTransportModule : public Module
+{
+    public:
+
+    MPUnixDomainTransportModule();
+
+    ~MPUnixDomainTransportModule() override
+    { delete config; }
+
+    bool begin(const char*, int, SnortConfig*) override;
+    bool set(const char*, Value&, SnortConfig*) override;
+
+    Usage get_usage() const override
+    { return GLOBAL; }
+
+    MPUnixDomainTransportConfig* config;
+};
+
+static Module* mod_ctor()
+{
+    return new MPUnixDomainTransportModule;
+}
+
+static void mod_dtor(Module* m)
+{
+    delete m;
+}
+
+static MPTransport* mp_unixdomain_transport_ctor(Module* m)
+{
+    auto unix_tr_mod = (MPUnixDomainTransportModule*)m;
+    return new MPUnixDomainTransport(unix_tr_mod->config);
+}
+
+static void mp_unixdomain_transport_dtor(MPTransport* t)
+{
+    delete t;
+}
+
+}
+
+#endif
diff --git a/src/mp_transport/mp_unix_transport/test/CMakeLists.txt b/src/mp_transport/mp_unix_transport/test/CMakeLists.txt
new file mode 100644 (file)
index 0000000..4d58bb8
--- /dev/null
@@ -0,0 +1,24 @@
+
+add_cpputest( unix_transport_test
+    SOURCES
+        ../mp_unix_transport.cc
+        ../../../side_channel/side_channel.cc
+        ../../../side_channel/side_channel_format.cc
+        ../../../framework/module.cc
+        ../../../managers/connector_manager.cc
+        $<TARGET_OBJECTS:catch_tests>
+    LIBS
+        ${CMAKE_THREAD_LIBS_INIT}
+)
+
+ add_cpputest( unix_transport_module_test
+    SOURCES
+        ../mp_unix_transport_module.cc
+        ../../../framework/value.cc
+        ../../../sfip/sf_ip.cc
+        $<TARGET_OBJECTS:catch_tests>
+        ../../../framework/module.cc
+    LIBS
+        ${CMAKE_THREAD_LIBS_INIT}
+ )
diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_module_test.cc
new file mode 100644 (file)
index 0000000..c00ff41
--- /dev/null
@@ -0,0 +1,209 @@
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "../mp_unix_transport_module.h"
+
+#include "framework/value.h"
+#include "main/snort_config.h"
+#include "main/snort.h"
+#include "main/thread_config.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+
+static int warning_cnt = 0;
+static int destroy_cnt = 0;
+
+namespace snort
+{
+    void WarningMessage(const char*,...) { warning_cnt++; }
+
+    SnortConfig::SnortConfig(snort::SnortConfig const*, char const*)
+    {
+        max_procs = 2;
+    }
+    SnortConfig::~SnortConfig()
+    {
+
+    }
+
+    unsigned ThreadConfig::get_instance_max()
+    {
+        return 1;
+    }
+    unsigned Snort::get_process_id()
+    {
+        return 1;
+    }
+    unsigned get_instance_id()
+    {
+        return 1;
+    }
+
+    MPUnixDomainTransport::MPUnixDomainTransport(MPUnixDomainTransportConfig* config) :
+        MPTransport()
+    {
+        this->config = config;
+    }
+    MPUnixDomainTransport::~MPUnixDomainTransport()
+    { destroy_cnt++; }
+    void MPUnixDomainTransport::thread_init()
+    {}
+    void MPUnixDomainTransport::thread_term()
+    {}
+    void MPUnixDomainTransport::init_connection()
+    {}
+    void MPUnixDomainTransport::cleanup()
+    {}
+    void MPUnixDomainTransport::register_event_helpers(const unsigned&, const unsigned&, MPHelperFunctions&)
+    {}
+    bool MPUnixDomainTransport::send_to_transport(MPEventInfo&)
+    { return true; }
+    void MPUnixDomainTransport::unregister_receive_handler()
+    { }
+    void MPUnixDomainTransport::register_receive_handler(const TransportReceiveEventHandler&)
+    {}
+    bool MPUnixDomainTransport::configure(const SnortConfig*)
+    { return true; }
+    bool MPUnixDomainTransport::is_logging_enabled()
+    { return false; }
+    void MPUnixDomainTransport::enable_logging()
+    {}
+    void MPUnixDomainTransport::disable_logging()
+    {}
+
+    char* snort_strdup(const char*)
+    {
+        return nullptr;
+    }
+};
+
+void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
+void show_stats(PegCount*, const PegInfo*, const std::vector<unsigned>&, const char*, FILE*) { }
+
+using namespace snort;
+
+MPUnixDomainTransportModule* mod = nullptr;
+
+TEST_GROUP(MPUnixDomainTransportModuleTests)
+{
+    void setup() override
+    {
+        mod = (MPUnixDomainTransportModule*)mod_ctor();
+    }
+
+    void teardown() override
+    {
+        mod_dtor(mod);
+    }
+};
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigBegin)
+{
+    SnortConfig sc;
+    auto res = mod->begin("test", 0, &sc);
+    CHECK(res == true);
+    CHECK(mod->config != nullptr);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigEnd)
+{
+    auto res = mod->end("test", 0, nullptr);
+    CHECK(res == true);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigSet)
+{
+    Parameter p{"unix_domain_socket_path", Parameter::PT_STRING, nullptr, "test_value", nullptr};
+    Parameter p2{"max_connect_retries", Parameter::PT_INT, nullptr, "15", nullptr};
+    Parameter p3{"retry_interval_seconds", Parameter::PT_INT, nullptr, "33", nullptr};
+    Parameter p4{"connect_timeout_seconds", Parameter::PT_INT, nullptr, "32", nullptr};
+    Parameter p5{"consume_message_timeout_milliseconds", Parameter::PT_INT, nullptr, "200", nullptr};
+    Parameter p6{"consume_message_batch_size", Parameter::PT_INT, nullptr, "20", nullptr};
+    Parameter p7{"enable_logging", Parameter::PT_BOOL, nullptr, "true", nullptr};
+    Value v("test_value");
+    v.set(&p);
+
+    SnortConfig sc;
+    mod->begin("test", 0, &sc);
+    auto res = mod->set(nullptr, v, nullptr);
+    
+    CHECK(res == true);
+    CHECK(strcmp("test_value", mod->config->unix_domain_socket_path.c_str()) == 0);
+
+    v.set((double)15);
+    v.set(&p2);
+    res = mod->set(nullptr, v, nullptr);
+    CHECK(res == true);
+    CHECK(mod->config->max_retries == 15);
+
+    v.set((double)33);
+    v.set(&p3);
+    res = mod->set(nullptr, v, nullptr);
+    CHECK(res == true);
+    CHECK(mod->config->retry_interval_seconds == 33);
+
+    v.set((double)32);
+    v.set(&p4);
+    res = mod->set(nullptr, v, nullptr);
+    CHECK(res == true);
+    CHECK(mod->config->connect_timeout_seconds == 32);
+
+    v.set((double)200);
+    v.set(&p5);
+    res = mod->set(nullptr, v, nullptr);
+    CHECK(res == true);
+    CHECK(mod->config->consume_message_timeout_milliseconds == 200);
+
+    v.set((double)20);
+    v.set(&p6);
+    res = mod->set(nullptr, v, nullptr);
+    CHECK(res == true);
+    CHECK(mod->config->consume_message_batch_size == 20);
+
+    v.set(true);
+    v.set(&p7);
+    res = mod->set(nullptr, v, nullptr);
+    CHECK(res == true);
+    CHECK(mod->config->enable_logging == true);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleConfigUnknownSet)
+{
+    warning_cnt = 0;
+    
+    Parameter p{"unknown_value", Parameter::PT_STRING, nullptr, "/tmp/unx_dmn_sck", nullptr};
+    Value v("/tmp/unx_dmn_sck");
+    v.set(&p);
+
+    SnortConfig sc;
+    mod->begin("test", 0, &sc);
+    auto res = mod->set(nullptr, v, nullptr);
+    
+    CHECK(res == false);
+    CHECK(1 == warning_cnt);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleGetUsage)
+{
+    auto res = mod->get_usage();
+    CHECK(res == Module::Usage::GLOBAL);
+}
+
+TEST(MPUnixDomainTransportModuleTests, MPUnixDomainTransportModuleCreateDestroyTransport)
+{
+    destroy_cnt = 0;
+    auto transport = mp_unixdomain_transport_ctor(mod);
+    CHECK(transport != nullptr);
+
+    mp_unixdomain_transport_dtor(transport);
+    CHECK(destroy_cnt == 1);
+}
+
+int main(int argc, char** argv)
+{
+    int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
+    return return_value;
+}
\ No newline at end of file
diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc
new file mode 100644 (file)
index 0000000..d611e15
--- /dev/null
@@ -0,0 +1,505 @@
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "framework/mp_transport.h"
+#include "../mp_unix_transport.h"
+#include "framework/counts.h"
+#include "framework/mp_data_bus.h"
+#include "main/snort.h"
+#include "main/thread_config.h"
+#include "main/snort_config.h"
+#include "connectors/unixdomain_connector/unixdomain_connector.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+#include <mutex>
+#include <condition_variable>
+#include <vector>
+#include <iostream>
+
+static int snort_instance_id = 0;
+
+static int accept_cnt = 0;
+
+static int test_socket_calls = 0;
+static int test_bind_calls = 0;
+static int test_listen_calls = 0;
+static int test_accept_calls = 0;
+static int test_close_calls = 0;
+static int test_connect_calls = 0;
+static int test_call_sock_created = 0;
+static int test_serialize_calls = 0;
+static int test_deserialize_calls = 0;
+
+int accept (int, struct sockaddr*, socklen_t*)
+{
+    test_accept_calls++;
+    return accept_cnt--;
+}
+
+int close (int)
+{ 
+    test_close_calls++;
+    return 0;
+}
+
+void clear_test_calls()
+{
+    test_socket_calls = 0;
+    test_bind_calls = 0;
+    test_listen_calls = 0;
+    test_accept_calls = 0;
+    test_close_calls = 0;
+    test_connect_calls = 0;
+    test_call_sock_created = 0;
+    test_serialize_calls = 0;
+    test_deserialize_calls = 0;
+}
+
+namespace snort
+{
+    void ErrorMessage(const char*,...) { }
+    void WarningMessage(const char*,...) { }
+    void LogMessage(const char* s, ...) { }
+    void LogText(const char*, FILE*) {}
+    void ParseError(const char*, ...) { }
+
+    unsigned ThreadConfig::get_instance_max()
+    {
+        return 2; // Mock value for testing
+    }
+
+    SnortConfig::SnortConfig(snort::SnortConfig const*, char const*)
+    {
+        max_procs = 2;
+    }
+    SnortConfig::~SnortConfig()
+    {
+
+    }
+    
+    unsigned Snort::get_process_id()
+    {
+        return snort_instance_id;
+    }
+
+    unsigned get_instance_id()
+    {
+        return snort_instance_id;
+    }
+};
+static int test_send_calls = 0;
+UnixDomainConnector* listen_connector = nullptr;
+UnixDomainConnector* call_connector = nullptr;
+
+void UnixDomainConnector::set_message_received_handler(std::function<void ()> h)
+{
+    message_received_handler = h;
+}
+
+static bool expect_update_change = false;
+std::function<void (UnixDomainConnector*,bool)> test_update_handler = nullptr;
+
+void UnixDomainConnector::set_update_handler(std::function<void (UnixDomainConnector*,bool)> h)
+{
+    if(expect_update_change)
+        test_update_handler = h;
+}
+
+static snort::ConnectorMsg* test_msg_answer = nullptr;
+static snort::ConnectorMsg* test_msg_call = nullptr;
+static uint8_t* test_msg_call_data = nullptr;
+static uint8_t* test_msg_answer_data = nullptr;
+UnixDomainConnectorListener::UnixDomainConnectorListener(char const*) // cppcheck-suppress uninitMemberVar
+{}
+UnixDomainConnectorListener::~UnixDomainConnectorListener()
+{
+}
+void UnixDomainConnectorListener::stop_accepting_connections()
+{
+    close(0);
+}
+void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnectorAcceptHandler h, UnixDomainConnectorConfig* cfg)
+{
+    socket(0,0,0);
+    while(accept_cnt > 0)
+    {
+        accept(0, nullptr, nullptr);
+        auto cfg_copy = new UnixDomainConnectorConfig(*cfg);
+        h(new UnixDomainConnector(*cfg_copy, 0, 0), cfg_copy);
+    }
+}
+
+bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg& m, const ID&)
+{
+    test_send_calls++;
+    
+    if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+    {
+        test_msg_call_data = new uint8_t[m.get_length()];
+        memcpy(test_msg_call_data, m.get_data(), m.get_length());
+        test_msg_call = new snort::ConnectorMsg(test_msg_call_data, m.get_length());
+        if(!call_connector)
+            call_connector = this;
+        listen_connector->process_receive();
+    }
+    else
+    {
+        test_msg_answer_data = new uint8_t[m.get_length()];
+        memcpy(test_msg_answer_data, m.get_data(), m.get_length());
+        test_msg_answer = new snort::ConnectorMsg(test_msg_answer_data, m.get_length());
+        if(!listen_connector)
+            listen_connector = this;
+        call_connector->process_receive();
+    }
+    
+    return true;
+}
+void UnixDomainConnector::process_receive()
+{
+    if (message_received_handler)
+    {
+        message_received_handler();
+    }
+}
+bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg&&, const ID&)
+{ return true; }
+snort::ConnectorMsg UnixDomainConnector::receive_message(bool)
+{
+    if (cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+    {
+        if (test_msg_answer)
+        {
+            snort::ConnectorMsg msg(test_msg_answer_data, test_msg_answer->get_length());
+            delete test_msg_answer;
+            test_msg_answer = nullptr;
+            return std::move(msg); // cppcheck-suppress returnStdMoveLocal
+        }
+    }
+    else
+    {
+        if (test_msg_call)
+        {
+            snort::ConnectorMsg msg(test_msg_call_data, test_msg_call->get_length());
+            delete test_msg_call;
+            test_msg_call = nullptr;
+            return std::move(msg); // cppcheck-suppress returnStdMoveLocal
+        }
+    }
+    return snort::ConnectorMsg();
+}
+UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx) : Connector(config) // cppcheck-suppress uninitMemberVar
+{ cfg  = config; } // cppcheck-suppress useInitializationList
+UnixDomainConnector::~UnixDomainConnector()
+{
+    close(0);
+}
+
+UnixDomainConnector* unixdomain_connector_tinit_call(const UnixDomainConnectorConfig& cfg, const char* path, size_t idx, const UnixDomainConnectorUpdateHandler& update_handler)
+{
+    if(cfg.setup == UnixDomainConnectorConfig::Setup::CALL)
+    {
+        test_call_sock_created++;
+        auto new_conn = new UnixDomainConnector(cfg, 0, idx);
+        call_connector = new_conn;
+        return new_conn;
+    }
+    assert(false);
+    return nullptr;
+}
+
+void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
+void show_stats(PegCount*, const PegInfo*, const std::vector<unsigned>&, const char*, FILE*) { }
+
+using namespace snort;
+
+
+static int s_socket_return = 1;
+static int s_bind_return = 0;
+static int s_listen_return = 0;
+static int s_connect_return = 1;
+
+#ifdef __GLIBC__
+int socket (int, int, int) __THROW { test_socket_calls++; return s_socket_return; }
+int bind (int, const struct sockaddr*, socklen_t) __THROW { test_bind_calls++; return s_bind_return; }
+int listen (int, int) __THROW { test_listen_calls++; return s_listen_return; }
+int connect (int, const struct sockaddr*, socklen_t) __THROW { test_connect_calls++; return s_connect_return; }
+int unlink (const char *__name) __THROW { return 0;};
+#else
+int socket (int, int, int) { test_socket_calls++; return s_socket_return; }
+int bind (int, const struct sockaddr*, socklen_t) { test_bind_calls++; return s_bind_return; }
+int listen (int, int) { test_listen_calls++; return s_listen_return; }
+int connect (int, const struct sockaddr*, socklen_t) { test_connect_calls++; return s_connect_return; }
+int unlink (const char *__name) { return 0;};
+#endif
+
+int fcntl (int __fd, int __cmd, ...) { return 0;}
+ssize_t send (int, const void*, size_t n, int) { return n; }
+
+
+std::mutex accept_mutex;
+std::condition_variable accept_cond;
+
+
+class TestDataEvent : public DataEvent
+{
+public:
+    TestDataEvent() {}
+    ~TestDataEvent() override {}
+};
+
+bool serialize_mock(DataEvent* event, char*& buffer, uint16_t* length)
+{
+    test_serialize_calls++;
+    buffer = new char[9];
+    *length = 9;
+    memcpy(buffer, "test_data", 9);
+    return true;
+}
+
+bool deserialize_mock(const char* buffer, uint16_t length, DataEvent*& event)
+{
+    test_deserialize_calls++;
+    event = new TestDataEvent();
+    return true;
+}
+
+MPHelperFunctions mp_helper_functions_mock(serialize_mock, deserialize_mock);
+
+static MPUnixDomainTransportConfig test_config;
+static MPUnixDomainTransport* test_transport = nullptr;
+
+static SnortConfig test_snort_config(nullptr, nullptr);
+
+TEST_GROUP(unix_transport_test_connectivity_group)
+{
+    void setup() override
+    {
+        test_snort_config.max_procs = 2;
+        test_transport = new MPUnixDomainTransport(&test_config);
+        test_transport->configure(&test_snort_config);
+    }
+
+    void teardown() override
+    {
+        delete test_transport;
+        test_transport = nullptr;
+    }
+};
+
+static MPUnixDomainTransportConfig test_config_message;
+
+static MPTransport* test_transport_message_1 = nullptr;
+static MPTransport* test_transport_message_2 = nullptr;
+
+static int reciveved_1_msg_cnt = 0;
+static int reciveved_2_msg_cnt = 0;
+
+TEST_GROUP(unix_transport_test_messaging)
+{
+    void setup() override
+    {
+        test_snort_config.max_procs = 2;
+
+        accept_cnt = 1;
+
+        test_config_message.unix_domain_socket_path = "/tmp";
+        test_config_message.max_processes = 2;
+        test_config_message.conn_retries = false;
+        test_config_message.retry_interval_seconds = 0;
+        test_config_message.max_retries = 0;
+        test_config_message.connect_timeout_seconds = 30;
+
+        test_transport_message_1 = new MPUnixDomainTransport(&test_config_message);
+        snort_instance_id = 1;
+        test_transport_message_1->configure(&test_snort_config);
+        test_transport_message_1->init_connection();
+        test_transport_message_1->register_receive_handler([reciveved_1_msg_cnt](const snort::MPEventInfo& e)
+        {
+            reciveved_1_msg_cnt++;
+        });
+        
+        std::this_thread::sleep_for(std::chrono::milliseconds(100));
+
+        test_transport_message_2 = new MPUnixDomainTransport(&test_config_message);
+        snort_instance_id = 2;
+        test_transport_message_2->configure(&test_snort_config);
+        test_transport_message_2->init_connection();
+        test_transport_message_2->register_receive_handler([reciveved_2_msg_cnt](const snort::MPEventInfo& e)
+        {
+            reciveved_2_msg_cnt++;
+        });
+    }
+
+    void teardown() override
+    {
+        delete test_transport_message_1;
+        test_transport_message_1 = nullptr;
+        delete test_transport_message_2;
+        test_transport_message_2 = nullptr;
+        delete[] test_msg_call_data;
+        test_msg_call_data = nullptr;
+        delete[] test_msg_answer_data;
+        test_msg_answer_data = nullptr;
+    }
+};
+
+TEST(unix_transport_test_connectivity_group, get_config)
+{
+    auto unix_transport = (MPUnixDomainTransport*)test_transport;
+    CHECK(unix_transport->get_config() == &test_config);
+};
+
+TEST(unix_transport_test_connectivity_group, set_logging_enabled_disabled)
+{
+    auto logging_status = test_transport->is_logging_enabled();
+    CHECK(logging_status == false);
+
+    test_transport->enable_logging();
+    logging_status = test_transport->is_logging_enabled();
+    CHECK(logging_status == true);
+
+    test_transport->disable_logging();
+    logging_status = test_transport->is_logging_enabled();
+    CHECK(logging_status == false);
+};
+
+TEST(unix_transport_test_connectivity_group, init_connection_single_snort_instance)
+{
+    clear_test_calls();
+    test_config.unix_domain_socket_path = "/tmp";
+    test_config.max_processes = 1;
+
+    test_transport->init_connection();
+    
+    CHECK(test_socket_calls == 0);
+    CHECK(test_bind_calls == 0);
+    CHECK(test_listen_calls == 0);
+    CHECK(test_accept_calls == 0);
+    CHECK(test_close_calls == 0);
+
+    test_transport->cleanup();
+    CHECK(test_close_calls == 0);
+};
+
+TEST(unix_transport_test_connectivity_group, init_connection_first_snort_instance)
+{
+    clear_test_calls();
+    snort_instance_id = 1;
+
+    test_config.unix_domain_socket_path = "/tmp";
+    test_config.max_processes = 2;
+
+    accept_cnt = 1;
+
+    test_transport->init_connection();
+
+    CHECK(test_accept_calls == 1);
+
+    test_transport->cleanup();
+    CHECK(test_close_calls == 2);
+};
+
+TEST(unix_transport_test_connectivity_group, init_connection_second_snort_instance)
+{
+    clear_test_calls();
+    snort_instance_id = 2;
+    test_config.unix_domain_socket_path = "/tmp";
+    test_config.max_processes = 2;
+
+    test_transport->init_connection();
+    
+    CHECK(test_bind_calls == 0);
+    CHECK(test_listen_calls == 0);
+    CHECK(test_accept_calls == 0);
+    CHECK(test_close_calls == 0);
+    CHECK(test_call_sock_created == 1);
+
+    test_transport->cleanup();
+    CHECK(test_close_calls == 1);
+};
+
+TEST(unix_transport_test_connectivity_group, connector_update_handler_call)
+{
+    clear_test_calls();
+    
+    test_config.unix_domain_socket_path = "/tmp";
+    test_config.max_processes = 2;
+
+    accept_cnt = 1;
+    snort_instance_id = 1;
+
+    test_update_handler = nullptr;
+    expect_update_change = true;
+
+    test_transport->init_connection();
+
+    CHECK(test_update_handler != nullptr);
+
+    test_update_handler(nullptr, false);
+
+    CHECK(test_close_calls == 1);
+    expect_update_change = false;
+    test_update_handler = nullptr;
+};
+
+static TestDataEvent test_event;
+
+TEST(unix_transport_test_messaging, send_to_transport_biderectional)
+{
+    clear_test_calls();
+
+    test_transport_message_1->register_event_helpers(0, 0, mp_helper_functions_mock);
+    test_transport_message_2->register_event_helpers(0, 0, mp_helper_functions_mock);
+
+    MPEventInfo event(&test_event, 0, 0);
+
+    auto res = test_transport_message_1->send_to_transport(event);
+    
+
+    CHECK(res == true);
+    CHECK(test_serialize_calls == 1);
+
+    res = test_transport_message_2->send_to_transport(event);
+    
+    CHECK(res == true);
+    CHECK(test_serialize_calls == 2);
+
+    std::this_thread::sleep_for(std::chrono::milliseconds(500));
+
+    CHECK(test_deserialize_calls == 2);
+    CHECK(reciveved_1_msg_cnt == 1);
+    CHECK(reciveved_2_msg_cnt == 1);
+    CHECK(test_send_calls == 2);
+};
+
+TEST(unix_transport_test_messaging, send_to_transport_no_helpers)
+{
+    clear_test_calls();
+
+    MPEventInfo event(&test_event, 0, 0);
+
+    auto res = test_transport_message_1->send_to_transport(event);
+    CHECK(res == false);
+    CHECK(test_serialize_calls == 0);
+    CHECK(test_deserialize_calls == 0);
+    CHECK(reciveved_1_msg_cnt == 0);
+    CHECK(reciveved_2_msg_cnt == 0);
+    CHECK(test_send_calls == 0);
+
+    res = test_transport_message_2->send_to_transport(event);
+    CHECK(res == false);
+    CHECK(test_serialize_calls == 0);
+    CHECK(test_deserialize_calls == 0);
+    CHECK(reciveved_1_msg_cnt == 0);
+    CHECK(reciveved_2_msg_cnt == 0);
+    CHECK(test_send_calls == 0);
+}
+
+int main(int argc, char** argv)
+{
+    int return_value = CommandLineTestRunner::RunAllTests(argc, argv);
+    return return_value;
+}
\ No newline at end of file
index 941c478591a3a80ddaa79c2934cde0b31e936f85..3fe5a0f80f7a822f48bf80b236c0f476efb21ffc 100644 (file)
@@ -172,6 +172,9 @@ SideChannel::SideChannel(ScMsgFormat fmt) : msg_format(fmt)
 // return true iff we received any messages.
 bool SideChannel::process(int max_messages)
 {
+    if(!connector_receive)
+        return false;
+    
     bool received_message = false;
 
     while (true)