]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4054: host_cache: added segmented cache
authorRaza Shafiq (rshafiq) <rshafiq@cisco.com>
Mon, 16 Oct 2023 22:12:20 +0000 (22:12 +0000)
committerSteve Chew (stechew) <stechew@cisco.com>
Mon, 16 Oct 2023 22:12:20 +0000 (22:12 +0000)
Merge in SNORT/snort3 from ~RSHAFIQ/snort3:segmented_atr_cache to master

Squashed commit of the following:

commit d5e597e210b8c9a8c1d8e3dad6d675ecd9c5bcda
Author: rshafiq <rshafiq@cisco.com>
Date:   Wed Oct 11 19:15:09 2023 -0400

    host_cache: added segmented cache

src/hash/CMakeLists.txt
src/hash/dev_notes.txt
src/hash/lru_segmented_cache_shared.h [new file with mode: 0644]
src/hash/test/CMakeLists.txt
src/hash/test/lru_seg_cache_shared_test.cc [new file with mode: 0644]
src/hash/xhash.h
src/main/modules.cc
src/main/snort_config.h
src/target_based/host_attributes.cc

index 72605f00514a153199f71354e74b15c16e3935ba..72e3139d7462e24f4545c21755c744c2ae13ee31 100644 (file)
@@ -6,6 +6,7 @@ set (HASH_INCLUDES
     hash_key_operations.h
     lru_cache_local.h
     lru_cache_shared.h
+    lru_segmented_cache_shared.h
     xhash.h
 )
 
index e4cff8b75bf661aec3cbb0798b857c3fdbee43e8..ef57c4ed4f13c561cf0d6666ab5ac1e2ef3d23a1 100644 (file)
@@ -31,4 +31,17 @@ other data. The utilization of this feature is optional.
 During initialization, the number of LRUs to be created can be specified. 
 If not specified, a single LRU will be created by default.
 
-
+Segmented Shared LRU Cache
+The SegmentedLruCache class is a layer built atop the existing 
+LruCacheShared class, designed to mitigate bottlenecks in 
+multi-threaded environments, thereby bolstering scalability. 
+Without altering the core caching logic, it divides the cache 
+into multiple segments, defaulting to four. This structure drastically 
+reduces contention among threads, allowing for improved performance. 
+The segmented approach is generic and configurable, enabling easy 
+adaptation for different modules while preserving the fundamental 
+LRU cache behavior. Through this strategic modification, 
+the pathway for enhanced scalability and future advancements is 
+significantly broadened, making the caching mechanism more robust 
+and adaptable to evolving computational demands.
+check host_attributes.cc for example usage.
\ No newline at end of file
diff --git a/src/hash/lru_segmented_cache_shared.h b/src/hash/lru_segmented_cache_shared.h
new file mode 100644 (file)
index 0000000..c6d4f96
--- /dev/null
@@ -0,0 +1,197 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2016-2023 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// lru_segmented_cache_shared.h author Raza Shafiq <rshafiq@cisco.com>
+
+#ifndef LRU_SEGMENTED_CACHE_SHARED_H
+#define LRU_SEGMENTED_CACHE_SHARED_H
+
+#include <cassert>
+#include <vector>
+
+#include "lru_cache_shared.h"
+
+#define DEFAULT_SEGMENT_COUNT 4
+
+template<typename Key, typename Value, typename Hash = std::hash<Key>, typename Eq = std::equal_to<Key>>
+class SegmentedLruCache 
+{
+public:
+
+    using LruCacheType = LruCacheShared<Key, Value, Hash, Eq>;
+    using Data = typename LruCacheType::Data;
+
+    SegmentedLruCache(const size_t initial_size, std::size_t segment_count = DEFAULT_SEGMENT_COUNT)
+        :segment_count(segment_count)
+    {
+        assert( segment_count > 0 && ( segment_count & (segment_count - 1)) == 0 );
+
+        segments.resize(segment_count);
+        for( auto& segment : segments )
+            segment = std::make_unique<LruCacheType>(initial_size/segment_count);
+
+        assert( segment_count == segments.size() );
+    }
+
+    virtual ~SegmentedLruCache() = default;
+
+    Data find(const Key& key)
+    {
+        std::size_t segment_idx = get_segment_idx(key);
+        return segments[segment_idx]->find(key);
+    }
+
+    Data operator[](const Key& key)
+    {
+        std::size_t segment_idx = get_segment_idx(key);
+        return (*segments[segment_idx])[key];
+    }
+
+    bool remove(const Key& key)
+    {
+        std::size_t segment_idx = get_segment_idx(key);
+        return segments[segment_idx]->remove(key);
+    }
+    
+    bool remove(const Key& key, Data& data)
+    {
+        std::size_t idx = get_segment_idx(key); 
+        return segments[idx]->remove(key, data);
+    }
+
+    Data find_else_create(const Key& key, bool* new_data)
+    {
+        std::size_t segment_idx = get_segment_idx(key);
+        return segments[segment_idx]->find_else_create(key, new_data);
+    }
+
+    bool find_else_insert(const Key& key, std::shared_ptr<Value>& data, bool replace = false)
+    {
+        std::size_t segment_idx = get_segment_idx(key);
+        return segments[segment_idx]->find_else_insert(key, data, replace);
+    }
+
+    std::shared_ptr<Value> find_else_insert(const Key& key, std::shared_ptr<Value>& data, LcsInsertStatus* status, bool replace = false)
+    {
+        std::size_t segment_idx = get_segment_idx(key);
+        return segments[segment_idx]->find_else_insert(key, data, status, replace);
+    }
+    
+    bool set_max_size(size_t max_size)
+    {
+        bool success = true;
+        size_t memcap_per_segment = max_size / segment_count;
+        for ( const auto& segment : segments )
+        {
+            if ( !segment->set_max_size(memcap_per_segment) )
+                success = false;
+        }
+        return success;
+    }
+
+    std::vector<std::pair<Key, std::shared_ptr<Value>>> get_all_data()
+    {
+        std::vector<std::pair<Key, std::shared_ptr<Value>>> all_data;
+
+        for ( const auto& cache : segments )
+        {
+            auto cache_data = cache->get_all_data();
+            all_data.insert(all_data.end(), cache_data.begin(), cache_data.end());
+        }
+        return all_data;
+    }
+
+    size_t mem_size()
+    {
+        size_t mem_size = 0;
+        for ( const auto& cache : segments )
+        {
+            mem_size += cache->mem_size();
+        }
+        return mem_size;
+    }
+
+    const PegInfo* get_pegs() 
+    { 
+        return lru_cache_shared_peg_names; 
+    }
+
+    PegCount* get_counts() 
+    {
+        PegCount* pcs = (PegCount*)&counts;
+        const PegInfo* pegs = get_pegs();
+
+        for ( int i = 0; pegs[i].type != CountType::END; i++ )
+        {
+            PegCount c = 0;
+            for ( const auto& cache : segments )
+            {
+                c += cache->get_counts()[i];
+            }
+            pcs[i] = c;
+        }
+        return (PegCount*)&counts;
+    }
+
+    size_t size() 
+    {
+        size_t total_size = 0;
+        for ( const auto& cache : segments ) 
+        {
+            total_size += cache->size();
+        }
+        return total_size;
+    }
+
+    size_t get_max_size()
+    {
+        size_t max_size = 0;
+        for ( const auto& cache : segments )
+        {
+            max_size += cache->get_max_size();
+        }
+        return max_size;
+    }
+
+    size_t get_segment_count() const
+    {
+        return segment_count;
+    }
+
+protected:
+    std::size_t segment_count = DEFAULT_SEGMENT_COUNT;
+
+private:
+    std::vector<std::unique_ptr<LruCacheType>> segments;
+    struct LruCacheSharedStats counts;
+
+    //derived class can implement their own get_segment_idx if needed
+    virtual std::size_t get_segment_idx(Key val)
+    {
+        if ( segment_count == 1 )
+            return 0;
+        const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&val);
+        uint8_t result = 0;
+        for ( size_t i = 0; i < sizeof(Key); ++i )
+            result ^= bytes[i];
+        // Assumes segment_count is a power of 2 always
+        // This is a fast way to do a modulo operation
+        return result & (segment_count - 1);
+    }
+};
+
+#endif // LRU_SEGMENTED_CACHE_SHARED_H
index d140fc3d98ec7c723756e5e1f40720113b5e956f..caf17394438b07eb8cd34ce40ea4369e573dbe09 100644 (file)
@@ -6,6 +6,11 @@ add_cpputest( lru_cache_shared_test
     SOURCES ../lru_cache_shared.cc
 )
 
+add_cpputest( lru_seg_cache_shared_test
+    SOURCES ../lru_segmented_cache_shared.h
+            ../lru_cache_shared.cc
+)
+
 add_cpputest( hash_lru_cache_test
     SOURCES ../hash_lru_cache.cc
 )
diff --git a/src/hash/test/lru_seg_cache_shared_test.cc b/src/hash/test/lru_seg_cache_shared_test.cc
new file mode 100644 (file)
index 0000000..9779e9e
--- /dev/null
@@ -0,0 +1,223 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2016-2023 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// lru_seg_cache_shared_test.cc author Raza Shafiq <rshafiq@cisco.com>
+
+#include "hash/lru_segmented_cache_shared.h"
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+TEST_GROUP(segmented_lru_cache)
+{
+};
+
+// Test SegmentedLruCache constructor and member access.
+TEST(segmented_lru_cache, constructor_test)
+{
+    SegmentedLruCache<int, std::string> lru_cache(8);
+
+    CHECK(lru_cache.get_max_size() == 8);
+    CHECK(lru_cache.size() == 0);
+}
+
+// Test SegmentedLruCache insert and find functions.
+TEST(segmented_lru_cache, insert_test)
+{
+    SegmentedLruCache<int, std::string> lru_cache(4);
+
+    auto data = lru_cache[0];
+    CHECK(data == lru_cache.find(0));
+    data->assign("zero");
+
+    data = lru_cache[1];
+    CHECK(data == lru_cache.find(1));
+    data->assign("one");
+
+    data = lru_cache[2];
+    CHECK(data == lru_cache.find(2));
+    data->assign("two");
+
+    data = lru_cache[0];
+    data->assign("new_zero");
+
+    CHECK(nullptr == lru_cache.find(3));
+
+    data = lru_cache[3];
+    CHECK(data == lru_cache.find(3));
+    data->assign("three");
+
+    const auto&& vec = lru_cache.get_all_data();
+    CHECK(vec.size() == 4);
+}
+
+// Test mem_size function
+TEST(segmented_lru_cache, mem_size_test)
+{
+    SegmentedLruCache<int, std::string> cache(5);
+    std::shared_ptr<std::string> data(new std::string("hello"));
+    cache.find_else_insert(1, data);
+
+    // Assuming each segment's metadata takes some space, hence checking for non-zero size
+    CHECK(0 != cache.mem_size());
+}
+
+// Test get_counts function
+TEST(segmented_lru_cache, get_counts_test)
+{
+    SegmentedLruCache<int, std::string> cache(5);
+    // Assuming get_counts function returns a non-null pointer
+    CHECK(nullptr != cache.get_counts());
+}
+
+// Test size function
+TEST(segmented_lru_cache, size_test)
+{
+    SegmentedLruCache<int, std::string> cache(5);
+    std::shared_ptr<std::string> data(new std::string("hello"));
+    cache.find_else_insert(1, data);
+
+    CHECK(1 == cache.size());
+}
+
+// Test set/get max size.
+TEST(segmented_lru_cache, max_size)
+{
+    SegmentedLruCache<int, std::string> lru_cache(16);
+
+    size_t sz = lru_cache.get_max_size();
+    CHECK(sz == 16);
+
+    CHECK(lru_cache.set_max_size(8) == true);
+    CHECK(lru_cache.get_max_size() == 8);
+}
+
+// Test the remove functions.
+TEST(segmented_lru_cache, remove_test)
+{
+    SegmentedLruCache<int, std::string> lru_cache(4);
+
+    for ( int i = 0; i < 4; i++ )
+    {
+        std::shared_ptr<std::string> data(new std::string(std::to_string(i)));
+        CHECK(false == lru_cache.find_else_insert(i, data));
+        CHECK(true == lru_cache.find_else_insert(i, data));
+        CHECK(true == lru_cache.remove(i));
+        CHECK(lru_cache.find(i) == nullptr);
+    }
+
+    CHECK(0 == lru_cache.size());
+
+    std::shared_ptr<std::string> data_ptr;
+    std::shared_ptr<std::string> data(new std::string("one"));
+    lru_cache.find_else_insert(1,data);
+    CHECK(1 == lru_cache.size());
+    CHECK(true == lru_cache.remove(1, data_ptr));
+    CHECK(*data_ptr == "one");
+    CHECK(0 == lru_cache.size());
+}
+
+// Test the find_else_insert function.
+TEST(segmented_lru_cache, find_else_insert)
+{
+    std::shared_ptr<std::string> data(new std::string("12345"));
+    SegmentedLruCache<int, std::string> lru_cache(8);
+
+    CHECK(false == lru_cache.find_else_insert(1, data));
+    CHECK(1 == lru_cache.size());
+
+    CHECK(true == lru_cache.find_else_insert(1, data));
+    CHECK(1 == lru_cache.size());
+}
+
+// Test 8 segments
+TEST (segmented_lru_cache, segements_8)
+{
+    SegmentedLruCache<int, std::string> lru_cache(1024,8);
+    std::shared_ptr<std::string> data(new std::string("12345"));
+
+    CHECK(false == lru_cache.find_else_insert(1, data));
+    CHECK(1 == lru_cache.size());
+
+    CHECK(true == lru_cache.find_else_insert(1, data));
+    CHECK(1 == lru_cache.size());
+
+    CHECK (8 == lru_cache.get_segment_count());
+
+}
+
+// Test statistics counters.
+TEST(segmented_lru_cache, stats_test)
+{
+    SegmentedLruCache<int, std::string> lru_cache(8);
+
+    for (int i = 0; i < 10; i++)
+        lru_cache[i];
+
+    lru_cache.find(7);
+    lru_cache.find(8);
+    lru_cache.find(9);
+    lru_cache.find(10);
+    lru_cache.find(11);
+
+    CHECK(lru_cache.set_max_size(16) == true);
+
+    lru_cache.remove(7);
+    lru_cache.remove(8);
+    lru_cache.remove(11); // not in cache
+
+    PegCount* stats = lru_cache.get_counts();
+
+    CHECK(stats[0] == 10);  // adds
+    CHECK(stats[1] == 2);   // alloc_prunes
+    CHECK(stats[2] == 0);   // bytes_in_use
+    CHECK(stats[3] == 0);   // items_in_use
+    CHECK(stats[4] == 3);   // find_hits
+    CHECK(stats[5] == 12);   // find_misses
+    CHECK(stats[6] == 0);   // reload_prunes
+    CHECK(stats[7] == 2);   // removes
+    CHECK(stats[8] == 0);   // replaced
+
+}
+
+// Test the find_else_insert method for item replacement
+TEST(segmented_lru_cache, find_else_insert_replace)
+{
+    SegmentedLruCache<int, std::string> lru_cache(8);
+    std::shared_ptr<std::string> data(new std::string("hello"));
+    LcsInsertStatus status;
+    
+    lru_cache.find_else_insert(1, data, &status, false);  // initial insert
+    CHECK(status == LcsInsertStatus::LCS_ITEM_INSERTED);  // Check status for initial insert
+    
+    std::shared_ptr<std::string> newData(new std::string("world"));
+    std::shared_ptr<std::string> returnedData = lru_cache.find_else_insert(1, newData, &status, true);  // replace existing item
+    CHECK(returnedData != nullptr);
+    CHECK(status == LcsInsertStatus::LCS_ITEM_REPLACED);  // Check status for item replacement
+    CHECK(*returnedData == "world");  // Check data for item replacement
+
+    returnedData = lru_cache.find_else_insert(1, newData, &status, false);  // attempt insert without replace flag
+    CHECK(returnedData != nullptr);
+    CHECK(status == LcsInsertStatus::LCS_ITEM_PRESENT);  // Check status for existing item
+    CHECK(*returnedData == "world");  // Data should remain unchanged
+}
+
+
+
+int main(int argc, char** argv)
+{
+    return CommandLineTestRunner::RunAllTests(argc, argv);
+}
index f77059b4ddc2b18ac6a04838e99562bb939c1017..7e057e23729bdc16cbd661551f0d2155e5fbc3e6 100644 (file)
@@ -60,7 +60,7 @@ public:
     void* get_user_data();
     void* get_user_data(const void* key, uint8_t type = 0);
     void release(uint8_t type = 0);
-    int release_node(const void* key, u_int8_t type = 0);
+    int release_node(const void* key, uint8_t type = 0);
     int release_node(HashNode* node, uint8_t type = 0);
     void* get_mru_user_data(uint8_t type = 0);
     void* get_lru_user_data(uint8_t type = 0);
index abb5633b2557de866ee58fbdb05560ffc499fdb2..9917e76f7eca125a0921c17129eefe45725c6436 100644 (file)
@@ -799,6 +799,9 @@ static const Parameter attribute_table_params[] =
     { "max_hosts", Parameter::PT_INT, "32:max53", "1024",
       "maximum number of hosts in attribute table" },
 
+    { "segments", Parameter::PT_INT, "1:32", "4",
+      "number of segments of hosts attribute table. It must be power of 2." },
+
     { "max_services_per_host", Parameter::PT_INT, "1:65535", "8",
       "maximum number of services per host entry in attribute table" },
 
@@ -830,6 +833,23 @@ bool AttributeTableModule::set(const char*, Value& v, SnortConfig* sc)
     else if ( v.is("max_hosts") )
         sc->max_attribute_hosts = v.get_uint32();
 
+    else if ( v.is("segments") )
+    {
+        auto segments = v.get_uint32();
+
+        if(segments > 32)
+            segments = 32;
+
+        if (segments == 0 || (segments & (segments - 1)) != 0)
+        {
+            uint8_t highestBitSet = 0;
+            while (segments >>= 1)
+                highestBitSet++;
+            segments = 1 << highestBitSet;
+            LogMessage("== WARNING: host attribute table segments count is not a power of 2. setting to %d\n", segments);
+        }
+        sc->segment_count_host = segments;
+    }
     else if ( v.is("max_services_per_host") )
         sc->max_attribute_services_per_host = v.get_uint16();
 
index 16017df686683bb23d0ca83b2bf43b4181cc8ae5..c348a2ffb32f5de018537b2a5a22faf0869d7b01 100644 (file)
@@ -286,6 +286,7 @@ public:
     uint32_t max_attribute_hosts = 0;
     uint32_t max_attribute_services_per_host = 0;
     uint32_t max_metadata_services = 0;
+    uint32_t segment_count_host = 4;
 
     //------------------------------------------------------
     // packet module stuff
@@ -515,7 +516,10 @@ public:
     { return run_flags & RUN_FLAG__READ; }
 
     bool ips_inline_mode() const
-    { return get_ips_policy()->policy_mode == POLICY_MODE__INLINE; }
+    {   
+        // cppcheck-suppress nullPointer
+        return get_ips_policy()->policy_mode == POLICY_MODE__INLINE; 
+    }
 
     bool ips_inline_test_mode() const
     { return get_ips_policy()->policy_mode == POLICY_MODE__INLINE_TEST; }
@@ -626,6 +630,9 @@ public:
     uint32_t get_max_attribute_hosts() const
     { return max_attribute_hosts; }
 
+    uint32_t get_segment_count_host() const
+    { return segment_count_host; }
+
     uint32_t get_max_services_per_host() const
     { return max_attribute_services_per_host; }
 
index ab22aafdf781398bb24f831dcf370b856bc58318..6d2d4dcb7007254a12a44cf7e53e1b970327fc39 100644 (file)
@@ -25,7 +25,7 @@
 
 #include "host_attributes.h"
 
-#include "hash/lru_cache_shared.h"
+#include "hash/lru_segmented_cache_shared.h"
 #include "main/reload_tuner.h"
 #include "main/shell.h"
 #include "main/snort.h"
@@ -46,14 +46,16 @@ static const PegInfo host_attribute_pegs[] =
 };
 
 template<typename Key, typename Value, typename Hash>
-class HostLruSharedCache : public LruCacheShared<Key, Value, Hash>
+class HostLruSegmentedCache : public SegmentedLruCache<Key, Value, Hash> 
 {
 public:
-    HostLruSharedCache(const size_t initial_size) : LruCacheShared<Key, Value, Hash>(initial_size)
-    { }
+
+    HostLruSegmentedCache(const size_t initial_size, std::size_t seg_count = DEFAULT_SEGMENT_COUNT)
+        : SegmentedLruCache<Key, Value, Hash>(initial_size, seg_count)
+      { } 
 };
 
-typedef HostLruSharedCache<snort::SfIp, HostAttributesDescriptor, HostAttributesCacheKey> HostAttributesSharedCache;
+typedef HostLruSegmentedCache<snort::SfIp, HostAttributesDescriptor, HostAttributesCacheKey> HostAttributesSegmentedCache;
 
 class HostAttributesReloadTuner : public snort::ReloadResourceTuner
 {
@@ -73,10 +75,10 @@ public:
     { return true; }
 };
 
-static THREAD_LOCAL HostAttributesSharedCache* active_cache = nullptr;
-static HostAttributesSharedCache* swap_cache = nullptr;
-static HostAttributesSharedCache* next_cache = nullptr;
-static HostAttributesSharedCache* old_cache = nullptr;
+static THREAD_LOCAL HostAttributesSegmentedCache* active_cache = nullptr;
+static HostAttributesSegmentedCache* swap_cache = nullptr;
+static HostAttributesSegmentedCache* next_cache = nullptr;
+static HostAttributesSegmentedCache* old_cache = nullptr;
 static THREAD_LOCAL HostAttributeStats host_attribute_stats;
 
 bool HostAttributesDescriptor::update_service
@@ -140,7 +142,7 @@ void HostAttributesDescriptor::get_host_attributes(uint16_t port,HostAttriInfo*
 bool HostAttributesManager::load_hosts_file(snort::SnortConfig* sc, const char* fname)
 {
     delete next_cache;
-    next_cache = new HostAttributesSharedCache(sc->max_attribute_hosts);
+    next_cache = new HostAttributesSegmentedCache(sc->max_attribute_hosts, sc->segment_count_host);
 
     Shell sh(fname);
     if ( sh.configure(sc, true) )
@@ -157,7 +159,7 @@ bool HostAttributesManager::load_hosts_file(snort::SnortConfig* sc, const char*
 bool HostAttributesManager::add_host(HostAttributesEntry host, snort::SnortConfig* sc)
 {
     if ( !next_cache )
-        next_cache = new HostAttributesSharedCache(sc->max_attribute_hosts);
+        next_cache = new HostAttributesSegmentedCache(sc->max_attribute_hosts, sc->segment_count_host);
 
     return next_cache->find_else_insert(host->get_ip_addr(), host, true);
 }