From: Raza Shafiq (rshafiq) Date: Mon, 16 Oct 2023 22:12:20 +0000 (+0000) Subject: Pull request #4054: host_cache: added segmented cache X-Git-Tag: 3.1.73.0~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=b7550aa3e5f0f8479a74381089169f00e5a43cfd;p=thirdparty%2Fsnort3.git Pull request #4054: host_cache: added segmented cache Merge in SNORT/snort3 from ~RSHAFIQ/snort3:segmented_atr_cache to master Squashed commit of the following: commit d5e597e210b8c9a8c1d8e3dad6d675ecd9c5bcda Author: rshafiq Date: Wed Oct 11 19:15:09 2023 -0400 host_cache: added segmented cache --- diff --git a/src/hash/CMakeLists.txt b/src/hash/CMakeLists.txt index 72605f005..72e3139d7 100644 --- a/src/hash/CMakeLists.txt +++ b/src/hash/CMakeLists.txt @@ -6,6 +6,7 @@ set (HASH_INCLUDES hash_key_operations.h lru_cache_local.h lru_cache_shared.h + lru_segmented_cache_shared.h xhash.h ) diff --git a/src/hash/dev_notes.txt b/src/hash/dev_notes.txt index e4cff8b75..ef57c4ed4 100644 --- a/src/hash/dev_notes.txt +++ b/src/hash/dev_notes.txt @@ -31,4 +31,17 @@ other data. The utilization of this feature is optional. During initialization, the number of LRUs to be created can be specified. If not specified, a single LRU will be created by default. - +Segmented Shared LRU Cache +The SegmentedLruCache class is a layer built atop the existing +LruCacheShared class, designed to mitigate bottlenecks in +multi-threaded environments, thereby bolstering scalability. +Without altering the core caching logic, it divides the cache +into multiple segments, defaulting to four. This structure drastically +reduces contention among threads, allowing for improved performance. +The segmented approach is generic and configurable, enabling easy +adaptation for different modules while preserving the fundamental +LRU cache behavior. Through this strategic modification, +the pathway for enhanced scalability and future advancements is +significantly broadened, making the caching mechanism more robust +and adaptable to evolving computational demands. +check host_attributes.cc for example usage. \ No newline at end of file diff --git a/src/hash/lru_segmented_cache_shared.h b/src/hash/lru_segmented_cache_shared.h new file mode 100644 index 000000000..c6d4f96ae --- /dev/null +++ b/src/hash/lru_segmented_cache_shared.h @@ -0,0 +1,197 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2016-2023 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// lru_segmented_cache_shared.h author Raza Shafiq + +#ifndef LRU_SEGMENTED_CACHE_SHARED_H +#define LRU_SEGMENTED_CACHE_SHARED_H + +#include +#include + +#include "lru_cache_shared.h" + +#define DEFAULT_SEGMENT_COUNT 4 + +template, typename Eq = std::equal_to> +class SegmentedLruCache +{ +public: + + using LruCacheType = LruCacheShared; + using Data = typename LruCacheType::Data; + + SegmentedLruCache(const size_t initial_size, std::size_t segment_count = DEFAULT_SEGMENT_COUNT) + :segment_count(segment_count) + { + assert( segment_count > 0 && ( segment_count & (segment_count - 1)) == 0 ); + + segments.resize(segment_count); + for( auto& segment : segments ) + segment = std::make_unique(initial_size/segment_count); + + assert( segment_count == segments.size() ); + } + + virtual ~SegmentedLruCache() = default; + + Data find(const Key& key) + { + std::size_t segment_idx = get_segment_idx(key); + return segments[segment_idx]->find(key); + } + + Data operator[](const Key& key) + { + std::size_t segment_idx = get_segment_idx(key); + return (*segments[segment_idx])[key]; + } + + bool remove(const Key& key) + { + std::size_t segment_idx = get_segment_idx(key); + return segments[segment_idx]->remove(key); + } + + bool remove(const Key& key, Data& data) + { + std::size_t idx = get_segment_idx(key); + return segments[idx]->remove(key, data); + } + + Data find_else_create(const Key& key, bool* new_data) + { + std::size_t segment_idx = get_segment_idx(key); + return segments[segment_idx]->find_else_create(key, new_data); + } + + bool find_else_insert(const Key& key, std::shared_ptr& data, bool replace = false) + { + std::size_t segment_idx = get_segment_idx(key); + return segments[segment_idx]->find_else_insert(key, data, replace); + } + + std::shared_ptr find_else_insert(const Key& key, std::shared_ptr& data, LcsInsertStatus* status, bool replace = false) + { + std::size_t segment_idx = get_segment_idx(key); + return segments[segment_idx]->find_else_insert(key, data, status, replace); + } + + bool set_max_size(size_t max_size) + { + bool success = true; + size_t memcap_per_segment = max_size / segment_count; + for ( const auto& segment : segments ) + { + if ( !segment->set_max_size(memcap_per_segment) ) + success = false; + } + return success; + } + + std::vector>> get_all_data() + { + std::vector>> all_data; + + for ( const auto& cache : segments ) + { + auto cache_data = cache->get_all_data(); + all_data.insert(all_data.end(), cache_data.begin(), cache_data.end()); + } + return all_data; + } + + size_t mem_size() + { + size_t mem_size = 0; + for ( const auto& cache : segments ) + { + mem_size += cache->mem_size(); + } + return mem_size; + } + + const PegInfo* get_pegs() + { + return lru_cache_shared_peg_names; + } + + PegCount* get_counts() + { + PegCount* pcs = (PegCount*)&counts; + const PegInfo* pegs = get_pegs(); + + for ( int i = 0; pegs[i].type != CountType::END; i++ ) + { + PegCount c = 0; + for ( const auto& cache : segments ) + { + c += cache->get_counts()[i]; + } + pcs[i] = c; + } + return (PegCount*)&counts; + } + + size_t size() + { + size_t total_size = 0; + for ( const auto& cache : segments ) + { + total_size += cache->size(); + } + return total_size; + } + + size_t get_max_size() + { + size_t max_size = 0; + for ( const auto& cache : segments ) + { + max_size += cache->get_max_size(); + } + return max_size; + } + + size_t get_segment_count() const + { + return segment_count; + } + +protected: + std::size_t segment_count = DEFAULT_SEGMENT_COUNT; + +private: + std::vector> segments; + struct LruCacheSharedStats counts; + + //derived class can implement their own get_segment_idx if needed + virtual std::size_t get_segment_idx(Key val) + { + if ( segment_count == 1 ) + return 0; + const uint8_t* bytes = reinterpret_cast(&val); + uint8_t result = 0; + for ( size_t i = 0; i < sizeof(Key); ++i ) + result ^= bytes[i]; + // Assumes segment_count is a power of 2 always + // This is a fast way to do a modulo operation + return result & (segment_count - 1); + } +}; + +#endif // LRU_SEGMENTED_CACHE_SHARED_H diff --git a/src/hash/test/CMakeLists.txt b/src/hash/test/CMakeLists.txt index d140fc3d9..caf173944 100644 --- a/src/hash/test/CMakeLists.txt +++ b/src/hash/test/CMakeLists.txt @@ -6,6 +6,11 @@ add_cpputest( lru_cache_shared_test SOURCES ../lru_cache_shared.cc ) +add_cpputest( lru_seg_cache_shared_test + SOURCES ../lru_segmented_cache_shared.h + ../lru_cache_shared.cc +) + add_cpputest( hash_lru_cache_test SOURCES ../hash_lru_cache.cc ) diff --git a/src/hash/test/lru_seg_cache_shared_test.cc b/src/hash/test/lru_seg_cache_shared_test.cc new file mode 100644 index 000000000..9779e9e07 --- /dev/null +++ b/src/hash/test/lru_seg_cache_shared_test.cc @@ -0,0 +1,223 @@ +//-------------------------------------------------------------------------- +// Copyright (C) 2016-2023 Cisco and/or its affiliates. All rights reserved. +// +// This program is free software; you can redistribute it and/or modify it +// under the terms of the GNU General Public License Version 2 as published +// by the Free Software Foundation. You may not use, modify or distribute +// this program under any other version of the GNU General Public License. +// +// This program is distributed in the hope that it will be useful, but +// WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +// General Public License for more details. +// +// You should have received a copy of the GNU General Public License along +// with this program; if not, write to the Free Software Foundation, Inc., +// 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. +//-------------------------------------------------------------------------- +// lru_seg_cache_shared_test.cc author Raza Shafiq + +#include "hash/lru_segmented_cache_shared.h" +#include +#include + +TEST_GROUP(segmented_lru_cache) +{ +}; + +// Test SegmentedLruCache constructor and member access. +TEST(segmented_lru_cache, constructor_test) +{ + SegmentedLruCache lru_cache(8); + + CHECK(lru_cache.get_max_size() == 8); + CHECK(lru_cache.size() == 0); +} + +// Test SegmentedLruCache insert and find functions. +TEST(segmented_lru_cache, insert_test) +{ + SegmentedLruCache lru_cache(4); + + auto data = lru_cache[0]; + CHECK(data == lru_cache.find(0)); + data->assign("zero"); + + data = lru_cache[1]; + CHECK(data == lru_cache.find(1)); + data->assign("one"); + + data = lru_cache[2]; + CHECK(data == lru_cache.find(2)); + data->assign("two"); + + data = lru_cache[0]; + data->assign("new_zero"); + + CHECK(nullptr == lru_cache.find(3)); + + data = lru_cache[3]; + CHECK(data == lru_cache.find(3)); + data->assign("three"); + + const auto&& vec = lru_cache.get_all_data(); + CHECK(vec.size() == 4); +} + +// Test mem_size function +TEST(segmented_lru_cache, mem_size_test) +{ + SegmentedLruCache cache(5); + std::shared_ptr data(new std::string("hello")); + cache.find_else_insert(1, data); + + // Assuming each segment's metadata takes some space, hence checking for non-zero size + CHECK(0 != cache.mem_size()); +} + +// Test get_counts function +TEST(segmented_lru_cache, get_counts_test) +{ + SegmentedLruCache cache(5); + // Assuming get_counts function returns a non-null pointer + CHECK(nullptr != cache.get_counts()); +} + +// Test size function +TEST(segmented_lru_cache, size_test) +{ + SegmentedLruCache cache(5); + std::shared_ptr data(new std::string("hello")); + cache.find_else_insert(1, data); + + CHECK(1 == cache.size()); +} + +// Test set/get max size. +TEST(segmented_lru_cache, max_size) +{ + SegmentedLruCache lru_cache(16); + + size_t sz = lru_cache.get_max_size(); + CHECK(sz == 16); + + CHECK(lru_cache.set_max_size(8) == true); + CHECK(lru_cache.get_max_size() == 8); +} + +// Test the remove functions. +TEST(segmented_lru_cache, remove_test) +{ + SegmentedLruCache lru_cache(4); + + for ( int i = 0; i < 4; i++ ) + { + std::shared_ptr data(new std::string(std::to_string(i))); + CHECK(false == lru_cache.find_else_insert(i, data)); + CHECK(true == lru_cache.find_else_insert(i, data)); + CHECK(true == lru_cache.remove(i)); + CHECK(lru_cache.find(i) == nullptr); + } + + CHECK(0 == lru_cache.size()); + + std::shared_ptr data_ptr; + std::shared_ptr data(new std::string("one")); + lru_cache.find_else_insert(1,data); + CHECK(1 == lru_cache.size()); + CHECK(true == lru_cache.remove(1, data_ptr)); + CHECK(*data_ptr == "one"); + CHECK(0 == lru_cache.size()); +} + +// Test the find_else_insert function. +TEST(segmented_lru_cache, find_else_insert) +{ + std::shared_ptr data(new std::string("12345")); + SegmentedLruCache lru_cache(8); + + CHECK(false == lru_cache.find_else_insert(1, data)); + CHECK(1 == lru_cache.size()); + + CHECK(true == lru_cache.find_else_insert(1, data)); + CHECK(1 == lru_cache.size()); +} + +// Test 8 segments +TEST (segmented_lru_cache, segements_8) +{ + SegmentedLruCache lru_cache(1024,8); + std::shared_ptr data(new std::string("12345")); + + CHECK(false == lru_cache.find_else_insert(1, data)); + CHECK(1 == lru_cache.size()); + + CHECK(true == lru_cache.find_else_insert(1, data)); + CHECK(1 == lru_cache.size()); + + CHECK (8 == lru_cache.get_segment_count()); + +} + +// Test statistics counters. +TEST(segmented_lru_cache, stats_test) +{ + SegmentedLruCache lru_cache(8); + + for (int i = 0; i < 10; i++) + lru_cache[i]; + + lru_cache.find(7); + lru_cache.find(8); + lru_cache.find(9); + lru_cache.find(10); + lru_cache.find(11); + + CHECK(lru_cache.set_max_size(16) == true); + + lru_cache.remove(7); + lru_cache.remove(8); + lru_cache.remove(11); // not in cache + + PegCount* stats = lru_cache.get_counts(); + + CHECK(stats[0] == 10); // adds + CHECK(stats[1] == 2); // alloc_prunes + CHECK(stats[2] == 0); // bytes_in_use + CHECK(stats[3] == 0); // items_in_use + CHECK(stats[4] == 3); // find_hits + CHECK(stats[5] == 12); // find_misses + CHECK(stats[6] == 0); // reload_prunes + CHECK(stats[7] == 2); // removes + CHECK(stats[8] == 0); // replaced + +} + +// Test the find_else_insert method for item replacement +TEST(segmented_lru_cache, find_else_insert_replace) +{ + SegmentedLruCache lru_cache(8); + std::shared_ptr data(new std::string("hello")); + LcsInsertStatus status; + + lru_cache.find_else_insert(1, data, &status, false); // initial insert + CHECK(status == LcsInsertStatus::LCS_ITEM_INSERTED); // Check status for initial insert + + std::shared_ptr newData(new std::string("world")); + std::shared_ptr returnedData = lru_cache.find_else_insert(1, newData, &status, true); // replace existing item + CHECK(returnedData != nullptr); + CHECK(status == LcsInsertStatus::LCS_ITEM_REPLACED); // Check status for item replacement + CHECK(*returnedData == "world"); // Check data for item replacement + + returnedData = lru_cache.find_else_insert(1, newData, &status, false); // attempt insert without replace flag + CHECK(returnedData != nullptr); + CHECK(status == LcsInsertStatus::LCS_ITEM_PRESENT); // Check status for existing item + CHECK(*returnedData == "world"); // Data should remain unchanged +} + + + +int main(int argc, char** argv) +{ + return CommandLineTestRunner::RunAllTests(argc, argv); +} diff --git a/src/hash/xhash.h b/src/hash/xhash.h index f77059b4d..7e057e237 100644 --- a/src/hash/xhash.h +++ b/src/hash/xhash.h @@ -60,7 +60,7 @@ public: void* get_user_data(); void* get_user_data(const void* key, uint8_t type = 0); void release(uint8_t type = 0); - int release_node(const void* key, u_int8_t type = 0); + int release_node(const void* key, uint8_t type = 0); int release_node(HashNode* node, uint8_t type = 0); void* get_mru_user_data(uint8_t type = 0); void* get_lru_user_data(uint8_t type = 0); diff --git a/src/main/modules.cc b/src/main/modules.cc index abb5633b2..9917e76f7 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -799,6 +799,9 @@ static const Parameter attribute_table_params[] = { "max_hosts", Parameter::PT_INT, "32:max53", "1024", "maximum number of hosts in attribute table" }, + { "segments", Parameter::PT_INT, "1:32", "4", + "number of segments of hosts attribute table. It must be power of 2." }, + { "max_services_per_host", Parameter::PT_INT, "1:65535", "8", "maximum number of services per host entry in attribute table" }, @@ -830,6 +833,23 @@ bool AttributeTableModule::set(const char*, Value& v, SnortConfig* sc) else if ( v.is("max_hosts") ) sc->max_attribute_hosts = v.get_uint32(); + else if ( v.is("segments") ) + { + auto segments = v.get_uint32(); + + if(segments > 32) + segments = 32; + + if (segments == 0 || (segments & (segments - 1)) != 0) + { + uint8_t highestBitSet = 0; + while (segments >>= 1) + highestBitSet++; + segments = 1 << highestBitSet; + LogMessage("== WARNING: host attribute table segments count is not a power of 2. setting to %d\n", segments); + } + sc->segment_count_host = segments; + } else if ( v.is("max_services_per_host") ) sc->max_attribute_services_per_host = v.get_uint16(); diff --git a/src/main/snort_config.h b/src/main/snort_config.h index 16017df68..c348a2ffb 100644 --- a/src/main/snort_config.h +++ b/src/main/snort_config.h @@ -286,6 +286,7 @@ public: uint32_t max_attribute_hosts = 0; uint32_t max_attribute_services_per_host = 0; uint32_t max_metadata_services = 0; + uint32_t segment_count_host = 4; //------------------------------------------------------ // packet module stuff @@ -515,7 +516,10 @@ public: { return run_flags & RUN_FLAG__READ; } bool ips_inline_mode() const - { return get_ips_policy()->policy_mode == POLICY_MODE__INLINE; } + { + // cppcheck-suppress nullPointer + return get_ips_policy()->policy_mode == POLICY_MODE__INLINE; + } bool ips_inline_test_mode() const { return get_ips_policy()->policy_mode == POLICY_MODE__INLINE_TEST; } @@ -626,6 +630,9 @@ public: uint32_t get_max_attribute_hosts() const { return max_attribute_hosts; } + uint32_t get_segment_count_host() const + { return segment_count_host; } + uint32_t get_max_services_per_host() const { return max_attribute_services_per_host; } diff --git a/src/target_based/host_attributes.cc b/src/target_based/host_attributes.cc index ab22aafdf..6d2d4dcb7 100644 --- a/src/target_based/host_attributes.cc +++ b/src/target_based/host_attributes.cc @@ -25,7 +25,7 @@ #include "host_attributes.h" -#include "hash/lru_cache_shared.h" +#include "hash/lru_segmented_cache_shared.h" #include "main/reload_tuner.h" #include "main/shell.h" #include "main/snort.h" @@ -46,14 +46,16 @@ static const PegInfo host_attribute_pegs[] = }; template -class HostLruSharedCache : public LruCacheShared +class HostLruSegmentedCache : public SegmentedLruCache { public: - HostLruSharedCache(const size_t initial_size) : LruCacheShared(initial_size) - { } + + HostLruSegmentedCache(const size_t initial_size, std::size_t seg_count = DEFAULT_SEGMENT_COUNT) + : SegmentedLruCache(initial_size, seg_count) + { } }; -typedef HostLruSharedCache HostAttributesSharedCache; +typedef HostLruSegmentedCache HostAttributesSegmentedCache; class HostAttributesReloadTuner : public snort::ReloadResourceTuner { @@ -73,10 +75,10 @@ public: { return true; } }; -static THREAD_LOCAL HostAttributesSharedCache* active_cache = nullptr; -static HostAttributesSharedCache* swap_cache = nullptr; -static HostAttributesSharedCache* next_cache = nullptr; -static HostAttributesSharedCache* old_cache = nullptr; +static THREAD_LOCAL HostAttributesSegmentedCache* active_cache = nullptr; +static HostAttributesSegmentedCache* swap_cache = nullptr; +static HostAttributesSegmentedCache* next_cache = nullptr; +static HostAttributesSegmentedCache* old_cache = nullptr; static THREAD_LOCAL HostAttributeStats host_attribute_stats; bool HostAttributesDescriptor::update_service @@ -140,7 +142,7 @@ void HostAttributesDescriptor::get_host_attributes(uint16_t port,HostAttriInfo* bool HostAttributesManager::load_hosts_file(snort::SnortConfig* sc, const char* fname) { delete next_cache; - next_cache = new HostAttributesSharedCache(sc->max_attribute_hosts); + next_cache = new HostAttributesSegmentedCache(sc->max_attribute_hosts, sc->segment_count_host); Shell sh(fname); if ( sh.configure(sc, true) ) @@ -157,7 +159,7 @@ bool HostAttributesManager::load_hosts_file(snort::SnortConfig* sc, const char* bool HostAttributesManager::add_host(HostAttributesEntry host, snort::SnortConfig* sc) { if ( !next_cache ) - next_cache = new HostAttributesSharedCache(sc->max_attribute_hosts); + next_cache = new HostAttributesSegmentedCache(sc->max_attribute_hosts, sc->segment_count_host); return next_cache->find_else_insert(host->get_ip_addr(), host, true); }