From: Oleksandr Stepanov -X (ostepano - SOFTSERVE INC at Cisco) Date: Wed, 4 Feb 2026 21:12:50 +0000 (+0000) Subject: Pull request #5119: mp_dbus: use lockless ring for events X-Git-Tag: 3.11.1.0~25 X-Git-Url: http://git.ipfire.org/?a=commitdiff_plain;h=c8abc96d086efe7d76f00f3ca756098db283cac2;p=thirdparty%2Fsnort3.git Pull request #5119: mp_dbus: use lockless ring for events Merge in SNORT/snort3 from ~OSTEPANO/snort3:mp_ring_opt to master Squashed commit of the following: commit d918d17143484d7d84ed2852dc107290ea8e952a Author: Oleksandr Stepanov Date: Fri Jan 9 06:36:46 2026 -0500 mp_dbus: lockless event ring --- diff --git a/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc b/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc index 7063d8798..beb77d737 100644 --- a/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc +++ b/src/connectors/unixdomain_connector/test/unixdomain_connector_test.cc @@ -535,24 +535,6 @@ TEST(unixdomain_connector_tinit_tterm_call, alloc_transmit) CHECK(unixdomainc->transmit_message(std::move(msg)) == true); } -TEST(unixdomain_connector_tinit_tterm_call, alloc_transmit_header_fail) -{ - const uint32_t len = 40; - const uint8_t* data = new uint8_t[len]; - UnixDomainConnector* unixdomainc = (UnixDomainConnector*)connector; - set_normal_status(); - - ConnectorMsg msg(data, len, true); - - CHECK(msg.get_length() == len); - CHECK(msg.get_data() == data); - - s_send_ret_header = sizeof(UnixDomainConnectorMsgHdr)-1; - s_send_ret_other = len; - CHECK(unixdomainc->transmit_message(msg) == false); - CHECK(unixdomainc->transmit_message(std::move(msg)) == false); -} - TEST(unixdomain_connector_tinit_tterm_call, alloc_transmit_body_fail) { const uint32_t len = 40; diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.cc b/src/connectors/unixdomain_connector/unixdomain_connector.cc index f183c4363..2e45e42cf 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.cc +++ b/src/connectors/unixdomain_connector/unixdomain_connector.cc @@ -47,6 +47,38 @@ THREAD_LOCAL ProfileStats unixdomain_connector_perfstats; /* Module *****************************************************************/ +#define SOCKET_SEND_BUFFER_SIZE_MULTIPLIER 4 + +static void increase_socket_send_buffer_size(int& sock_handle, uint8_t size_mult) +{ + char buf[4] = {}; + socklen_t buf_len = 4; + auto get_opts = getsockopt(sock_handle, SOL_SOCKET, SO_SNDBUF, buf, &buf_len); + if (get_opts == 0 and buf_len <= 4) + { + LogMessage("UnixDomainC: Socket default send buffer size: %d ; Attempting to increase\n", *(int*)buf); + uint32_t socket_size = *(int*)buf * size_mult; + + auto set_opt_res = setsockopt(sock_handle, SOL_SOCKET, SO_SNDBUF, (void*)&socket_size, sizeof(socket_size)); + + if (set_opt_res != 0) + { + LogMessage("UnixDomainC: Failed to update send buffer size, continuing with default: %s \n", strerror(errno)); + } + else + { + get_opts = getsockopt(sock_handle, SOL_SOCKET, SO_SNDBUF, buf, &buf_len); + if (get_opts == 0) + LogMessage("UnixDomainC: Updated Socket send buffer size: %d \n", *(int*)buf); + } + } + else + { + LogMessage("UnixDomainC: Failed to get socket send buffer size: %s \n", strerror(errno)); + } +} + + static bool attempt_connection(int& sfd, const char* path, unsigned long timeout_sec) { sfd = socket(AF_UNIX, SOCK_STREAM, 0); if (sfd == -1) { @@ -54,6 +86,8 @@ static bool attempt_connection(int& sfd, const char* path, unsigned long timeout return false; } + increase_socket_send_buffer_size(sfd, SOCKET_SEND_BUFFER_SIZE_MULTIPLIER); + // Set the socket to non-blocking mode int flags = fcntl(sfd, F_GETFL, 0); if (flags == -1) { @@ -172,7 +206,7 @@ static void start_retry_thread(const UnixDomainConnectorConfig& cfg, size_t idx, UnixDomainConnector::UnixDomainConnector(const UnixDomainConnectorConfig& unixdomain_connector_config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper) : Connector(unixdomain_connector_config), sock_fd(sfd), run_thread(false), receive_thread(nullptr), - receive_ring(new ReceiveRing(50)), instance_id(idx), cfg(unixdomain_connector_config), reconnect_helper(reconnect_helper) { + receive_ring(new ReceiveRing(2000)), instance_id(idx), cfg(unixdomain_connector_config), reconnect_helper(reconnect_helper) { if (unixdomain_connector_config.async_receive) { start_receive_thread(); } @@ -376,20 +410,23 @@ bool UnixDomainConnector::internal_transmit_message(const ConnectorMsg& msg) { return false; } - UnixDomainConnectorMsgHdr unixdomainc_hdr(msg.get_length()); - - if ( send( sock_fd, (const char*)&unixdomainc_hdr, sizeof(unixdomainc_hdr), 0 ) != sizeof(unixdomainc_hdr) ) + if (send(sock_fd, msg.get_data(), msg.get_length(), 0) != msg.get_length()) { - ErrorMessage("UnixDomainC: failed to transmit header\n"); + ErrorMessage("UnixDomainC: failed to transmit message, error = %s\n", strerror(errno)); return false; } - if (send(sock_fd, msg.get_data(), msg.get_length(), 0) != msg.get_length()) - return false; - return true; } +ConnectorMsg UnixDomainConnector::allocate_connector_message(uint32_t length) +{ + UnixDomainConnectorMsgHdr unixdomainc_hdr(length); + uint8_t* data = new uint8_t[length + sizeof(UnixDomainConnectorMsgHdr)]; + memcpy(data, &unixdomainc_hdr, sizeof(UnixDomainConnectorMsgHdr)); + return ConnectorMsg(data, length + sizeof(UnixDomainConnectorMsgHdr), true, sizeof(UnixDomainConnectorMsgHdr)); +} + bool UnixDomainConnector::transmit_message(const ConnectorMsg& msg, const ID&) { return internal_transmit_message(msg); } @@ -479,6 +516,7 @@ static UnixDomainConnector* unixdomain_connector_tinit_answer(const UnixDomainCo close(sfd); return nullptr; } + increase_socket_send_buffer_size(peer_sfd, SOCKET_SEND_BUFFER_SIZE_MULTIPLIER); LogMessage("UnixDomainC: Accepted connection from %s \n", path); return new UnixDomainConnector(cfg, peer_sfd, idx); @@ -644,6 +682,7 @@ void UnixDomainConnectorListener::start_accepting_connections(UnixDomainConnecto ErrorMessage("UnixDomainC: accept error: %s \n", strerror(errno)); continue; } + increase_socket_send_buffer_size(peer_sfd, SOCKET_SEND_BUFFER_SIZE_MULTIPLIER); error_count = 0; auto config_copy = new UnixDomainConnectorConfig(*config); auto unix_conn = new UnixDomainConnector(*config_copy, peer_sfd, 0); diff --git a/src/connectors/unixdomain_connector/unixdomain_connector.h b/src/connectors/unixdomain_connector/unixdomain_connector.h index 5931effcf..10081c15c 100644 --- a/src/connectors/unixdomain_connector/unixdomain_connector.h +++ b/src/connectors/unixdomain_connector/unixdomain_connector.h @@ -61,6 +61,8 @@ public: UnixDomainConnector(const UnixDomainConnectorConfig& config, int sfd, size_t idx, UnixDomainConnectorReconnectHelper* reconnect_helper = nullptr); ~UnixDomainConnector() override; + snort::ConnectorMsg allocate_connector_message(uint32_t length) override; + bool transmit_message(const snort::ConnectorMsg&, const ID& = null) override; bool transmit_message(const snort::ConnectorMsg&&, const ID& = null) override; diff --git a/src/framework/connector.h b/src/framework/connector.h index 8e9739700..e349665f7 100644 --- a/src/framework/connector.h +++ b/src/framework/connector.h @@ -37,7 +37,7 @@ namespace snort { // this is the current version of the api -#define CONNECTOR_API_VERSION ((BASE_API_VERSION << 16) | 3) +#define CONNECTOR_API_VERSION ((BASE_API_VERSION << 16) | 4) //------------------------------------------------------------------------- // api for class @@ -52,8 +52,8 @@ class ConnectorMsg public: ConnectorMsg() = default; - ConnectorMsg(const uint8_t* data, uint32_t length, bool pass_ownership = false) : - data(data), length(length), owns(pass_ownership) + ConnectorMsg(const uint8_t* data, uint32_t length, bool pass_ownership = false, uint32_t content_offset = 0) : + data(data), content(const_cast(data) + content_offset), length(length), owns(pass_ownership) { } ~ConnectorMsg() @@ -63,7 +63,7 @@ public: ConnectorMsg& operator=(ConnectorMsg& other) = delete; ConnectorMsg(ConnectorMsg&& other) : - data(other.data), length(other.length), owns(other.owns) + data(other.data), content(other.content), length(other.length), owns(other.owns) { other.owns = false; } ConnectorMsg& operator=(ConnectorMsg&& other) @@ -72,6 +72,7 @@ public: delete[] data; data = other.data; + content = other.content; length = other.length; owns = other.owns; @@ -83,11 +84,18 @@ public: const uint8_t* get_data() const { return data; } + uint8_t* get_content() const + { return content; } + uint32_t get_length() const { return length; } + uint32_t get_content_length() const + { return length - (content - data); } + private: const uint8_t* data = nullptr; + uint8_t* content = nullptr; uint32_t length = 0; bool owns = false; }; @@ -111,6 +119,12 @@ public: virtual const ID get_id(const char*) const { return null; } + virtual ConnectorMsg allocate_connector_message(uint32_t length) + { + const uint8_t* data = new uint8_t[length]; + return ConnectorMsg(data, length, true); + } + virtual bool transmit_message(const ConnectorMsg&, const ID& = null) = 0; virtual bool transmit_message(const ConnectorMsg&&, const ID& = null) = 0; diff --git a/src/framework/mp_data_bus.cc b/src/framework/mp_data_bus.cc index 0234803da..215cb2fa5 100644 --- a/src/framework/mp_data_bus.cc +++ b/src/framework/mp_data_bus.cc @@ -33,11 +33,11 @@ #include "main/snort_types.h" #include "log/messages.h" #include "log/log_stats.h" -#include "helpers/ring.h" #include "managers/mp_transport_manager.h" #include "managers/module_manager.h" #include "main/snort.h" #include "framework/module.h" +#include "utils/util.h" using namespace snort; @@ -76,10 +76,9 @@ void MPDataBusLog(const char* msg, ...) MPDataBus::MPDataBus() : run_thread(true), - worker_thread(nullptr), - mp_event_queue(nullptr) + worker_thread(nullptr) { - mp_event_queue = new Ring>(mp_max_eventq_size); + mp_event_queue = new MPEventQueue(mp_max_eventq_size); start_worker_thread(); } @@ -290,23 +289,23 @@ void MPDataBus::process_event_queue() // coverity[wait_not_in_locked_loop:FALSE] if( (std::cv_status::timeout == queue_cv.wait_for(u_lock, std::chrono::milliseconds(WORKER_THREAD_SLEEP))) and - mp_event_queue->empty() ) + mp_event_queue->is_empty() ) return; - while (!mp_event_queue->empty()) { - std::shared_ptr event_info = mp_event_queue->get(nullptr); - if (event_info) { - MPDataBusLog("Processing event for publisher ID %u \n", - event_info->pub_id); - - if (!transport_layer){ - run_thread.store(false); - ErrorMessage("MPDataBus: Transport layer not initialized\n"); - return; - } - + + static std::shared_ptr event_info; + + while (mp_event_queue->try_pop(event_info)) + { + if (UNLIKELY(!transport_layer)) + { + run_thread.store(false); + ErrorMessage("MPDataBus: Transport layer not initialized\n"); + return; + } + + MPDataBusLog("Processing event for publisher ID %u \n", event_info->pub_id); auto send_res = transport_layer->send_to_transport(*event_info); - { std::lock_guard lock(mp_stats_mutex); mp_pub_stats[event_info->pub_id].total_messages_published++; @@ -319,7 +318,7 @@ void MPDataBus::process_event_queue() mp_pub_stats[event_info->pub_id].total_messages_sent++; } } - } + } } @@ -474,8 +473,8 @@ void MPDataBus::dump_stats(ControlConn *ctrlconn, const char *module_name) void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name) { - int current_read_idx = 0; - uint32_t ring_items = mp_event_queue->count(); + uint32_t current_read_idx = 0; + uint32_t ring_items = mp_event_queue->size(); if(ring_items == 0) { if (ctrlconn) @@ -488,16 +487,7 @@ void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name) } return; } - auto event_queue_store = mp_event_queue->grab_store(current_read_idx); - - if (current_read_idx == 0) - { - current_read_idx = mp_max_eventq_size - 1; - } - else - { - current_read_idx--; - } + auto event_queue_buffer = mp_event_queue->get_buffer(current_read_idx); for (uint32_t i = current_read_idx; i <= ring_items; i++) { @@ -506,7 +496,7 @@ void MPDataBus::dump_events(ControlConn *ctrlconn, const char *module_name) i = 0; ring_items -= mp_max_eventq_size; } - auto event_info = event_queue_store[i]; + auto event_info = event_queue_buffer[i].data; if (event_info) { if (module_name) @@ -625,7 +615,7 @@ bool MPDataBus::_publish(unsigned pid, unsigned eid, DataEvent& e, Flow* f) bool snort::MPDataBus::_enqueue_event(std::shared_ptr ev_info) { - bool res = mp_event_queue != nullptr and !mp_event_queue->full() and mp_event_queue->put(std::move(ev_info)); + bool res = mp_event_queue != nullptr and mp_event_queue->try_push(std::move(ev_info)); if(res) queue_cv.notify_one(); return res; } diff --git a/src/framework/mp_data_bus.h b/src/framework/mp_data_bus.h index ce2752a62..47c7ff447 100644 --- a/src/framework/mp_data_bus.h +++ b/src/framework/mp_data_bus.h @@ -45,16 +45,14 @@ #include "control/control.h" #include "framework/mp_transport.h" #include "framework/counts.h" +#include "helpers/lockless_ring.h" #include "main/snort_types.h" #include "data_bus.h" #define DEFAULT_TRANSPORT "unix_transport" -#define DEFAULT_MAX_EVENTQ_SIZE 1000 +#define DEFAULT_MAX_EVENTQ_SIZE 4096 #define WORKER_THREAD_SLEEP 100 -template -class Ring; - namespace snort { class Flow; @@ -116,6 +114,8 @@ struct MPHelperFunctions { : serializer(s), deserializer(d) {} }; +using MPEventQueue = LocklessRing, false>; + struct pair_hash { template @@ -175,9 +175,6 @@ public: // API for receiving the DataEvent and Event type from transport layer using EventInfo void receive_message(const MPEventInfo& event_info); - Ring>* get_event_queue() - { return mp_event_queue; } - void set_debug_enabled(bool flag); MPDataBusStats get_stats_copy(); @@ -189,6 +186,10 @@ public: void dump_events(ControlConn* ctrlconn, const char* module_name); void show_channel_status(ControlConn* ctrlconn); +#ifdef MP_DATA_BUS_UNIT_TEST + MPEventQueue* get_event_queue() { return mp_event_queue; } +#endif + private: void _subscribe(unsigned pid, unsigned eid, DataHandler* h); void _subscribe(const PubKey& key, unsigned eid, DataHandler* h); @@ -206,7 +207,7 @@ private: std::atomic run_thread; std::unique_ptr worker_thread; - Ring>* mp_event_queue; + MPEventQueue* mp_event_queue; std::condition_variable queue_cv; std::mutex queue_mutex; diff --git a/src/framework/test/mp_data_bus_test.cc b/src/framework/test/mp_data_bus_test.cc index ee30f3cc7..530c77c5c 100644 --- a/src/framework/test/mp_data_bus_test.cc +++ b/src/framework/test/mp_data_bus_test.cc @@ -21,6 +21,7 @@ #include "config.h" #endif +#define MP_DATA_BUS_UNIT_TEST #include "../mp_data_bus.h" #include "../main/snort_config.h" @@ -271,13 +272,14 @@ TEST_GROUP(mp_data_bus_pub) } void teardown() override - { } + { + } }; TEST(mp_data_bus_pub, publish) { - CHECK_TRUE(mp_dbus->get_event_queue()->empty()); - CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0); + CHECK_TRUE(mp_dbus->get_event_queue()->is_empty()); + CHECK_TRUE(mp_dbus->get_event_queue()->size() == 0); std::shared_ptr event = std::make_shared(100); @@ -304,8 +306,8 @@ TEST(mp_data_bus_pub, publish) TEST(mp_data_bus_pub, publish_fail_to_send) { - CHECK_TRUE(mp_dbus->get_event_queue()->empty()); - CHECK_TRUE(mp_dbus->get_event_queue()->count() == 0); + CHECK_TRUE(mp_dbus->get_event_queue()->is_empty()); + CHECK_TRUE(mp_dbus->get_event_queue()->size() == 0); test_transport_send_result = false; diff --git a/src/helpers/CMakeLists.txt b/src/helpers/CMakeLists.txt index 7d1a32d1f..2f43222a4 100644 --- a/src/helpers/CMakeLists.txt +++ b/src/helpers/CMakeLists.txt @@ -22,6 +22,7 @@ set (HELPERS_INCLUDES infractions.h json_stream.h literal_search.h + lockless_ring.h memcap_allocator.h ring2.h scratch_allocator.h diff --git a/src/helpers/lockless_ring.h b/src/helpers/lockless_ring.h new file mode 100644 index 000000000..68d8b2967 --- /dev/null +++ b/src/helpers/lockless_ring.h @@ -0,0 +1,224 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2014-2026 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// lockless_ring.h author Cisco + +#ifndef LOCKLESS_RING_H +#define LOCKLESS_RING_H + +#include +#include + +static inline uint32_t round_up_to_power_of_2(uint32_t n) +{ + n--; + n |= n >> 1; + n |= n >> 2; + n |= n >> 4; + n |= n >> 8; + n |= n >> 16; + return n + 1; +} + +enum SlotStatus : uint8_t +{ + SLOT_EMPTY = 0, + SLOT_STORING = 1, + SLOT_STORED = 2, + SLOT_LOADING = 3 +}; + + +template +class LocklessRing +{ +public: + + struct Slot + { + std::atomic status; + T data; + + Slot() : status(SlotStatus::SLOT_EMPTY), data() {} + }; + + explicit LocklessRing(size_t capacity) + : buf_cap(round_up_to_power_of_2(capacity)) + , buffer_(new Slot[buf_cap]) + , head_index(0) + , tail_index(0) + { + for (size_t i = 0; i < buf_cap; ++i) + { + buffer_[i].status.store(SlotStatus::SLOT_EMPTY, std::memory_order_relaxed); + } + } + + ~LocklessRing() + { + delete[] buffer_; + } + + inline void push(T&& item) + { + uint32_t current_head = head_index.fetch_add(1, std::memory_order_relaxed); + do_push(std::forward(item), &buffer_[current_head & (buf_cap - 1)]); + } + + inline bool try_push(const T& item) + { + auto current_head_index = head_index.load(std::memory_order_relaxed); + + do + { + if (static_cast(current_head_index - tail_index.load(std::memory_order_relaxed)) >= static_cast(buf_cap)) + { + return false; // Buffer is full + } + } while (!head_index.compare_exchange_weak(current_head_index, current_head_index + 1, + std::memory_order_relaxed, + std::memory_order_relaxed)); + + do_push(T(item), &buffer_[current_head_index & (buf_cap - 1)]); + return true; + } + + inline bool try_push(T&& item) + { + auto current_head_index = head_index.load(std::memory_order_relaxed); + + do + { + if (static_cast(current_head_index - tail_index.load(std::memory_order_relaxed)) >= static_cast(buf_cap)) + { + return false; // Buffer is full + } + } while (!head_index.compare_exchange_weak(current_head_index, current_head_index + 1, + std::memory_order_relaxed, + std::memory_order_relaxed)); + + do_push(std::forward(item), &buffer_[current_head_index & (buf_cap - 1)]); + return true; + } + + inline auto pop() + { + uint32_t current_tail = tail_index.fetch_add(1, std::memory_order_relaxed); + return do_pop(&buffer_[current_tail & (buf_cap - 1)]); + } + + inline bool try_pop(T& item) + { + auto current_tail_index = tail_index.load(std::memory_order_relaxed); + + do + { + if (static_cast(head_index.load(std::memory_order_relaxed) - current_tail_index) <= 0) + { + return false; // Buffer is empty + } + }while(!tail_index.compare_exchange_weak(current_tail_index, current_tail_index + 1, + std::memory_order_relaxed, + std::memory_order_relaxed)); + + item = do_pop(&buffer_[current_tail_index & (buf_cap - 1)]); + return true; + } + + // Get approximate size + uint32_t size() const + { + uint32_t head_pos = head_index.load(std::memory_order_relaxed); + uint32_t tail_pos = tail_index.load(std::memory_order_relaxed); + return head_pos >= tail_pos ? head_pos - tail_pos : head_pos + (UINT32_MAX - tail_pos); + } + + // Check if empty + bool is_empty() const + { + return size() == 0; + } + + Slot* get_buffer(uint32_t& cur_tail) + { + cur_tail = tail_index.load(std::memory_order_relaxed) & (buf_cap - 1); + return buffer_; + } + + void reset() + { + head_index.store(0, std::memory_order_relaxed); + tail_index.store(0, std::memory_order_relaxed); + for (size_t i = 0; i < buf_cap; ++i) + { + buffer_[i].status.store(SlotStatus::SLOT_EMPTY, std::memory_order_relaxed); + } + } + +private: + + inline T do_pop(Slot* slot) + { + for (;;) + { + SlotStatus expected_slot_status = SlotStatus::SLOT_STORED; + if (slot->status.compare_exchange_weak(expected_slot_status, SlotStatus::SLOT_LOADING, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + if (USE_MOVE_TO_DEQUEUE) + { + T item = std::move(slot->data); + slot->status.store(SlotStatus::SLOT_EMPTY, std::memory_order_release); + return item; + } + else + { + T item = slot->data; + slot->status.store(SlotStatus::SLOT_EMPTY, std::memory_order_release); + return item; + } + } + } + } + + inline void do_push(T&& item, Slot* slot) + { + for (;;) + { + SlotStatus expected_slot_status = SlotStatus::SLOT_EMPTY; + if (slot->status.compare_exchange_weak(expected_slot_status, SlotStatus::SLOT_STORING, + std::memory_order_acquire, + std::memory_order_relaxed)) + { + slot->data = std::move(item); + slot->status.store(SlotStatus::SLOT_STORED, std::memory_order_release); + return; + } + } + } + + const uint32_t buf_cap; + Slot* buffer_; + + alignas(32) std::atomic head_index; + alignas(32) std::atomic tail_index; + + static constexpr bool use_move_to_dequeue = USE_MOVE_TO_DEQUEUE; +}; + +#endif // LOCKLESS_RING_H diff --git a/src/helpers/test/CMakeLists.txt b/src/helpers/test/CMakeLists.txt index c8bf69f3d..43a461553 100644 --- a/src/helpers/test/CMakeLists.txt +++ b/src/helpers/test/CMakeLists.txt @@ -45,6 +45,13 @@ add_catch_test( ring2_test ../ring2.h ) +add_catch_test( lockless_ring_test + SOURCES + ../lockless_ring.h + LIBS + ${CMAKE_THREAD_LIBS_INIT} +) + if (ENABLE_BENCHMARK_TESTS) add_catch_test( ring2_benchmark diff --git a/src/helpers/test/lockless_ring_test.cc b/src/helpers/test/lockless_ring_test.cc new file mode 100644 index 000000000..3f276eecf --- /dev/null +++ b/src/helpers/test/lockless_ring_test.cc @@ -0,0 +1,172 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2025-2025 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// ring2_test.cc author Cisco + +#ifdef HAVE_CONFIG_H +#include "config.h" +#endif + +#include +#include + +#include "catch/catch.hpp" + +#include "../lockless_ring.h" + +TEST_CASE("Basic", "LocklessRing") +{ + LocklessRing llr(1024); + + SECTION("Try read empty") + { + REQUIRE(0 == llr.size()); + + uint32_t value = 0; + REQUIRE(false == llr.try_pop(value)); + } + + SECTION("Try push and pop") + { + REQUIRE(0 == llr.size()); + + for (uint32_t i = 0; i < 512; ++i) + { + REQUIRE(true == llr.try_push(i)); + } + + REQUIRE(512 == llr.size()); + + for (uint32_t i = 0; i < 512; ++i) + { + uint32_t value = 0; + REQUIRE(true == llr.try_pop(value)); + REQUIRE(i == value); + } + + REQUIRE(0 == llr.size()); + } + + SECTION("Try push over capacity") + { + REQUIRE(0 == llr.size()); + + for (uint32_t i = 0; i < 1024; ++i) + { + REQUIRE(true == llr.try_push(i)); + } + + REQUIRE(1024 == llr.size()); + + REQUIRE(false == llr.try_push(1024)); + + REQUIRE(1024 == llr.size()); + + for (uint32_t i = 0; i < 1024; ++i) + { + uint32_t value = 0; + REQUIRE(true == llr.try_pop(value)); + REQUIRE(i == value); + } + REQUIRE(0 == llr.size()); + } + + SECTION("Multiple producers") + { + const uint32_t num_producers = 4; + const uint32_t items_per_producer = 256; + + std::vector producers; + for (uint32_t p = 0; p < num_producers; ++p) + { + producers.emplace_back([&, p]() { + for (uint32_t i = 0; i < items_per_producer; ++i) + { + while (!llr.try_push(p * items_per_producer + i)) + { + // busy wait + } + } + }); + } + + for (auto& prod : producers) + { + prod.join(); + } + + REQUIRE(num_producers * items_per_producer == llr.size()); + + bool seen[num_producers * items_per_producer] = { false }; + for (uint32_t i = 0; i < num_producers * items_per_producer; ++i) + { + uint32_t value = 0; + REQUIRE(true == llr.try_pop(value)); + seen[value] = true; + } + + for (uint32_t i = 0; i < num_producers * items_per_producer; ++i) + { + REQUIRE(true == seen[i]); + } + + REQUIRE(0 == llr.size()); + } + + SECTION("Multiple consumers") + { + const uint32_t total_items = 1024; + const uint32_t num_consumers = 4; + const uint32_t items_per_consumer = 256; + + for (uint32_t i = 0; i < total_items; ++i) + { + REQUIRE(true == llr.try_push(i)); + } + + REQUIRE(total_items == llr.size()); + + std::vector consumers; + bool seen[total_items] = { false }; + for (uint32_t c = 0; c < num_consumers; ++c) + { + consumers.emplace_back([&, c]() { + for (uint32_t i = 0; i < items_per_consumer; ++i) + { + uint32_t value = 0; + while (!llr.try_pop(value)) + { + // busy wait + } + seen[value] = true; + } + }); + } + + for (auto& cons : consumers) + { + cons.join(); + } + + for (uint32_t i = 0; i < total_items; ++i) + { + REQUIRE(true == seen[i]); + } + + REQUIRE(0 == llr.size()); + } +} \ No newline at end of file diff --git a/src/main/modules.cc b/src/main/modules.cc index 7793d945d..0a89039d6 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -400,7 +400,7 @@ bool ClassificationsModule::set(const char*, Value& v, SnortConfig*) static const Parameter mp_data_bus_params[] = { - { "max_eventq_size", Parameter::PT_INT, "100:65535", "1000", + { "max_eventq_size", Parameter::PT_INT, "100:65535", "4096", "maximum events to queue" }, { "transport", Parameter::PT_STRING, nullptr, nullptr, diff --git a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc index 207c943eb..0003b4814 100644 --- a/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc +++ b/src/mp_transport/mp_unix_transport/mp_unix_transport_module.cc @@ -49,7 +49,7 @@ static const PegInfo mp_unix_transport_pegs[] = { CountType::SUM, "sent_bytes", "mp_transport events bytes sent count" }, { CountType::SUM, "receive_events", "mp_transport events received count" }, { CountType::SUM, "receive_bytes", "mp_transport events bytes received count" }, - { CountType::SUM, "sent_errors", "mp_transport events errors count" }, + { CountType::SUM, "send_errors", "mp_transport events errors count" }, { CountType::SUM, "successful_connections", "successful mp_transport connections count" }, { CountType::SUM, "closed_connections", "closed mp_transport connections count" }, { CountType::SUM, "connection_retries", "mp_transport connection retries count" }, diff --git a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc index 133e9602a..fa2bbb150 100644 --- a/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc +++ b/src/mp_transport/mp_unix_transport/test/unix_transport_test.cc @@ -169,6 +169,12 @@ void UnixDomainConnector::process_receive() message_received_handler(); } } +snort::ConnectorMsg UnixDomainConnector::allocate_connector_message(uint32_t length) +{ + uint8_t* data = new uint8_t[length]; + return snort::ConnectorMsg(data, length, true); +} + bool UnixDomainConnector::transmit_message(const snort::ConnectorMsg&&, const ID&) { return true; } snort::ConnectorMsg UnixDomainConnector::receive_message(bool) diff --git a/src/side_channel/side_channel.cc b/src/side_channel/side_channel.cc index abcc673dd..b4a5c5dd4 100644 --- a/src/side_channel/side_channel.cc +++ b/src/side_channel/side_channel.cc @@ -182,24 +182,24 @@ bool SideChannel::process(int max_messages) // get message if one is available. ConnectorMsg connector_msg = connector_receive->receive_message(false); - if ( connector_msg.get_length() > 0 and msg_format == ScMsgFormat::TEXT ) + if ( connector_msg.get_content_length() > 0 and msg_format == ScMsgFormat::TEXT ) { - connector_msg = from_text((const char*)connector_msg.get_data(), connector_msg.get_length()); + connector_msg = from_text((const char*)connector_msg.get_content(), connector_msg.get_content_length()); } // if none or invalid, we are complete - if ( connector_msg.get_length() == 0 ) + if ( connector_msg.get_content_length() == 0 ) break; if ( receive_handler ) { SCMessage* msg = new SCMessage(this, connector_receive, std::move(connector_msg)); - msg->content = const_cast(msg->cmsg.get_data()); - msg->content_length = msg->cmsg.get_length(); + msg->content = msg->cmsg.get_content(); + msg->content_length = msg->cmsg.get_content_length(); // if the message is longer than the header, assume we have a header - if ( msg->cmsg.get_length() > sizeof(SCMsgHdr) ) + if ( msg->cmsg.get_content_length() > sizeof(SCMsgHdr) ) { msg->content += sizeof(SCMsgHdr); msg->content_length -= sizeof( SCMsgHdr ); @@ -241,6 +241,9 @@ SCMsgHdr SideChannel::get_header() SCMessage* SideChannel::alloc_transmit_message(uint32_t content_length) { + if (!connector_transmit) + return nullptr; + SCMessage* msg = nullptr; const SCMsgHdr sc_hdr = get_header(); @@ -248,14 +251,11 @@ SCMessage* SideChannel::alloc_transmit_message(uint32_t content_length) { case ScMsgFormat::BINARY: { - uint8_t* msg_data = new uint8_t[sizeof(SCMsgHdr) + content_length]; - - memcpy(msg_data, &sc_hdr, sizeof(SCMsgHdr)); + ConnectorMsg connector_message = connector_transmit->allocate_connector_message(sizeof(SCMsgHdr) + content_length); + memcpy(connector_message.get_content(), &sc_hdr, sizeof(SCMsgHdr)); - ConnectorMsg bin_cmsg(msg_data, sizeof(SCMsgHdr) + content_length, true); - - msg = new SCMessage(this, connector_transmit, std::move(bin_cmsg)); - msg->content = msg_data + sizeof(SCMsgHdr); + msg = new SCMessage(this, connector_transmit, std::move(connector_message)); + msg->content = msg->cmsg.get_content() + sizeof(SCMsgHdr); msg->content_length = content_length; break; @@ -269,14 +269,12 @@ SCMessage* SideChannel::alloc_transmit_message(uint32_t content_length) break; const uint32_t msg_len = hdr_text.size() + (content_length * TXT_UNIT_LEN); - uint8_t* msg_data = new uint8_t[msg_len]; - - memcpy(msg_data, hdr_text.c_str(), hdr_text.size()); + ConnectorMsg connector_message = connector_transmit->allocate_connector_message(msg_len); - ConnectorMsg text_cmsg(msg_data, msg_len, true); + memcpy(connector_message.get_content(), hdr_text.c_str(), hdr_text.size()); - msg = new SCMessage(this, connector_transmit, std::move(text_cmsg)); - msg->content = msg_data + hdr_text.size(); + msg = new SCMessage(this, connector_transmit, std::move(connector_message)); + msg->content = msg->cmsg.get_content() + hdr_text.size(); msg->content_length = content_length; break;