]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #5119: mp_dbus: use lockless ring for events
authorOleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco) <ostepano@cisco.com>
Wed, 4 Feb 2026 21:12:50 +0000 (21:12 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Wed, 4 Feb 2026 21:12:50 +0000 (21:12 +0000)
Merge in SNORT/snort3 from ~OSTEPANO/snort3:mp_ring_opt to master

Squashed commit of the following:

commit d918d17143484d7d84ed2852dc107290ea8e952a
Author: Oleksandr Stepanov <ostepano@cisco.com>
Date:   Fri Jan 9 06:36:46 2026 -0500

    mp_dbus: lockless event ring

15 files changed:
src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc
src/connectors/unixdomain_connector/unixdomain_connector.cc
src/connectors/unixdomain_connector/unixdomain_connector.h
src/framework/connector.h
src/framework/mp_data_bus.cc
src/framework/mp_data_bus.h
src/framework/test/mp_data_bus_test.cc
src/helpers/CMakeLists.txt
src/helpers/lockless_ring.h [new file with mode: 0644]
src/helpers/test/CMakeLists.txt
src/helpers/test/lockless_ring_test.cc [new file with mode: 0644]
src/main/modules.cc
src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc
src/mp_transport/mp_unix_transport/test/unix_transport_test.cc
src/side_channel/side_channel.cc

index 7063d87986210147e540e7b21cf20599b159ccd2..beb77d7370d929d36f25403f2775a33d20711f46 100644 (file)
@@ -535,24 +535,6 @@ TEST(unixdomain_connector_tinit_tterm_call, alloc_transmit)
     CHECK(unixdomainc->transmit_message(std::move(msg)) == true);
 }
 
-TEST(unixdomain_connector_tinit_tterm_call, alloc_transmit_header_fail)
-{
-    const uint32_t len = 40;
-    const uint8_t* data = new uint8_t[len];
-    UnixDomainConnector* unixdomainc = (UnixDomainConnector*)connector;
-    set_normal_status();
-
-    ConnectorMsg msg(data, len, true);
-
-    CHECK(msg.get_length() == len);
-    CHECK(msg.get_data() == data);
-
-    s_send_ret_header = sizeof(UnixDomainConnectorMsgHdr)-1;
-    s_send_ret_other = len;
-    CHECK(unixdomainc->transmit_message(msg) == false);
-    CHECK(unixdomainc->transmit_message(std::move(msg)) == false);
-}
-
 TEST(unixdomain_connector_tinit_tterm_call, alloc_transmit_body_fail)
 {
     const uint32_t len = 40;
index f183c4363f46dbe1e2686aad7c20f4ce4d0192a4..2e45e42cf69157992249267db7db1266ed9b00ff 100644 (file)
@@ -47,6 +47,38 @@ THREAD_LOCAL ProfileStats unixdomain_connector_perfstats;
 
 /* Module *****************************************************************/
 
+#define SOCKET_SEND_BUFFER_SIZE_MULTIPLIER 4
+
+static void increase_socket_send_buffer_size(int& sock_handle, uint8_t size_mult)
+{
+    char buf[4] = {};
+    socklen_t buf_len = 4;
+    auto get_opts = getsockopt(sock_handle, SOL_SOCKET, SO_SNDBUF, buf, &buf_len);
+    if (get_opts == 0 and buf_len <= 4)
+    {
+        LogMessage("UnixDomainC: Socket default send buffer size: %d ; Attempting to increase\n", *(int*)buf);
+        uint32_t socket_size = *(int*)buf * size_mult;
+
+        auto set_opt_res = setsockopt(sock_handle, SOL_SOCKET, SO_SNDBUF, (void*)&socket_size, sizeof(socket_size));
+
+        if (set_opt_res != 0)
+        {
+            LogMessage("UnixDomainC: Failed to update send buffer size, continuing with default: %s \n", strerror(errno));
+        }
+        else
+        {
+            get_opts = getsockopt(sock_handle, SOL_SOCKET, SO_SNDBUF, buf, &buf_len);
+            if (get_opts == 0)
+                LogMessage("UnixDomainC: Updated Socket send buffer size: %d \n", *(int*)buf);
+        }
+    }
+    else
+    {
+        LogMessage("UnixDomainC: Failed to get socket send buffer size: %s \n", strerror(errno));
+    }
+}
+
+
 static bool attempt_connection(int& sfd, const char* path, unsigned long timeout_sec) {
     sfd = socket(AF_UNIX, SOCK_STREAM, 0);
     if (sfd == -1) {
@@ -54,6 +86,8 @@ static bool attempt_connection(int& sfd, const char* path, unsigned long timeout
         return false;
     }
 
+    increase_socket_send_buffer_size(sfd, SOCKET_SEND_BUFFER_SIZE_MULTIPLIER);
+
     // Set the socket to non-blocking mode
     int flags = fcntl(sfd, F_GETFL, 0);
     if (flags == -1) {
@@ -172,7 +206,7 @@ static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx,
 
 UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& unixdomain_connector_config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper)
     : Connector(unixdomain_connector_config), sock_fd(sfd), run_thread(false), receive_thread(nullptr),
-      receive_ring(new ReceiveRing(50)), instance_id(idx), cfg(unixdomain_connector_config), reconnect_helper(reconnect_helper) {
+      receive_ring(new ReceiveRing(2000)), instance_id(idx), cfg(unixdomain_connector_config), reconnect_helper(reconnect_helper) {
     if (unixdomain_connector_config.async_receive) {
         start_receive_thread();
     }
@@ -376,20 +410,23 @@ bool UnixDomainConnector::internal_transmit_message(const ConnectorMsg& msg) {
         return false;
     }
 
-    UnixDomainConnectorMsgHdr unixdomainc_hdr(msg.get_length());
-
-    if ( send( sock_fd, (const char*)&unixdomainc_hdr, sizeof(unixdomainc_hdr), 0 ) != sizeof(unixdomainc_hdr) )
+    if (send(sock_fd, msg.get_data(), msg.get_length(), 0) != msg.get_length())
     {
-        ErrorMessage("UnixDomainC: failed to transmit header\n");
+        ErrorMessage("UnixDomainC: failed to transmit message, error = %s\n", strerror(errno));
         return false;
     }
 
-    if (send(sock_fd, msg.get_data(), msg.get_length(), 0) != msg.get_length())
-        return false;
-
     return true;
 }
 
+ConnectorMsg UnixDomainConnector::allocate_connector_message(uint32_t length)
+{
+    UnixDomainConnectorMsgHdr unixdomainc_hdr(length);
+    uint8_t* data = new uint8_t[length + sizeof(UnixDomainConnectorMsgHdr)];
+    memcpy(data, &unixdomainc_hdr, sizeof(UnixDomainConnectorMsgHdr));
+    return ConnectorMsg(data, length + sizeof(UnixDomainConnectorMsgHdr), true, sizeof(UnixDomainConnectorMsgHdr));
+}
+
 bool UnixDomainConnector::transmit_message(const ConnectorMsg& msg, const ID&) {
     return internal_transmit_message(msg);
 }
@@ -479,6 +516,7 @@ static UnixDomainConnector* unixdomain_connector_tinit_answer(const UnixDomainCo
         close(sfd);
         return nullptr;
     }
+    increase_socket_send_buffer_size(peer_sfd, SOCKET_SEND_BUFFER_SIZE_MULTIPLIER);
 
     LogMessage("UnixDomainC: Accepted connection from %s \n", path);
     return new UnixDomainConnector(cfg, peer_sfd, idx);
@@ -644,6 +682,7 @@ void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnecto
                 ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno));
                 continue;
             }
+            increase_socket_send_buffer_size(peer_sfd, SOCKET_SEND_BUFFER_SIZE_MULTIPLIER);
             error_count = 0;
             auto config_copy = new UnixDomainConnectorConfig(*config);
             auto unix_conn = new UnixDomainConnector(*config_copy, peer_sfd, 0);
index 5931effcfbac44d5b77c616d6a3212fff11eb5f0..10081c15c49ccc0880d795a072818e8eeffe4207 100644 (file)
@@ -61,6 +61,8 @@ public:
     UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr);
     ~UnixDomainConnector() override;
 
+    snort::ConnectorMsg allocate_connector_message(uint32_t length) override;
+
     bool transmit_message(const snort::ConnectorMsg&, const ID& = null) override;
     bool transmit_message(const snort::ConnectorMsg&&, const ID& = null) override;
 
index 8e9739700b536a1ec64a35170fc714cf3954cb80..e349665f75bd4bf852f6f616bd4f4b276f7c708f 100644 (file)
@@ -37,7 +37,7 @@
 namespace snort
 {
 // this is the current version of the api
-#define CONNECTOR_API_VERSION ((BASE_API_VERSION << 16) | 3)
+#define CONNECTOR_API_VERSION ((BASE_API_VERSION << 16) | 4)
 
 //-------------------------------------------------------------------------
 // api for class
@@ -52,8 +52,8 @@ class ConnectorMsg
 public:
     ConnectorMsg() = default;
 
-    ConnectorMsg(const uint8_t* data, uint32_t length, bool pass_ownership = false) :
-        data(data), length(length), owns(pass_ownership)
+    ConnectorMsg(const uint8_t* data, uint32_t length, bool pass_ownership = false, uint32_t content_offset = 0) :
+        data(data), content(const_cast<uint8_t*>(data) + content_offset), length(length), owns(pass_ownership)
     { }
 
     ~ConnectorMsg()
@@ -63,7 +63,7 @@ public:
     ConnectorMsg& operator=(ConnectorMsg& other) = delete;
 
     ConnectorMsg(ConnectorMsg&& other) :
-        data(other.data), length(other.length), owns(other.owns)
+        data(other.data), content(other.content), length(other.length), owns(other.owns)
     { other.owns = false; }
 
     ConnectorMsg& operator=(ConnectorMsg&& other)
@@ -72,6 +72,7 @@ public:
             delete[] data;
 
         data = other.data;
+        content = other.content;
         length = other.length;
         owns = other.owns;
 
@@ -83,11 +84,18 @@ public:
     const uint8_t* get_data() const
     { return data; }
 
+    uint8_t* get_content() const
+    { return content; }
+
     uint32_t get_length() const
     { return length; }
 
+    uint32_t get_content_length() const
+    { return length - (content - data); }
+
 private:
     const uint8_t* data = nullptr;
+    uint8_t* content = nullptr;
     uint32_t length = 0;
     bool owns = false;
 };
@@ -111,6 +119,12 @@ public:
     virtual const ID get_id(const char*) const
     { return null; }
 
+    virtual ConnectorMsg allocate_connector_message(uint32_t length)
+    {
+        const uint8_t* data = new uint8_t[length];
+        return ConnectorMsg(data, length, true);
+    }
+
     virtual bool transmit_message(const ConnectorMsg&, const ID& = null) = 0;
     virtual bool transmit_message(const ConnectorMsg&&, const ID& = null) = 0;
 
index 0234803dac4add0a6723e5215a42b4b259e5e850..215cb2fa56f9c89274d9696b79724989254ee2ca 100644 (file)
 #include "main/snort_types.h"
 #include "log/messages.h"
 #include "log/log_stats.h"
-#include "helpers/ring.h"
 #include "managers/mp_transport_manager.h"
 #include "managers/module_manager.h"
 #include "main/snort.h"
 #include "framework/module.h"
+#include "utils/util.h"
 
 using namespace snort;
 
@@ -76,10 +76,9 @@ void MPDataBusLog(const char* msg, ...)
 
 MPDataBus::MPDataBus() :
     run_thread(true),
-    worker_thread(nullptr),
-    mp_event_queue(nullptr)
+    worker_thread(nullptr)
 {
-    mp_event_queue = new Ring<std::shared_ptr<MPEventInfo>>(mp_max_eventq_size);
+    mp_event_queue = new MPEventQueue(mp_max_eventq_size);
     start_worker_thread();
 }
 
@@ -290,23 +289,23 @@ void MPDataBus::process_event_queue()
 
     // coverity[wait_not_in_locked_loop:FALSE]
     if( (std::cv_status::timeout == queue_cv.wait_for(u_lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP))) and
-        mp_event_queue->empty() )
+        mp_event_queue->is_empty() )
         return;
 
-    while (!mp_event_queue->empty()) {
-        std::shared_ptr<MPEventInfo> event_info = mp_event_queue->get(nullptr);
-        if (event_info) {
-            MPDataBusLog("Processing event for publisher ID %u \n",
-                        event_info->pub_id);
-
-            if (!transport_layer){
-                run_thread.store(false);
-                ErrorMessage("MPDataBus: Transport layer not initialized\n");
-                return;
-            }
-
+    
+    static std::shared_ptr<MPEventInfo> event_info;
+    
+    while (mp_event_queue->try_pop(event_info))
+    {
+        if (UNLIKELY(!transport_layer))
+        {
+            run_thread.store(false);
+            ErrorMessage("MPDataBus: Transport layer not initialized\n");
+            return;
+        }
+        
+            MPDataBusLog("Processing event for publisher ID %u \n", event_info->pub_id);
             auto send_res = transport_layer->send_to_transport(*event_info);
-
             {
                 std::lock_guard<std::mutex> lock(mp_stats_mutex);
                 mp_pub_stats[event_info->pub_id].total_messages_published++;
@@ -319,7 +318,7 @@ void MPDataBus::process_event_queue()
                     mp_pub_stats[event_info->pub_id].total_messages_sent++;
                 }
             }
-        }
+        
     }
 }
 
@@ -474,8 +473,8 @@ void MPDataBus::dump_stats(ControlConn *ctrlconn, const char *module_name)
 
 void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name)
 {
-    int current_read_idx = 0;
-    uint32_t ring_items = mp_event_queue->count();
+    uint32_t current_read_idx = 0;
+    uint32_t ring_items = mp_event_queue->size();
     if(ring_items == 0)
     {
         if (ctrlconn)
@@ -488,16 +487,7 @@ void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name)
         }
         return;
     }
-    auto event_queue_store = mp_event_queue->grab_store(current_read_idx);
-
-    if (current_read_idx == 0)
-    {
-        current_read_idx = mp_max_eventq_size - 1;
-    }
-    else
-    {
-        current_read_idx--;
-    }
+    auto event_queue_buffer = mp_event_queue->get_buffer(current_read_idx);
 
     for (uint32_t i = current_read_idx; i <= ring_items; i++)
     {
@@ -506,7 +496,7 @@ void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name)
             i = 0;
             ring_items -= mp_max_eventq_size;
         }
-        auto event_info = event_queue_store[i];
+        auto event_info = event_queue_buffer[i].data;
         if (event_info)
         {
             if (module_name)
@@ -625,7 +615,7 @@ bool MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f)
 
 bool snort::MPDataBus::_enqueue_event(std::shared_ptr<MPEventInfo> ev_info)
 {
-    bool res = mp_event_queue != nullptr and !mp_event_queue->full() and mp_event_queue->put(std::move(ev_info));
+    bool res = mp_event_queue != nullptr and mp_event_queue->try_push(std::move(ev_info));
     if(res) queue_cv.notify_one();
     return res;
 }
index ce2752a62772bf164684bfb1ffa87d99aecb4058..47c7ff4475a0d9d2f4c3a9592b7589df90d60ac4 100644 (file)
 #include "control/control.h"
 #include "framework/mp_transport.h"
 #include "framework/counts.h"
+#include "helpers/lockless_ring.h"
 #include "main/snort_types.h"
 #include "data_bus.h"
 
 #define DEFAULT_TRANSPORT "unix_transport"
-#define DEFAULT_MAX_EVENTQ_SIZE 1000
+#define DEFAULT_MAX_EVENTQ_SIZE 4096
 #define WORKER_THREAD_SLEEP 100
 
-template <typename T>
-class Ring;
-
 namespace snort
 {
 class Flow;
@@ -116,6 +114,8 @@ struct MPHelperFunctions {
         : serializer(s), deserializer(d) {}
 };
 
+using MPEventQueue = LocklessRing<std::shared_ptr<MPEventInfo>, false>;
+
 struct pair_hash
 {
     template <class T1, class T2>
@@ -175,9 +175,6 @@ public:
     // API for receiving the DataEvent and Event type from transport layer using EventInfo
     void receive_message(const MPEventInfo& event_info);
 
-    Ring<std::shared_ptr<MPEventInfo>>* get_event_queue()
-    { return mp_event_queue; }
-
     void set_debug_enabled(bool flag);
 
     MPDataBusStats get_stats_copy();
@@ -189,6 +186,10 @@ public:
     void dump_events(ControlConn* ctrlconn, const char* module_name);
     void show_channel_status(ControlConn* ctrlconn);
 
+#ifdef MP_DATA_BUS_UNIT_TEST
+    MPEventQueue* get_event_queue() { return mp_event_queue; }
+#endif
+
 private: 
     void _subscribe(unsigned pid, unsigned eid, DataHandler* h);
     void _subscribe(const PubKey& key, unsigned eid, DataHandler* h);
@@ -206,7 +207,7 @@ private:
     std::atomic<bool> run_thread;
     std::unique_ptr<std::thread> worker_thread;
 
-    Ring<std::shared_ptr<MPEventInfo>>* mp_event_queue;
+    MPEventQueue* mp_event_queue;
 
     std::condition_variable queue_cv;
     std::mutex queue_mutex;
index ee30f3cc78707638afede95f6b414005e88d5c0c..530c77c5c0d274fd2185ab419e70d9e0d01d2505 100644 (file)
@@ -21,6 +21,7 @@
 #include "config.h"
 #endif
 
+#define MP_DATA_BUS_UNIT_TEST
 
 #include "../mp_data_bus.h"
 #include "../main/snort_config.h"
@@ -271,13 +272,14 @@ TEST_GROUP(mp_data_bus_pub)
     }
 
     void teardown() override
-    { }
+    { 
+    }
 };
 
 TEST(mp_data_bus_pub, publish)
 {
-    CHECK_TRUE(mp_dbus->get_event_queue()->empty());
-    CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0);
+    CHECK_TRUE(mp_dbus->get_event_queue()->is_empty());
+    CHECK_TRUE(mp_dbus->get_event_queue()->size() == 0);
 
     std::shared_ptr<UTestEvent> event = std::make_shared<UTestEvent>(100);
 
@@ -304,8 +306,8 @@ TEST(mp_data_bus_pub, publish)
 
 TEST(mp_data_bus_pub, publish_fail_to_send)
 {
-    CHECK_TRUE(mp_dbus->get_event_queue()->empty());
-    CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0);
+    CHECK_TRUE(mp_dbus->get_event_queue()->is_empty());
+    CHECK_TRUE(mp_dbus->get_event_queue()->size() == 0);
 
     test_transport_send_result = false;
 
index 7d1a32d1fa3133747f133f3949dbe63042f2bc8b..2f43222a420b5d68c81f1673fd998d1bb8c5081b 100644 (file)
@@ -22,6 +22,7 @@ set (HELPERS_INCLUDES
     infractions.h
     json_stream.h
     literal_search.h
+    lockless_ring.h
     memcap_allocator.h
     ring2.h
     scratch_allocator.h
diff --git a/src/helpers/lockless_ring.h b/src/helpers/lockless_ring.h
new file mode 100644 (file)
index 0000000..68d8b29
--- /dev/null
@@ -0,0 +1,224 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2026 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// lockless_ring.h author Cisco
+
+#ifndef LOCKLESS_RING_H
+#define LOCKLESS_RING_H
+
+#include <atomic>
+#include <cstdint>
+
+static inline uint32_t round_up_to_power_of_2(uint32_t n)
+{
+    n--;
+    n |= n >> 1;
+    n |= n >> 2;
+    n |= n >> 4;
+    n |= n >> 8;
+    n |= n >> 16;
+    return n + 1;
+}
+
+enum SlotStatus : uint8_t
+{
+    SLOT_EMPTY = 0,
+    SLOT_STORING = 1,
+    SLOT_STORED = 2,
+    SLOT_LOADING = 3
+};
+
+
+template<typename T, bool USE_MOVE_TO_DEQUEUE = true>
+class LocklessRing
+{
+public:
+
+    struct Slot
+    {
+        std::atomic<SlotStatus> status;
+        T data;
+        
+        Slot() : status(SlotStatus::SLOT_EMPTY), data() {}
+    };
+
+    explicit LocklessRing(size_t capacity)
+        : buf_cap(round_up_to_power_of_2(capacity))
+        , buffer_(new Slot[buf_cap])
+        , head_index(0)
+        , tail_index(0)
+    {
+        for (size_t i = 0; i < buf_cap; ++i)
+        {
+            buffer_[i].status.store(SlotStatus::SLOT_EMPTY, std::memory_order_relaxed);
+        }
+    }
+
+    ~LocklessRing()
+    {
+        delete[] buffer_;
+    }
+
+    inline void push(T&& item)
+    {
+        uint32_t current_head = head_index.fetch_add(1, std::memory_order_relaxed);
+        do_push(std::forward<T>(item), &buffer_[current_head & (buf_cap - 1)]);
+    } 
+
+    inline bool try_push(const T& item)
+    {
+        auto current_head_index = head_index.load(std::memory_order_relaxed);
+
+        do
+        {
+            if (static_cast<int32_t>(current_head_index - tail_index.load(std::memory_order_relaxed)) >= static_cast<int32_t>(buf_cap))
+            {
+                return false; // Buffer is full
+            }
+        } while (!head_index.compare_exchange_weak(current_head_index, current_head_index + 1,
+                                               std::memory_order_relaxed,
+                                               std::memory_order_relaxed));
+
+        do_push(T(item), &buffer_[current_head_index & (buf_cap - 1)]);
+        return true;
+    }
+
+    inline bool try_push(T&& item)
+    {
+        auto current_head_index = head_index.load(std::memory_order_relaxed);
+
+        do
+        {
+            if (static_cast<int32_t>(current_head_index - tail_index.load(std::memory_order_relaxed)) >= static_cast<int32_t>(buf_cap))
+            {
+                return false; // Buffer is full
+            }
+        } while (!head_index.compare_exchange_weak(current_head_index, current_head_index + 1,
+                                               std::memory_order_relaxed,
+                                               std::memory_order_relaxed));
+
+        do_push(std::forward<T>(item), &buffer_[current_head_index & (buf_cap - 1)]);
+        return true;
+    }
+
+    inline auto pop()
+    {
+        uint32_t current_tail = tail_index.fetch_add(1, std::memory_order_relaxed);
+        return do_pop(&buffer_[current_tail & (buf_cap - 1)]);
+    }
+
+    inline bool try_pop(T& item)
+    {
+        auto current_tail_index = tail_index.load(std::memory_order_relaxed);
+
+        do
+        {
+            if (static_cast<int32_t>(head_index.load(std::memory_order_relaxed) - current_tail_index) <= 0)
+            {
+                return false; // Buffer is empty
+            }
+        }while(!tail_index.compare_exchange_weak(current_tail_index, current_tail_index + 1,
+                                               std::memory_order_relaxed,
+                                               std::memory_order_relaxed));
+
+        item = do_pop(&buffer_[current_tail_index & (buf_cap - 1)]);
+        return true;
+    }
+
+    // Get approximate size
+    uint32_t size() const
+    {
+        uint32_t head_pos = head_index.load(std::memory_order_relaxed);
+        uint32_t tail_pos = tail_index.load(std::memory_order_relaxed);
+        return head_pos >= tail_pos ? head_pos - tail_pos : head_pos + (UINT32_MAX - tail_pos);
+    }
+
+    // Check if empty
+    bool is_empty() const
+    {
+        return size() == 0;
+    }
+
+    Slot* get_buffer(uint32_t& cur_tail)
+    {
+        cur_tail = tail_index.load(std::memory_order_relaxed) & (buf_cap - 1);
+        return buffer_;
+    }
+
+    void reset()
+    {
+        head_index.store(0, std::memory_order_relaxed);
+        tail_index.store(0, std::memory_order_relaxed);
+        for (size_t i = 0; i < buf_cap; ++i)
+        {
+            buffer_[i].status.store(SlotStatus::SLOT_EMPTY, std::memory_order_relaxed);
+        }
+    }
+
+private:
+
+    inline T do_pop(Slot* slot)
+    {
+        for (;;)
+        {
+            SlotStatus expected_slot_status = SlotStatus::SLOT_STORED;
+            if (slot->status.compare_exchange_weak(expected_slot_status, SlotStatus::SLOT_LOADING,
+                                              std::memory_order_acquire,
+                                              std::memory_order_relaxed))
+            {
+                if (USE_MOVE_TO_DEQUEUE)
+                {
+                    T item = std::move(slot->data);
+                    slot->status.store(SlotStatus::SLOT_EMPTY, std::memory_order_release);
+                    return item;
+                }
+                else
+                {
+                    T item = slot->data;
+                    slot->status.store(SlotStatus::SLOT_EMPTY, std::memory_order_release);
+                    return item;
+                }
+            }
+        }
+    }
+
+    inline void do_push(T&& item, Slot* slot)
+    {
+        for (;;)
+        {
+            SlotStatus expected_slot_status = SlotStatus::SLOT_EMPTY;
+            if (slot->status.compare_exchange_weak(expected_slot_status, SlotStatus::SLOT_STORING,
+                                              std::memory_order_acquire,
+                                              std::memory_order_relaxed))
+            {
+                slot->data = std::move(item);
+                slot->status.store(SlotStatus::SLOT_STORED, std::memory_order_release);
+                return;
+            }
+        }
+    }
+
+    const uint32_t buf_cap;
+    Slot* buffer_;
+    
+    alignas(32) std::atomic<uint32_t> head_index;
+    alignas(32) std::atomic<uint32_t> tail_index;
+
+    static constexpr bool use_move_to_dequeue = USE_MOVE_TO_DEQUEUE;
+};
+
+#endif // LOCKLESS_RING_H
index c8bf69f3dc8845a70ada942b542cd73a33067066..43a4615537cf302de747f727b976dd7f3f634f6c 100644 (file)
@@ -45,6 +45,13 @@ add_catch_test( ring2_test
         ../ring2.h
 )
 
+add_catch_test( lockless_ring_test
+    SOURCES
+        ../lockless_ring.h
+    LIBS
+        ${CMAKE_THREAD_LIBS_INIT}
+)
+
 if (ENABLE_BENCHMARK_TESTS)
 
     add_catch_test( ring2_benchmark
diff --git a/src/helpers/test/lockless_ring_test.cc b/src/helpers/test/lockless_ring_test.cc
new file mode 100644 (file)
index 0000000..3f276ee
--- /dev/null
@@ -0,0 +1,172 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2025-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// ring2_test.cc author Cisco
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <vector>
+#include <thread>
+
+#include "catch/catch.hpp"
+
+#include "../lockless_ring.h"
+
+TEST_CASE("Basic", "LocklessRing")
+{
+    LocklessRing<uint32_t> llr(1024);
+
+    SECTION("Try read empty")
+    {
+        REQUIRE(0 == llr.size());
+
+        uint32_t value = 0;
+        REQUIRE(false == llr.try_pop(value));
+    }
+
+    SECTION("Try push and pop")
+    {
+        REQUIRE(0 == llr.size());
+
+        for (uint32_t i = 0; i < 512; ++i)
+        {
+            REQUIRE(true == llr.try_push(i));
+        }
+
+        REQUIRE(512 == llr.size());
+
+        for (uint32_t i = 0; i < 512; ++i)
+        {
+            uint32_t value = 0;
+            REQUIRE(true == llr.try_pop(value));
+            REQUIRE(i == value);
+        }
+
+        REQUIRE(0 == llr.size());
+    }
+
+    SECTION("Try push over capacity")
+    {
+        REQUIRE(0 == llr.size());
+
+        for (uint32_t i = 0; i < 1024; ++i)
+        {
+            REQUIRE(true == llr.try_push(i));
+        }
+
+        REQUIRE(1024 == llr.size());
+
+        REQUIRE(false == llr.try_push(1024));
+
+        REQUIRE(1024 == llr.size());
+
+        for (uint32_t i = 0; i < 1024; ++i)
+        {
+            uint32_t value = 0;
+            REQUIRE(true == llr.try_pop(value));
+            REQUIRE(i == value);
+        }
+        REQUIRE(0 == llr.size());
+    }
+
+    SECTION("Multiple producers")
+    {
+        const uint32_t num_producers = 4;
+        const uint32_t items_per_producer = 256;
+
+        std::vector<std::thread> producers;
+        for (uint32_t p = 0; p < num_producers; ++p)
+        {
+            producers.emplace_back([&, p]() {
+                for (uint32_t i = 0; i < items_per_producer; ++i)
+                {
+                    while (!llr.try_push(p * items_per_producer + i))
+                    {
+                        // busy wait
+                    }
+                }
+            });
+        }
+
+        for (auto& prod : producers)
+        {
+            prod.join();
+        }
+
+        REQUIRE(num_producers * items_per_producer == llr.size());
+
+        bool seen[num_producers * items_per_producer] = { false };
+        for (uint32_t i = 0; i < num_producers * items_per_producer; ++i)
+        {
+            uint32_t value = 0;
+            REQUIRE(true == llr.try_pop(value));
+            seen[value] = true;
+        }
+
+        for (uint32_t i = 0; i < num_producers * items_per_producer; ++i)
+        {
+            REQUIRE(true == seen[i]);
+        }
+
+        REQUIRE(0 == llr.size());
+    }
+
+    SECTION("Multiple consumers")
+    {
+        const uint32_t total_items = 1024;
+        const uint32_t num_consumers = 4;
+        const uint32_t items_per_consumer = 256;
+
+        for (uint32_t i = 0; i < total_items; ++i)
+        {
+            REQUIRE(true == llr.try_push(i));
+        }
+
+        REQUIRE(total_items == llr.size());
+
+        std::vector<std::thread> consumers;
+        bool seen[total_items] = { false };
+        for (uint32_t c = 0; c < num_consumers; ++c)
+        {
+            consumers.emplace_back([&, c]() {
+                for (uint32_t i = 0; i < items_per_consumer; ++i)
+                {
+                    uint32_t value = 0;
+                    while (!llr.try_pop(value))
+                    {
+                        // busy wait
+                    }
+                    seen[value] = true;
+                }
+            });
+        }
+
+        for (auto& cons : consumers)
+        {
+            cons.join();
+        }
+
+        for (uint32_t i = 0; i < total_items; ++i)
+        {
+            REQUIRE(true == seen[i]);
+        }
+
+        REQUIRE(0 == llr.size());
+    }
+}
\ No newline at end of file
index 7793d945da146c28f949d5d5bb2c32606fabacf3..0a89039d61220168de1e60bd07d2837c2732338e 100644 (file)
@@ -400,7 +400,7 @@ bool ClassificationsModule::set(const char*, Value& v, SnortConfig*)
 
 static const Parameter mp_data_bus_params[] =
 {
-    { "max_eventq_size", Parameter::PT_INT, "100:65535", "1000",
+    { "max_eventq_size", Parameter::PT_INT, "100:65535", "4096",
       "maximum events to queue" },
 
     { "transport", Parameter::PT_STRING, nullptr, nullptr,
index 207c943ebad6c0b56d1ddba51ab6ff0700163124..0003b4814793c8ab610d20c64b96d4f56451b00c 100644 (file)
@@ -49,7 +49,7 @@ static const PegInfo mp_unix_transport_pegs[] =
     { CountType::SUM, "sent_bytes", "mp_transport events bytes sent count" },
     { CountType::SUM, "receive_events", "mp_transport events received count" },
     { CountType::SUM, "receive_bytes", "mp_transport events bytes received count" },
-    { CountType::SUM, "sent_errors", "mp_transport events errors count" },
+    { CountType::SUM, "send_errors", "mp_transport events errors count" },
     { CountType::SUM, "successful_connections", "successful mp_transport connections count" },
     { CountType::SUM, "closed_connections", "closed mp_transport connections count" },
     { CountType::SUM, "connection_retries", "mp_transport connection retries count" },
index 133e9602a1c977ed24038c717818abada105d92a..fa2bbb1504c095c052bc8e7a401bf38da422fd89 100644 (file)
@@ -169,6 +169,12 @@ void UnixDomainConnector::process_receive()
         message_received_handler();
     }
 }
+snort::ConnectorMsg UnixDomainConnector::allocate_connector_message(uint32_t length)
+{
+    uint8_t* data = new uint8_t[length];
+    return snort::ConnectorMsg(data, length, true);
+}
+
 bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg&&, const ID&)
 { return true; }
 snort::ConnectorMsg UnixDomainConnector::receive_message(bool)
index abcc673dd55669ed02486c0d66ef7fa33178e0d5..b4a5c5dd4ef59bef7ba40fdb56162518a2a91b21 100644 (file)
@@ -182,24 +182,24 @@ bool SideChannel::process(int max_messages)
         // get message if one is available.
         ConnectorMsg connector_msg = connector_receive->receive_message(false);
 
-        if ( connector_msg.get_length() > 0 and msg_format == ScMsgFormat::TEXT )
+        if ( connector_msg.get_content_length() > 0 and msg_format == ScMsgFormat::TEXT )
         {
-            connector_msg = from_text((const char*)connector_msg.get_data(), connector_msg.get_length());
+            connector_msg = from_text((const char*)connector_msg.get_content(), connector_msg.get_content_length());
         }
 
         // if none or invalid, we are complete
-        if ( connector_msg.get_length() == 0 )
+        if ( connector_msg.get_content_length() == 0 )
             break;
 
         if ( receive_handler )
         {
             SCMessage* msg = new SCMessage(this, connector_receive, std::move(connector_msg));
 
-            msg->content = const_cast<uint8_t*>(msg->cmsg.get_data());
-            msg->content_length = msg->cmsg.get_length();
+            msg->content = msg->cmsg.get_content();
+            msg->content_length = msg->cmsg.get_content_length();
 
             // if the message is longer than the header, assume we have a header
-            if ( msg->cmsg.get_length() > sizeof(SCMsgHdr) )
+            if ( msg->cmsg.get_content_length() > sizeof(SCMsgHdr) )
             {
                 msg->content += sizeof(SCMsgHdr);
                 msg->content_length -= sizeof( SCMsgHdr );
@@ -241,6 +241,9 @@ SCMsgHdr SideChannel::get_header()
 
 SCMessage* SideChannel::alloc_transmit_message(uint32_t content_length)
 {
+    if (!connector_transmit)
+        return nullptr;
+    
     SCMessage* msg = nullptr;
     const SCMsgHdr sc_hdr = get_header();
 
@@ -248,14 +251,11 @@ SCMessage* SideChannel::alloc_transmit_message(uint32_t content_length)
     {
     case ScMsgFormat::BINARY:
     {
-        uint8_t* msg_data = new uint8_t[sizeof(SCMsgHdr) + content_length];
-
-        memcpy(msg_data, &sc_hdr, sizeof(SCMsgHdr));
+        ConnectorMsg connector_message = connector_transmit->allocate_connector_message(sizeof(SCMsgHdr) + content_length);
+        memcpy(connector_message.get_content(), &sc_hdr, sizeof(SCMsgHdr));
 
-        ConnectorMsg bin_cmsg(msg_data, sizeof(SCMsgHdr) + content_length, true);
-
-        msg = new SCMessage(this, connector_transmit, std::move(bin_cmsg));
-        msg->content = msg_data + sizeof(SCMsgHdr);
+        msg = new SCMessage(this, connector_transmit, std::move(connector_message));
+        msg->content = msg->cmsg.get_content() + sizeof(SCMsgHdr);
         msg->content_length = content_length;
 
         break;
@@ -269,14 +269,12 @@ SCMessage* SideChannel::alloc_transmit_message(uint32_t content_length)
             break;
 
         const uint32_t msg_len = hdr_text.size() + (content_length * TXT_UNIT_LEN);
-        uint8_t* msg_data = new uint8_t[msg_len];
-
-        memcpy(msg_data, hdr_text.c_str(), hdr_text.size());
+        ConnectorMsg connector_message = connector_transmit->allocate_connector_message(msg_len);
 
-        ConnectorMsg text_cmsg(msg_data, msg_len, true);
+        memcpy(connector_message.get_content(), hdr_text.c_str(), hdr_text.size());
 
-        msg = new SCMessage(this, connector_transmit, std::move(text_cmsg));
-        msg->content = msg_data + hdr_text.size();
+        msg = new SCMessage(this, connector_transmit, std::move(connector_message));
+        msg->content = msg->cmsg.get_content() + hdr_text.size();
         msg->content_length = content_length;
 
         break;