]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3942: host_cache: segmented host cache
authorRaza Shafiq (rshafiq) <rshafiq@cisco.com>
Fri, 1 Sep 2023 20:02:20 +0000 (20:02 +0000)
committerSteven Baigal (sbaigal) <sbaigal@cisco.com>
Fri, 1 Sep 2023 20:02:20 +0000 (20:02 +0000)
Merge in SNORT/snort3 from ~RSHAFIQ/snort3:host_cache_locking to master

Squashed commit of the following:

commit e642b5dcfbc6a48be841676c6a9e77f2a8788dd3
Author: rshafiq <rshafiq@cisco.com>
Date:   Thu Jul 27 08:43:35 2023 -0400

    host_cache: added segmented host cache

27 files changed:
src/host_tracker/CMakeLists.txt
src/host_tracker/cache_allocator.cc
src/host_tracker/cache_allocator.h
src/host_tracker/dev_notes.txt
src/host_tracker/host_cache.cc
src/host_tracker/host_cache.h
src/host_tracker/host_cache_module.cc
src/host_tracker/host_cache_module.h
src/host_tracker/host_cache_segmented.h [new file with mode: 0644]
src/host_tracker/host_tracker.cc
src/host_tracker/host_tracker.h
src/host_tracker/host_tracker_module.cc
src/host_tracker/host_tracker_module.h
src/host_tracker/test/CMakeLists.txt
src/host_tracker/test/cache_allocator_test.cc
src/host_tracker/test/host_cache_allocator_ht_test.cc
src/host_tracker/test/host_cache_module_test.cc
src/host_tracker/test/host_cache_segmented_test.cc [new file with mode: 0644]
src/host_tracker/test/host_tracker_module_test.cc
src/host_tracker/test/host_tracker_test.cc
src/main/snort.cc
src/main/snort_config.cc
src/network_inspectors/appid/appid_discovery.cc
src/network_inspectors/appid/appid_module.cc
src/network_inspectors/appid/test/appid_discovery_test.cc
src/network_inspectors/rna/rna_flow.cc
src/network_inspectors/rna/rna_module.cc

index 56dfa68067335dd265edd11e15e17202d18753fd..8625723b52015d0e073bd49380861ce61823894e 100644 (file)
@@ -2,6 +2,7 @@ set (HOST_TRACKER_INCLUDES
     cache_allocator.h
     cache_interface.h
     host_cache.h
+    host_cache_segmented.h
     host_tracker.h
 )
 
@@ -11,6 +12,7 @@ add_library( host_tracker OBJECT
     host_cache.cc
     host_cache_module.cc
     host_cache_module.h
+    host_cache_segmented.h
     host_tracker_module.cc
     host_tracker_module.h
     host_tracker.cc
index f3e31fee594660be75e70386bbf51875ea0246e5..2f046532bd35ecef65f2b5624f0df3e63ee45acb 100644 (file)
 #define CACHE_ALLOCATOR_CC
 
 #include "host_cache.h"
+#include "host_cache_segmented.h"
 
 template <class T>
 HostCacheAllocIp<T>::HostCacheAllocIp()
 {
-    lru = &host_cache;
+    lru = &default_host_cache;
 }
 
 #endif
index cc8f5410c5a143b80ea5c28f075273ce8a754ea4..ae392ad4d1477c6dbdb2a55bf74a3a7e9043be31 100644 (file)
@@ -38,6 +38,8 @@ public:
 
     T* allocate(std::size_t n);
     void deallocate(T* p, std::size_t n) noexcept;
+    void set_lru(CacheInterface* c) { lru = c; }
+    CacheInterface* get_lru() const { return lru; }
 
 protected:
 
@@ -73,6 +75,7 @@ class HostCacheAllocIp : public CacheAlloc<T>
 {
 public:
 
+    using Base = CacheAlloc<T>;
     // This needs to be in every derived class:
     template <class U>
     struct rebind
@@ -82,6 +85,21 @@ public:
 
     using CacheAlloc<T>::lru;
 
+    void set_cache(CacheInterface* hci) { Base::set_lru(hci); }
+    CacheInterface* get_cache_ptr() { return Base::get_lru(); }
+
+    template <class U>
+    HostCacheAllocIp(const HostCacheAllocIp<U>& other) 
+    {
+        this->lru = other.get_lru();
+    }
+
+    template <class U>
+    HostCacheAllocIp(HostCacheAllocIp<U>&& other)  noexcept 
+    {
+        this->lru = other.get_lru();
+    }
+
     HostCacheAllocIp();
 
 };
index 1a9a85361b48d9fd12a7686851ea525bf610da1c..5f659bf74e04ae408fb485dfa9a4a3022ecdf85d 100644 (file)
@@ -131,3 +131,53 @@ host_tracker.h and host_tracker.cc.
 Illustrative examples are test/cache_allocator_test.cc (standalone
 host cache / allocator example) and test/host_cache_allocator_ht_test.cc
 (host_cache / allocator with host tracker example).
+
+13/08/2023
+
+To address the issue of contention due to mutex locks when Snort is configured
+to run a large number (over 100) of threads with a single host_cache, 
+we introduced a new layer: "host_cache_segmented". This layer operates on 
+multiple cache segments, thus significantly reducing the locking contention 
+that was previously observed.
+
+The segmented host cache is not a replacement but rather an enhancement layer 
+above the existing host_cache. With this architecture, there can be more than 
+one host_cache, now referred to as a "segment". Each segment functions 
+as an LRU cache, just like the previous singular host_cache. Importantly, 
+there has been no change in the LRU cache design or its logic.
+
+Whenever a new key-data pair is added to a segment, its allocator needs updating. 
+This ensures that the peg counts and visibility metrics are accurate for 
+that specific segment. The find_else_create method of the segmented cache 
+takes care of this, ensuring that each key-data pair is correctly 
+associated with its segment.
+
+Each of these cache segments can operate independently, allowing for more 
+efficient parallel processing. This not only reduces the time threads spend 
+waiting for locks but also better utilizes multi-core systems by allowing 
+simultaneous read and write operations in different cache segments.
+
+The number of segments and the memcap are both configurable, providing flexibility 
+for tuning based on the specific requirements of the deployment environment 
+and the workload. Furthermore, this segmented approach scales well with the 
+increase in the number of threads, making it a robust solution for high-performance, 
+multi-threaded environments.
+
+In summary, the introduction of the "host_cache_segmented" layer represents 
+a significant step forward in the performance and scalability of Snort in 
+multi-threaded environments. This enhancement not only provides immediate benefits 
+in terms of improved throughput but also paves the way for further performance 
+optimizations in the future.
+                         +-----------------+
+                         |   Snort Threads |
+                         +-----------------+
+                                 |
+                                 v
+                    +-------------------------------+
+                    | Host Cache Segmented Layer    |
+                    +-------------------------------+
+                                 |
+                                 v
+            +-------------------------------------------------+
+            | Cache Segment 1 | Cache Segment 2 |   ...       |
+            +-------------------------------------------------+
\ No newline at end of file
index 61dcf509999f16eabd6a6f5352001210a2e10a97..e4e7f44773f2db7ad5e0dac31cf3956fbca00300 100644 (file)
 #endif
 
 #include "host_cache.h"
+#include "host_cache_segmented.h"
 
 using namespace snort;
 
-// Default host cache size in bytes.
-// Must agree with default memcap in host_cache_module.cc.
-#define LRU_CACHE_INITIAL_SIZE 16384 * 512
-
-HostCacheIp host_cache(LRU_CACHE_INITIAL_SIZE);
+HostCacheIp default_host_cache(LRU_CACHE_INITIAL_SIZE);
+HostCacheSegmentedIp host_cache;
index e68837c3ab6302d44abad94d6d6bd6fa82b25bb7..06f81f977a89c00aee92c1d3dc4aaa76b1131e3f 100644 (file)
@@ -36,6 +36,9 @@
 #include "cache_allocator.h"
 #include "cache_interface.h"
 
+// Default host cache size in bytes.
+#define LRU_CACHE_INITIAL_SIZE 8388608 // 8 MB
+
 // Used to create hash of key for indexing into cache.
 //
 // Note that both HashIp and IpEqualTo below ignore the IP family.
@@ -276,6 +279,5 @@ public:
     }
 };
 
-extern SO_PUBLIC HostCacheIp host_cache;
 
 #endif
index 1bcc2c2162b714e3695f726a4cb799eee5839968..f211fbd39327d128b737817d605a790c2f0b9df1 100644 (file)
@@ -32,6 +32,7 @@
 #include "log/messages.h"
 #include "managers/module_manager.h"
 #include "utils/util.h"
+#include "host_cache_segmented.h"
 
 using namespace snort;
 using namespace std;
@@ -62,6 +63,20 @@ static int host_cache_get_stats(lua_State* L)
     return 0;
 }
 
+static int host_cache_get_segment_stats(lua_State* L)
+{
+    HostCacheModule* mod = (HostCacheModule*) ModuleManager::get_module(HOST_CACHE_NAME);
+
+    if ( mod )
+    {
+        int seg_idx = luaL_optint(L, 1, -1);
+        ControlConn* ctrlcon = ControlConn::query_from_lua(L);
+        string outstr = mod->get_host_cache_segment_stats(seg_idx);
+        ctrlcon->respond("%s", outstr.c_str());
+    }
+    return 0;
+}
+
 static int host_cache_delete_host(lua_State* L)
 {
     HostCacheModule* mod = (HostCacheModule*) ModuleManager::get_module(HOST_CACHE_NAME);
@@ -274,6 +289,12 @@ static const Parameter host_cache_stats_params[] =
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
+static const Parameter host_cache_segment_stats_params[] =
+{
+    { "segment", Parameter::PT_INT, nullptr, nullptr, "segment number for stats" },
+    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
 static const Parameter host_cache_delete_host_params[] =
 {
     { "host_ip", Parameter::PT_STRING, nullptr, nullptr, "ip address to delete" },
@@ -324,6 +345,7 @@ static const Command host_cache_cmds[] =
     { "delete_client", host_cache_delete_client,
       host_cache_delete_client_params, "delete client from host"},
     { "get_stats", host_cache_get_stats, host_cache_stats_params, "get current host cache usage and pegs"},
+    { "get_segment_stats", host_cache_get_segment_stats, host_cache_segment_stats_params, "get usage and pegs for cache segment(s)"},
     { nullptr, nullptr, nullptr, nullptr }
 };
 
@@ -343,6 +365,9 @@ static const Parameter host_cache_params[] =
 
     { "memcap", Parameter::PT_INT, "512:maxSZ", "8388608",
       "maximum host cache size in bytes" },
+    
+    { "segments", Parameter::PT_INT, "1:32", "4",
+      "number of host cache segments. It must be power of 2."},
 
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
@@ -354,8 +379,25 @@ bool HostCacheModule::set(const char*, Value& v, SnortConfig*)
         dump_file = v.get_string();
     }
     else if ( v.is("memcap") )
+    {
         memcap = v.get_size();
+    }
+    else if ( v.is("segments"))
+    {
+        segments = v.get_uint8();
+        
+        if(segments > 32)
+            segments = 32;
 
+        if (segments == 0 || (segments & (segments - 1)) != 0)
+        {
+            uint8_t highestBitSet = 0;
+            while (segments >>= 1)
+                highestBitSet++;
+            segments = 1 << highestBitSet;
+            LogMessage("== WARNING: host_cache segments is not the power of 2. setting to %d\n", segments);
+        }
+    }
     return true;
 }
 
@@ -366,8 +408,8 @@ bool HostCacheModule::end(const char* fqn, int, SnortConfig* sc)
         if ( Snort::is_reloading() )
             sc->register_reload_handler(new HostCacheReloadTuner(memcap));
         else
-        {
-            host_cache.set_max_size(memcap);
+        {   
+            host_cache.setup(segments, memcap);
             ControlConn::log_command("host_cache.delete_host",false);
         }
     }
@@ -438,20 +480,93 @@ void HostCacheModule::log_host_cache(const char* file_name, bool verbose)
 }
 
 
+string HostCacheModule::get_host_cache_segment_stats(int seg_idx)
+{
+
+    if(seg_idx >= host_cache.get_segments())
+        return "Invalid segment index\nTry host_cache.get_segment_stats() to get all stats\n";
+    
+    string str;
+    const PegInfo* pegs = host_cache.get_pegs();
+
+    if(seg_idx == -1)
+    {
+        const auto&& lru_data = host_cache.get_all_data();
+        str = "Total host cache size: " + to_string(host_cache.mem_size()) + " bytes, "
+            + to_string(lru_data.size()) + " trackers, memcap: " + to_string(host_cache.get_max_size())
+            + " bytes\n";
+
+        for(auto cache : host_cache.seg_list)
+        {
+            cache->lock();
+            cache->stats.bytes_in_use = cache->current_size;
+            cache->stats.items_in_use = cache->list.size();
+            cache->unlock();
+        }
+        
+        PegCount* counts = (PegCount*) host_cache.get_counts();
+        for ( int i = 0; pegs[i].type != CountType::END; i++ )
+        {
+            if ( counts[i] )
+            {
+                str += pegs[i].name;
+                str += ": " + to_string(counts[i]) + "\n" ;
+            }
+        }
+    }
+
+
+    str += "\n";
+    str += "total cache segments: " + to_string(host_cache.seg_list.size()) + "\n";
+    int idx = -1;
+    for( auto cache : host_cache.seg_list)
+    {
+        idx++;
+        if(seg_idx != -1 && seg_idx != idx)
+            continue;
+
+        str += "Segment " + to_string(idx) + ":\n";
+        const auto&& lru_data = cache->get_all_data();
+        str += "Current host cache size: " + to_string(cache->mem_size()) + " bytes, "
+            + to_string(lru_data.size()) + " trackers, memcap: " + to_string(cache->get_max_size())
+            + " bytes\n";
+
+        cache->lock();
+        cache->stats.bytes_in_use = cache->current_size;
+        cache->stats.items_in_use = cache->list.size();
+        cache->unlock();
+
+        PegCount* count = (PegCount*) cache->get_counts();
+        for ( int i = 0; pegs[i].type != CountType::END; i++ )
+        {
+            if ( count[i] )
+            {
+                str += pegs[i].name;
+                str += ": " + to_string(count[i]) + "\n" ;
+            }
+        }
+        str += "\n";
+    }
+    return str;
+}
+
 string HostCacheModule::get_host_cache_stats()
 {
     string str;
 
     const auto&& lru_data = host_cache.get_all_data();
     str = "Current host cache size: " + to_string(host_cache.mem_size()) + " bytes, "
-        + to_string(lru_data.size()) + " trackers, memcap: " + to_string(host_cache.max_size)
+        + to_string(lru_data.size()) + " trackers, memcap: " + to_string(host_cache.get_max_size())
         + " bytes\n";
 
-    host_cache.lock();
-
-    host_cache.stats.bytes_in_use = host_cache.current_size;
-    host_cache.stats.items_in_use = host_cache.list.size();
-
+    for(auto cache : host_cache.seg_list)
+    {
+        cache->lock();
+        cache->stats.bytes_in_use = cache->current_size;
+        cache->stats.items_in_use = cache->list.size();
+        cache->unlock();
+    }
+    
     PegCount* counts = (PegCount*) host_cache.get_counts();
     const PegInfo* pegs = host_cache.get_pegs();
 
@@ -465,7 +580,6 @@ string HostCacheModule::get_host_cache_stats()
 
     }
 
-    host_cache.unlock();
 
     return str;
 }
@@ -478,14 +592,17 @@ PegCount* HostCacheModule::get_counts() const
 
 void HostCacheModule::sum_stats(bool dump_stats)
 {
-    host_cache.lock();
     // These could be set in prep_counts but we set them here
     // to save an extra cache lock.
-    host_cache.stats.bytes_in_use = host_cache.current_size;
-    host_cache.stats.items_in_use = host_cache.list.size();
+    for(auto cache : host_cache.seg_list)
+    {
+        cache->lock();
+        cache->stats.bytes_in_use = cache->current_size;
+        cache->stats.items_in_use = cache->list.size();
+        cache->unlock();
+    }
 
     Module::sum_stats(dump_stats);
-    host_cache.unlock();
 }
 
 void HostCacheModule::set_trace(const Trace* trace) const
index 258f6b8f09b4cc39b6db25a6f6e600bc7ec1b051..09fd1d1fc8024a9aabf3e68338fd3e2d4eff0a25 100644 (file)
@@ -31,6 +31,7 @@
 #include "trace/trace_api.h"
 
 #include "host_cache.h"
+#include "host_cache_segmented.h"
 
 #define HOST_CACHE_NAME "host_cache"
 #define HOST_CACHE_HELP "global LRU cache of host_tracker data about hosts"
@@ -75,6 +76,7 @@ public:
 
     void log_host_cache(const char* file_name, bool verbose = false);
     std::string get_host_cache_stats();
+    std::string get_host_cache_segment_stats(int seg_idx);
 
     void set_trace(const snort::Trace*) const override;
     const snort::TraceOption* get_trace_options() const override;
@@ -82,6 +84,7 @@ public:
 private:
     std::string dump_file;
     size_t memcap = 0;
+    uint8_t segments = 1;
 };
 extern THREAD_LOCAL const snort::Trace* host_cache_trace;
 
diff --git a/src/host_tracker/host_cache_segmented.h b/src/host_tracker/host_cache_segmented.h
new file mode 100644 (file)
index 0000000..e417435
--- /dev/null
@@ -0,0 +1,381 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2015-2023 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// host_cache_segmented.h author Raza Shafiq <rshafiq@cisco.com>
+
+#ifndef HOST_CACHE_SEGMENTED_H
+#define HOST_CACHE_SEGMENTED_H
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <atomic>
+#include <cassert>
+
+#include "host_cache.h"
+#include "log/messages.h"
+
+#define DEFAULT_HOST_CACHE_SEGMENTS 4
+
+extern SO_PUBLIC HostCacheIp default_host_cache;
+
+template<typename Key, typename Value>
+class HostCacheSegmented
+{
+public:
+    HostCacheSegmented() : 
+        segment_count(DEFAULT_HOST_CACHE_SEGMENTS), 
+        memcap_per_segment(LRU_CACHE_INITIAL_SIZE) { }
+    HostCacheSegmented(uint8_t segment_count, size_t memcap_per_segment);
+
+    void init();
+    void term();
+    void setup(uint8_t , size_t );
+
+    const PegInfo* get_pegs() { return lru_cache_shared_peg_names; }
+    size_t get_memcap_per_segment() { return memcap_per_segment.load(); }
+    size_t get_valid_id(uint8_t idx);
+    uint8_t get_segments() { return segment_count; }
+    size_t get_max_size();
+    size_t get_mem_chunk();
+    PegCount* get_counts();
+
+    void set_segments(uint8_t segments) { segment_count = segments; }
+    void print_config();
+    bool set_max_size(size_t max_size);
+    bool reload_resize(size_t memcap_per_segment);
+    bool reload_prune(size_t new_size, unsigned max_prune);
+    void invalidate();
+
+    std::shared_ptr<Value> operator[](const Key& key);
+
+    uint8_t get_segment_idx(Key val);
+    std::shared_ptr<Value> find(const Key& key);
+    std::shared_ptr<Value> find_else_create(const Key& key, bool* new_data);
+    std::vector<std::pair<Key, std::shared_ptr<Value>>> get_all_data();
+    bool find_else_insert(const Key& key, std::shared_ptr<Value>& value);
+    bool remove(const Key& key);
+    bool remove(const Key& key, typename LruCacheSharedMemcap
+        <snort::SfIp, snort::HostTracker, HashIp, IpEqualTo, HTPurgatory>::Data& data);
+    size_t mem_size();
+
+    std::vector<HostCacheIp*> seg_list;
+    HostCacheIp* default_cache = &default_host_cache; // Default cache used for host tracker
+
+private:
+    void update_counts();
+
+    uint8_t segment_count;
+    std::atomic<size_t> memcap_per_segment;
+    struct LruCacheSharedStats counts;
+    bool init_done = false;
+};
+
+
+template<typename Key, typename Value>
+HostCacheSegmented<Key, Value>::HostCacheSegmented(uint8_t segment_count, size_t memcap_per_segment) :
+    segment_count(segment_count),
+    memcap_per_segment(memcap_per_segment)
+{
+    assert(segment_count > 0);
+    
+    for (size_t i = 0; i < this->segment_count; ++i)
+    {
+        auto cache = new HostCacheIp(this->memcap_per_segment);
+        seg_list.emplace_back((HostCacheIp*)cache);
+    }
+    init_done = true;
+}
+
+template<typename Key, typename Value>
+void HostCacheSegmented<Key, Value>::init()
+{
+    if(init_done or seg_list.size() >= segment_count)
+        return;
+
+    assert(segment_count > 0);
+    
+    for (size_t i = 0; i < segment_count; ++i)
+    {
+        auto cache = new HostCacheIp(memcap_per_segment.load());
+        seg_list.emplace_back((HostCacheIp*)cache);
+    }
+    init_done = true;
+}
+
+template<typename Key, typename Value>
+void HostCacheSegmented<Key, Value>::term()
+{
+    for (auto cache : seg_list)
+    {
+        if (cache)
+            delete cache;
+    }
+}
+
+template<typename Key, typename Value>
+void HostCacheSegmented<Key, Value>::setup(uint8_t segs, size_t memcap )
+{
+    assert(segment_count > 0);
+
+    segment_count = segs;
+    memcap_per_segment = memcap/segs;
+    set_max_size(memcap);
+}
+
+template<typename Key, typename Value>
+size_t HostCacheSegmented<Key, Value>::get_valid_id(uint8_t idx) 
+{ 
+    if(idx < seg_list.size())
+        return seg_list[idx]->get_valid_id();
+    return 0;
+}
+
+template<typename Key, typename Value>
+void HostCacheSegmented<Key, Value>::print_config() 
+{ 
+    if ( snort::SnortConfig::log_verbose() )
+    {
+        snort::LogLabel("host_cache");
+        snort::LogMessage("    memcap: %zu bytes\n", get_max_size());
+    }
+}
+
+template<typename Key, typename Value>
+std::shared_ptr<Value> HostCacheSegmented<Key, Value>::operator[](const Key& key)
+{
+    return find_else_create(key, nullptr);
+}
+
+/**
+ * Sets the maximum size for the entire cache, which is distributed equally
+ * among all the segments.
+ */
+template<typename Key, typename Value>
+bool HostCacheSegmented<Key, Value>::set_max_size(size_t max_size)
+{
+    bool success = true;
+    memcap_per_segment = max_size/segment_count;
+    for (auto cache : seg_list)
+    {
+        if (!cache->set_max_size(memcap_per_segment))
+            success = false;
+    }
+    return success;
+}
+
+/**
+ * Resize the cache based on the provided memory capacity, distributing the 
+ * memory equally among all the segments. If any segment fails to resize,
+ * the operation is considered unsuccessful.
+ */
+template<typename Key, typename Value>
+bool HostCacheSegmented<Key, Value>::reload_resize(size_t memcap)
+{
+    bool success = true;
+    memcap_per_segment = memcap/segment_count;
+    for (auto cache : seg_list)
+    {
+        if (!cache->reload_resize(memcap_per_segment.load()))
+            success = false;
+    }
+    return success;
+}
+
+// Computes the index of the segment where a given key-value pair belongs.
+template<typename Key, typename Value>
+uint8_t HostCacheSegmented<Key, Value>::get_segment_idx(Key val) 
+{
+    const uint8_t* bytes = reinterpret_cast<const uint8_t*>(&val);
+    uint8_t result = 0;
+    for (size_t i = 0; i < sizeof(Key); ++i) 
+        result ^= bytes[i];
+    //Assumes segment_count is a power of 2 always
+    //This is a fast way to do a modulo operation
+    return result & (segment_count - 1);
+}
+
+//Retrieves all the data stored across all the segments of the cache.
+template<typename Key, typename Value>
+std::vector<std::pair<Key, std::shared_ptr<Value>>>  HostCacheSegmented<Key,Value>::get_all_data()
+{
+    std::vector<std::pair<Key, std::shared_ptr<Value>>> all_data;
+
+    for (auto cache : seg_list)
+    {
+        auto cache_data = cache->get_all_data();
+        all_data.insert(all_data.end(), cache_data.begin(), cache_data.end());
+    }
+    return all_data;
+}
+
+template<typename Key, typename Value>
+std::shared_ptr<Value> HostCacheSegmented<Key, Value>::find(const Key& key)
+{
+    uint8_t idx = get_segment_idx(key);
+    return seg_list[idx]->find(key);
+}
+
+/**
+ * Updates the internal counts of the host cache. This method aggregates the 
+ * counts from all segments and updates the overall counts for the cache.
+ */
+template<typename Key, typename Value>
+void HostCacheSegmented<Key, Value>::update_counts()
+{
+    PegCount* pcs = (PegCount*)&counts;
+    const PegInfo* pegs = get_pegs();
+
+    for ( int i = 0; pegs[i].type != CountType::END; i++ )
+    {
+        PegCount c = 0;
+        for(auto cache : seg_list)
+        {
+            c += cache->get_counts()[i];
+        }
+        pcs[i] = c;
+    }
+}
+
+template<typename Key, typename Value>
+std::shared_ptr<Value> HostCacheSegmented<Key, Value>:: find_else_create(const Key& key, bool* new_data)
+{
+    // Determine the segment index where the key-value pair resides or should reside
+    uint8_t idx = get_segment_idx(key);
+    bool new_data_local = false;
+    
+    // Retrieve or create the entry for the key in the determined segment
+    auto ht = seg_list[idx]->find_else_create(key, &new_data_local);
+    if(new_data_local)
+    {
+        // If a new entry was created, update its cache interface and visibility
+        ht->update_cache_interface(idx);
+        ht->init_visibility(seg_list[idx]->get_valid_id());
+    }
+    if(new_data)
+        *new_data = new_data_local;
+    return ht;
+}
+
+template<typename Key, typename Value>
+bool HostCacheSegmented<Key, Value>::find_else_insert(const Key& key, std::shared_ptr<Value>& value)
+{
+    uint8_t idx = get_segment_idx(key);
+    return seg_list[idx]->find_else_insert(key, value, false);
+}
+
+template<typename Key, typename Value>
+PegCount* HostCacheSegmented<Key, Value>::get_counts()
+{
+    if(init_done)
+        update_counts();
+    return (PegCount*)&counts;
+}
+
+template<typename Key, typename Value>
+void HostCacheSegmented<Key, Value>::invalidate()
+{
+    for( auto cache: seg_list)
+    {
+        cache->invalidate();
+    }
+}
+
+template<typename Key, typename Value>
+bool HostCacheSegmented<Key, Value>::reload_prune(size_t new_size, unsigned max_prune)
+{
+    bool success = true;
+    memcap_per_segment = new_size/segment_count;
+    for (auto cache : seg_list)
+    {
+        if (!cache->reload_prune(memcap_per_segment, max_prune))
+            success = false;
+    }
+    return success;
+}
+
+template<typename Key, typename Value>
+size_t HostCacheSegmented<Key, Value>::mem_size()
+{
+    size_t mem_size = 0;
+    for (auto cache : seg_list)
+    {
+        if(cache)
+            mem_size += cache->mem_size();
+    }
+    return mem_size;
+}
+
+template<typename Key, typename Value>
+size_t HostCacheSegmented<Key, Value>::get_max_size()
+{
+    size_t max_size = 0;
+    for (auto cache : seg_list)
+    {
+        max_size += cache->get_max_size();
+    }
+    return max_size;
+}
+
+template<typename Key, typename Value>
+size_t HostCacheSegmented<Key, Value>::get_mem_chunk()
+{
+    //Assumes all segments have the same mem_chunk
+    return seg_list[0]->mem_chunk;
+}
+
+template<typename Key, typename Value>
+bool HostCacheSegmented<Key, Value>::remove(const Key& key) 
+{
+    uint8_t idx = get_segment_idx(key);
+    return seg_list[idx]->remove(key);
+}
+
+template<typename Key, typename Value>
+bool HostCacheSegmented<Key, Value>::remove(const Key& key, typename LruCacheSharedMemcap<snort::SfIp, snort::HostTracker, HashIp, IpEqualTo, HTPurgatory>::Data& data)
+{
+    uint8_t idx = get_segment_idx(key);
+    return seg_list[idx]->remove(key, data);
+}
+
+/*
+Warning!!!: update_allocator and update_set_allocator don't copy data to old container
+but erase it for speed. Use with care!!!
+*/
+template <template <typename, typename...> class Container, typename T, typename Alloc>
+void update_allocator(Container<T, Alloc>& cont, CacheInterface* new_lru) 
+{
+    Alloc new_allocator;
+    new_allocator.set_cache(new_lru);
+    cont = std::move(Container<T, Alloc>(new_allocator));
+} 
+
+template <template <typename, typename, typename...> class Container, typename T, typename Comp, typename Alloc>
+void update_set_allocator(Container<T, Comp, Alloc>& cont, CacheInterface* new_lru) 
+{
+    Alloc new_allocator;
+    new_allocator.set_cache(new_lru);
+    cont = std::move(Container<T, Comp, Alloc> (new_allocator)); 
+}
+
+
+typedef HostCacheSegmented<snort::SfIp, snort::HostTracker> HostCacheSegmentedIp;
+extern SO_PUBLIC HostCacheSegmentedIp host_cache;
+
+#endif // HOST_CACHE_SEGMENTED_H
+
index afa93b3ca40b76d9efa6f51a9b2ea0ee99447e07..d890b0fa705b7642152553509fa8824a119f8f4f 100644 (file)
@@ -30,6 +30,7 @@
 
 #include "cache_allocator.cc"
 #include "host_cache.h"
+#include "host_cache_segmented.h"
 #include "host_tracker.h"
 
 using namespace snort;
@@ -43,11 +44,12 @@ THREAD_LOCAL struct HostTrackerStats host_tracker_stats;
 const uint8_t snort::zero_mac[MAC_SIZE] = {0, 0, 0, 0, 0, 0};
 
 
+
 HostTracker::HostTracker() : hops(-1)
 {
     last_seen = nat_count_start = (uint32_t) packet_time();
     last_event = -1;
-    visibility = host_cache.get_valid_id();
+    visibility = host_cache.get_valid_id(0);
 }
 
 void HostTracker::update_last_seen()
@@ -816,10 +818,8 @@ bool HostTracker::add_cpe_os_hash(uint32_t hash)
 
 bool HostTracker::set_visibility(bool v)
 {
-    // get_valid_id may use its own lock, so get this outside our lock
-    size_t container_id = host_cache.get_valid_id();
-
     std::lock_guard<std::mutex> lck(host_tracker_lock);
+    size_t container_id = host_cache.get_valid_id(cache_idx);
     size_t old_visibility = visibility;
 
     visibility = v ? container_id : HostCacheIp::invalid_id;
@@ -869,7 +869,7 @@ bool HostTracker::set_visibility(bool v)
 bool HostTracker::is_visible() const
 {
     std::lock_guard<std::mutex> lck(host_tracker_lock);
-    return visibility == host_cache.get_valid_id();
+    return visibility == host_cache.get_valid_id(cache_idx);
 }
 
 
@@ -1108,6 +1108,28 @@ void HostTracker::remove_flows()
     flows.clear();
 }
 
+void HostTracker::update_cache_interface(uint8_t idx)
+{
+
+    if (idx == cache_idx and cache_interface == host_cache.seg_list[idx])
+        return;
+
+    std::lock_guard<std::mutex> lock(host_tracker_lock);
+    cache_idx = idx;
+    cache_interface = host_cache.seg_list[idx];
+    
+    update_allocator(macs, cache_interface);
+    update_allocator(network_protos, cache_interface);
+    update_allocator(xport_protos, cache_interface);
+    update_allocator(services, cache_interface);
+    update_allocator(clients, cache_interface);
+    update_allocator(ua_fps, cache_interface);
+    update_set_allocator(tcp_fpids, cache_interface);
+    update_set_allocator(udp_fpids, cache_interface);
+    update_set_allocator(smb_fpids, cache_interface);
+    update_set_allocator(cpe_fpids, cache_interface);
+}
+
 HostApplicationInfo::HostApplicationInfo(const char *ver, const char *ven)
 {
     if ( ver )
index 7ed8dc2316a3f67c26165b43f5b03c8e7c25ca97..f9d5754764a621b4bdb0f2fd64bdc916ab7ef25e 100644 (file)
@@ -385,9 +385,28 @@ public:
         return ++nat_count;
     }
 
+    void set_cache_idx(uint8_t idx) 
+    { 
+        std::lock_guard<std::mutex> lck(host_tracker_lock);
+        cache_idx = idx; 
+    }
+
+    void init_visibility(size_t v) 
+    {
+        std::lock_guard<std::mutex> lck(host_tracker_lock);
+        visibility = v;
+    }
+
+    uint8_t get_cache_idx() const
+    {
+        return cache_idx;
+    }
+
     bool set_netbios_name(const char*);
 
     bool set_visibility(bool v = true);
+    size_t get_visibility() const {return visibility;}
+
 
     bool is_visible() const;
 
@@ -418,6 +437,9 @@ public:
     void remove_flows();
     void remove_flow(RNAFlow*);
 
+    void update_cache_interface( uint8_t idx );
+    CacheInterface * get_cache_interface() { return cache_interface; }
+
 private:
 
     mutable std::mutex host_tracker_lock; // ensure that updates to a shared object are safe
@@ -450,10 +472,13 @@ private:
     uint32_t nat_count_start;     // the time nat counting starts for this host
 
     size_t visibility;
+    uint8_t cache_idx = 0; 
 
     uint32_t num_visible_services = 0;
     uint32_t num_visible_clients = 0;
     uint32_t num_visible_macs = 0;
+    
+    CacheInterface * cache_interface = nullptr;
 
     // These three do not lock independently; they are used by payload discovery and called
     // from add_payload(HostApplication&, Port, IpProtocol, AppId, AppId, size_t); where the
index eaa03bf1fb59e58e72dae93f17f63985fc2bc7e6..64e51ae053555793ea5c40d93221e110461aa738 100644 (file)
@@ -23,6 +23,7 @@
 #endif
 
 #include "host_tracker_module.h"
+#include "host_cache_segmented.h"
 
 #include "log/messages.h"
 #include "main/snort_config.h"
@@ -31,6 +32,8 @@
 
 using namespace snort;
 
+static HostCacheIp initial_host_cache(LRU_CACHE_INITIAL_SIZE);
+
 const PegInfo host_tracker_pegs[] =
 {
     { CountType::SUM, "service_adds", "host service adds" },
@@ -92,10 +95,10 @@ bool HostTrackerModule::end(const char* fqn, int idx, SnortConfig*)
 
     else if ( idx && !strcmp(fqn, "host_tracker") && addr.is_set() )
     {
-        host_cache[addr];
+        initial_host_cache[addr];
 
         for ( auto& a : apps )
-            host_cache[addr]->add_service(a);
+            initial_host_cache[addr]->add_service(a);
 
         addr.clear();
         apps.clear();
@@ -104,6 +107,17 @@ bool HostTrackerModule::end(const char* fqn, int idx, SnortConfig*)
     return true;
 }
 
+void HostTrackerModule::init_data()
+{
+    auto host_data = initial_host_cache.get_all_data();
+    for ( auto& h : host_data )
+    {
+        host_cache.find_else_insert(h.first, h.second);
+        h.second->init_visibility(1);
+    }
+}
+
+
 const PegInfo* HostTrackerModule::get_pegs() const
 { return host_tracker_pegs; }
 
index 5cc0520314577468a5e94d2a390279b8b637eee0..f7b49cc21b1ff6d9d2cf07791438fe5e8b1de106 100644 (file)
 
 #define host_tracker_help \
     "configure hosts"
+#define HOST_TRACKER_NAME "host_tracker"
 
 class HostTrackerModule : public snort::Module
 {
 public:
     HostTrackerModule() :
-        snort::Module("host_tracker", host_tracker_help, host_tracker_params, true) { }
+        snort::Module(HOST_TRACKER_NAME, host_tracker_help, host_tracker_params, true) { }
 
     const PegInfo* get_pegs() const override;
     PegCount* get_counts() const override;
@@ -50,6 +51,8 @@ public:
     bool begin(const char*, int, snort::SnortConfig*) override;
     bool end(const char*, int, snort::SnortConfig*) override;
 
+    void init_data();
+
     Usage get_usage() const override
     { return GLOBAL; }
 
index 41bc7a93d5998603e7ae29bf66f91dcdab7401c2..6e77e52bdecddc0e692be664e0d11a843a6ea51a 100644 (file)
@@ -2,6 +2,7 @@
 add_cpputest( host_cache_test
     SOURCES
         ../host_cache.cc
+        ../host_cache_segmented.h
         ../host_tracker.cc
         ../../network_inspectors/rna/test/rna_flow_stubs.cc
         ../../sfip/sf_ip.cc
@@ -10,7 +11,7 @@ add_cpputest( host_cache_test
 add_cpputest( host_cache_module_test
     SOURCES
         ../host_cache_module.cc
-        ../host_cache.cc
+        ../host_cache_segmented.h
         ../host_tracker.cc
         ../../framework/module.cc
         ../../framework/value.cc
@@ -30,9 +31,20 @@ add_cpputest( host_tracker_test
         ../../sfip/sf_ip.cc
 )
 
+add_cpputest( host_cache_segmented_test
+    SOURCES
+        ../host_cache.cc
+        ../host_cache.h
+        ../host_tracker.cc
+        ../host_cache_segmented.h
+        ../../network_inspectors/rna/test/rna_flow_stubs.cc
+        ../../sfip/sf_ip.cc
+    )
+
 add_cpputest( host_tracker_module_test
     SOURCES
         ../host_cache.cc
+        ../host_cache_segmented.h
         ../host_tracker.cc
         ../host_tracker_module.cc
         ../../framework/module.cc
index d22164f2618cefef63dea8360928119bee45e74f..c43adfd31bb8dabcabbfec44be586cb1e9da8a7a 100644 (file)
@@ -23,6 +23,7 @@
 #endif
 
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
 #include "host_tracker/cache_allocator.cc"
 #include "network_inspectors/rna/rna_flow.h"
 
 #include <CppUTest/CommandLineTestRunner.h>
 #include <CppUTest/TestHarness.h>
 
-HostCacheIp host_cache(100);
+HostCacheIp default_host_cache(LRU_CACHE_INITIAL_SIZE);
+HostCacheSegmentedIp host_cache(4,100);
 
 using namespace std;
 using namespace snort;
 
+namespace snort
+{
+void FatalError(const char* fmt, ...) { (void)fmt; exit(1); }
+}
 // Derive an allocator from CacheAlloc:
 template <class T>
 class Allocator : public CacheAlloc<T>
@@ -124,8 +130,8 @@ TEST(cache_allocator, allocate)
 
 int main(int argc, char** argv)
 {
-    // FIXIT-L There is currently no external way to fully release the memory from the global host
-    //   cache unordered_map in host_cache.cc
     MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
-    return CommandLineTestRunner::RunAllTests(argc, argv);
+    int ret =  CommandLineTestRunner::RunAllTests(argc, argv);
+    host_cache.term();
+    return ret;
 }
index 69dff2aadbae8e1a49c7be643cebeafe4e2b4b06..7bc9b5edb4ffd0b93beae14c118c79f4b59f236b 100644 (file)
 #include "network_inspectors/rna/rna_flow.h"
 
 #include <cstring>
-
 #include "main/snort_config.h"
 
 #include <CppUTest/CommandLineTestRunner.h>
 #include <CppUTest/TestHarness.h>
 
 using namespace snort;
+using namespace std;
 
 namespace snort
 {
@@ -42,14 +42,75 @@ namespace snort
 char* snort_strdup(const char* str)
 { return strdup(str); }
 time_t packet_time() { return 0; }
+void FatalError(const char* fmt, ...) { (void)fmt; exit(1);}
+}
+HostCacheIp default_host_cache(LRU_CACHE_INITIAL_SIZE);
+HostCacheSegmentedIp host_cache(1,100);
+
+
+template <class T>
+class Allocator : public CacheAlloc<T>
+{
+public:
+    template <class U>
+    struct rebind
+    {
+        typedef Allocator<U> other;
+    };
+
+    using CacheAlloc<T>::lru;
+    using Base = CacheAlloc<T>;
+
+    void set_cache(CacheInterface* hci) { Base::set_lru(hci); }
+    CacheInterface* get_cache_ptr() { return Base::get_lru(); }
+
+    template <class U>
+    Allocator(const Allocator<U>& other) 
+    {
+        lru = other.lru;
+    }
+    template <class U>
+    Allocator(const Allocator<U>&& other) 
+    {
+        lru = other.lru;
+    }
+    Allocator();
+};
+
+
+class Item
+{
+public:
+    typedef int ValueType;
+    vector<ValueType, Allocator<ValueType>> data;
+};
+
+typedef LruCacheSharedMemcap<string, Item, hash<string>> CacheType;
+CacheType cache(100);
+CacheType cache2(100);
+
+template <class T>
+Allocator<T>::Allocator()
+{
+    lru = &cache;
 }
 
-HostCacheIp host_cache(100);
 
 TEST_GROUP(host_cache_allocator_ht)
 {
+
 };
 
+TEST(host_cache_allocator_ht, allocate_update)
+{   
+    //declare a list with allocator cache
+    std::list<string, Allocator<string>> test_list;
+    CHECK(test_list.get_allocator().get_lru() == &cache);
+    //update cache interface of test_list to cache_2
+    update_allocator(test_list, &cache2);
+    CHECK(test_list.get_allocator().get_lru() == &cache2);
+}
+
 // Test allocation / deallocation, pruning and remove.
 TEST(host_cache_allocator_ht, allocate)
 {
@@ -61,12 +122,7 @@ TEST(host_cache_allocator_ht, allocate)
 
     // room for n host trackers in the cache and 2^floor(log2(3))+2^ceil(log2(3))-1 host
     // applications in ht
-    // FIXIT-L this makes a questionable assumption about the STL vector implementation
-    // that it will double the allocation each time it needs to increase its size, so
-    // going from 2 to 3 will allocate 4 and then release 2, meaning in order to exactly
-    // induce pruning, the max size should be just one <ht_item_sz> short of holding 6
-    const size_t max_size = n * hc_item_sz + 5 * ht_item_sz;
-
+    const size_t max_size = n * hc_item_sz + m * ht_item_sz;
     host_cache.set_max_size(max_size);
 
     // insert n empty host trackers:
@@ -146,8 +202,8 @@ TEST(host_cache_allocator_ht, allocate)
 
 int main(int argc, char** argv)
 {
-    // FIXIT-L There is currently no external way to fully release the memory from the global host
-    //   cache unordered_map in host_cache.cc
     MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
-    return CommandLineTestRunner::RunAllTests(argc, argv);
+    int ret = CommandLineTestRunner::RunAllTests(argc, argv);
+    host_cache.term();
+    return ret;
 }
index 557226d5f998acb132ae6fe176c3b8460781eca5..06d8591ed91adbbd49c7e14dfe1adc351449a173 100644 (file)
@@ -29,6 +29,7 @@
 #include "control/control.h"
 #include "host_tracker/host_cache_module.h"
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
 #include "main/snort_config.h"
 #include "managers/module_manager.h"
 
@@ -39,6 +40,9 @@
 
 using namespace snort;
 
+HostCacheIp default_host_cache(LRU_CACHE_INITIAL_SIZE);
+HostCacheSegmentedIp host_cache(4,LRU_CACHE_INITIAL_SIZE);
+
 // All tests here use the same module since host_cache is global. Creating a local module for each
 // test will cause host_cache PegCount testing to be dependent on the order of running these tests.
 static HostCacheModule module;
@@ -78,6 +82,7 @@ void LogMessage(const char* format,...)
 time_t packet_time() { return 0; }
 bool Snort::is_reloading() { return false; }
 void SnortConfig::register_reload_handler(ReloadResourceTuner* rrt) { delete rrt; }
+void FatalError(const char* fmt, ...) { (void)fmt; exit(1); }
 } // end of namespace snort
 
 void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
@@ -86,7 +91,7 @@ void show_stats(PegCount*, const PegInfo*, const IndexVec&, const char*, FILE*)
 template <class T>
 HostCacheAllocIp<T>::HostCacheAllocIp()
 {
-    lru = &host_cache;
+    lru = host_cache.seg_list[0];
 }
 
 TEST_GROUP(host_cache_module)
@@ -95,16 +100,52 @@ TEST_GROUP(host_cache_module)
 
 static void try_reload_prune(bool is_not_locked)
 {
+    auto segs = host_cache.seg_list.size();
+    auto prune_size = host_cache.seg_list[0]->mem_chunk * 1.5 * segs;
     if ( is_not_locked )
     {
-        CHECK(host_cache.reload_prune(host_cache.mem_chunk * 1.5, 2) == true);
+        CHECK(host_cache.reload_prune(prune_size, 2) == true);
+        for ( auto& seg : host_cache.seg_list )
+        {
+            CHECK(seg->get_max_size() == prune_size/segs);
+        }
     }
     else
     {
-        CHECK(host_cache.reload_prune(host_cache.mem_chunk * 1.5, 2) == false);
+        CHECK(host_cache.reload_prune(prune_size, 2) == false);
     }
 }
 
+TEST(host_cache_module, cache_segments)
+{
+    SfIp ip0, ip1, ip2, ip3;
+    ip0.set("1.2.3.2");
+    ip1.set("11.22.2.0");
+    ip2.set("192.168.1.1");
+    ip3.set("10.20.33.10");
+
+    uint8_t segment0 = host_cache.get_segment_idx(ip0);
+    uint8_t segment1 = host_cache.get_segment_idx(ip1);
+    uint8_t segment2 = host_cache.get_segment_idx(ip2);
+    uint8_t segment3 = host_cache.get_segment_idx(ip3);
+
+    CHECK(segment0 == 0);
+    CHECK(segment1 == 1);
+    CHECK(segment2 == 2);
+    CHECK(segment3 == 3);
+
+    auto h0 = host_cache.find_else_create(ip0, nullptr);
+    auto h1 = host_cache.find_else_create(ip1, nullptr);
+    auto h2 = host_cache.find_else_create(ip2, nullptr);
+    auto h3 = host_cache.find_else_create(ip3, nullptr);
+
+    CHECK(segment0 == h0->get_cache_idx());
+    CHECK(segment1 == h1->get_cache_idx());
+    CHECK(segment2 == h2->get_cache_idx());
+    CHECK(segment3 == h3->get_cache_idx());
+}
+
+
 // Test stats when HostCacheModule sets/changes host_cache size.
 // This method is a friend of LruCacheSharedMemcap class.
 TEST(host_cache_module, misc)
@@ -130,11 +171,12 @@ TEST(host_cache_module, misc)
     // cache, because sum_stats resets the pegs.
     module.sum_stats(true);
 
-    // add 3 entries
+    // add 3 entries to segment 3 
     SfIp ip1, ip2, ip3;
     ip1.set("1.1.1.1");
     ip2.set("2.2.2.2");
     ip3.set("3.3.3.3");
+    
     host_cache.find_else_create(ip1, nullptr);
     host_cache.find_else_create(ip2, nullptr);
     host_cache.find_else_create(ip3, nullptr);
@@ -143,23 +185,28 @@ TEST(host_cache_module, misc)
     CHECK(ht_stats[2] == 3*mc);  // bytes_in_use
     CHECK(ht_stats[3] == 3);     // items_in_use
 
-    // no pruning needed for resizing higher than current size
-    CHECK(host_cache.reload_resize(host_cache.mem_chunk * 10) == false);
+    // no pruning needed for resizing higher than current size in segment 3
+    CHECK(host_cache.seg_list[2]->reload_resize(host_cache.get_mem_chunk() * 10 ) == false);
     module.sum_stats(true);
     CHECK(ht_stats[2] == 3*mc);  // bytes_in_use unchanged
     CHECK(ht_stats[3] == 3);     // items_in_use unchanged
 
-    // pruning needed for resizing lower than current size
-    CHECK(host_cache.reload_resize(host_cache.mem_chunk * 1.5) == true);
+    // pruning needed for resizing lower than current size in segment 3
+    CHECK(host_cache.seg_list[2]->reload_resize(host_cache.get_mem_chunk() * 1.5) == true);
     module.sum_stats(true);
     CHECK(ht_stats[2] == 3*mc);  // bytes_in_use still unchanged
     CHECK(ht_stats[3] == 3);     // items_in_use still unchanged
 
     // pruning in thread is not done when reload_mutex is already locked
-    host_cache.reload_mutex.lock();
+    for(auto cache : host_cache.seg_list)
+        cache->reload_mutex.lock();
+        
     std::thread test_negative(try_reload_prune, false);
     test_negative.join();
-    host_cache.reload_mutex.unlock();
+
+    for(auto cache : host_cache.seg_list)
+        cache->reload_mutex.unlock();
+
     module.sum_stats(true);
     CHECK(ht_stats[2] == 3*mc);   // no pruning yet
     CHECK(ht_stats[3] == 3);      // no pruning_yet
@@ -193,6 +240,32 @@ TEST(host_cache_module, misc)
     CHECK(ht_stats[0] == 4);
 }
 
+
+// Test host_cache.get_segment_stats()
+TEST(host_cache_module, get_segment_stats)
+{
+    host_cache.init();
+    std::string str;
+    str = module.get_host_cache_segment_stats(0);
+
+    bool contain = str.find("Segment 0:") != std::string::npos;
+    CHECK_TRUE(contain);
+
+    str = module.get_host_cache_segment_stats(1);
+    contain = str.find("Segment 1:") != std::string::npos;
+    CHECK_TRUE(contain);
+
+    str = module.get_host_cache_segment_stats(2);
+    contain = str.find("Segment 2:") != std::string::npos;
+    CHECK_TRUE(contain);
+
+    str = module.get_host_cache_segment_stats(-1);
+    contain = str.find("total cache segments: 4") != std::string::npos;
+    CHECK_TRUE(contain);
+    
+
+}
+
 TEST(host_cache_module, log_host_cache_messages)
 {
     module.log_host_cache(nullptr, true);
@@ -211,8 +284,8 @@ TEST(host_cache_module, log_host_cache_messages)
 
 int main(int argc, char** argv)
 {
-    // FIXIT-L There is currently no external way to fully release the memory from the global host
-    //   cache unordered_map in host_cache.cc
     MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
-    return CommandLineTestRunner::RunAllTests(argc, argv);
+    int ret = CommandLineTestRunner::RunAllTests(argc, argv);
+    host_cache.term();
+    return ret;
 }
diff --git a/src/host_tracker/test/host_cache_segmented_test.cc b/src/host_tracker/test/host_cache_segmented_test.cc
new file mode 100644 (file)
index 0000000..95e81ae
--- /dev/null
@@ -0,0 +1,100 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2016-2023 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// host_cache_segmented_test.cc author Raza Shafiq <rshafiq@cisco.com>
+
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <cstring>
+
+#include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+#include "sfip/sf_ip.h"
+
+using namespace std;
+using namespace snort;
+
+namespace snort
+{
+char* snort_strdup(const char* s)
+{ return strdup(s); }
+time_t packet_time() { return 0; }
+void FatalError(const char* fmt, ...) { (void)fmt; exit(1); }
+
+}
+
+TEST_GROUP(host_cache_segmented)
+{
+};
+
+TEST(host_cache_segmented, get_segments_test)
+{
+    HostCacheSegmentedIp hc(4,4000);
+    hc.init();
+    CHECK(hc.get_segments() == 4);
+    CHECK(hc.get_memcap_per_segment() == 4000);
+    CHECK(hc.get_max_size() == 16000);
+    hc.term();
+}
+
+
+TEST(host_cache_segmented, cache_setup)
+{
+    HostCacheSegmentedIp hc;
+    hc.setup(2,2000);
+    hc.init();
+    CHECK(hc.get_segments() == 2);
+    CHECK(hc.get_memcap_per_segment() == 1000);
+    CHECK(hc.get_max_size() == 2000);
+    hc.term();
+}
+
+TEST(host_cache_segmented, one_segment)
+{
+    HostCacheSegmentedIp hc3(1,4000);
+    hc3.init();
+    CHECK(hc3.get_segments() == 1);
+    CHECK(hc3.get_memcap_per_segment() == 4000);
+    CHECK(hc3.get_max_size() == 4000);
+    hc3.term();
+}
+
+TEST(host_cache_segmented, set_max_size_test)
+{
+    HostCacheSegmentedIp hc4(16,1000);
+    hc4.init();
+    CHECK(hc4.get_segments() == 16);
+    CHECK(hc4.get_memcap_per_segment() == 1000);
+    hc4.set_max_size(40000);
+    CHECK(hc4.get_segments() == 16);
+    CHECK(hc4.get_memcap_per_segment() == 2500);
+    CHECK(hc4.get_max_size() == 40000);
+    hc4.term();
+}
+
+int main(int argc, char** argv)
+{
+    return CommandLineTestRunner::RunAllTests(argc, argv);
+}
index 03fad2e2efc6aa26878a119c2d6914fb80adb8a1..5052abee09c94243f027a252b142d9758661e16f 100644 (file)
@@ -26,6 +26,7 @@
 #include <cstring>
 
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
 #include "host_tracker/host_tracker_module.h"
 #include "main/snort_config.h"
 #include "target_based/snort_protocols.h"
@@ -40,6 +41,7 @@ namespace snort
 char* snort_strdup(const char* s)
 { return strdup(s); }
 time_t packet_time() { return 0; }
+void FatalError(const char* fmt, ...) { (void)fmt; exit(1); }
 }
 
 //  Fake show_stats to avoid bringing in a ton of dependencies.
@@ -100,9 +102,10 @@ TEST(host_tracker_module, host_tracker_module_test_basic)
 
 int main(int argc, char** argv)
 {
-    // FIXIT-L There is currently no external way to fully release the memory from the global host
-    //   cache unordered_map in host_cache.cc
+    host_cache.init();
     MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
-    return CommandLineTestRunner::RunAllTests(argc, argv);
+    int ret = CommandLineTestRunner::RunAllTests(argc, argv);
+    host_cache.term();
+    return ret;
 }
 
index ca980b09db557cd889b22884c5a00b7f0f6e4729..6edf458ac8c162bf14e3e57021e3ce887f745290 100644 (file)
@@ -42,12 +42,14 @@ namespace snort
 char* snort_strdup(const char* str)
 { return strdup(str); }
 time_t packet_time() { return test_time; }
+void FatalError(const char* fmt, ...) { (void)fmt; exit(1); }
 }
 
 // There always needs to be a HostCacheIp associated with HostTracker,
 // because any allocation / deallocation into the HostTracker will take up
 // memory managed by the cache.
-HostCacheIp host_cache(1024);
+HostCacheIp default_host_cache(LRU_CACHE_INITIAL_SIZE);
+HostCacheSegmentedIp host_cache(4,1024);
 
 TEST_GROUP(host_tracker)
 {
@@ -429,5 +431,7 @@ TEST(host_tracker, rediscover_host)
 
 int main(int argc, char** argv)
 {
-    return CommandLineTestRunner::RunAllTests(argc, argv);
+    int ret = CommandLineTestRunner::RunAllTests(argc, argv);
+    host_cache.term();
+    return ret;
 }
index 34875664c0567bf02c4b7977fbdf9baac8b95e7a..070f5a3bd76cd09495810462061c092196033fe3 100644 (file)
@@ -40,6 +40,8 @@
 #include "framework/mpse.h"
 #include "helpers/process.h"
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
+#include "host_tracker/host_tracker_module.h"
 #include "ips_options/ips_options.h"
 #include "log/log.h"
 #include "log/messages.h"
@@ -352,12 +354,14 @@ void Snort::term()
     HighAvailabilityManager::term();
     SideChannelManager::term();
     ModuleManager::term();
+    host_cache.term();
     PluginManager::release_plugins();
     ScriptManager::release_scripts();
     memory::MemoryCap::term();
     detection_filter_term();
 
     term_signals();
+    
 }
 
 void Snort::clean_exit(int)
@@ -404,6 +408,8 @@ void Snort::setup(int argc, char* argv[])
     memory::MemoryCap::start(*sc->memory, Stream::prune_flows);
     memory::MemoryCap::print(SnortConfig::log_verbose(), true);
 
+    host_cache.init();
+    ((HostTrackerModule*)ModuleManager::get_module(HOST_TRACKER_NAME))->init_data();
     host_cache.print_config();
 
     TimeStart();
index 0ae4fd0a391aec0734821fbe3d5ec38ce6943d12..06fa37c82f2f893b61fb087aa07707aabcfeab41 100644 (file)
@@ -45,6 +45,7 @@
 #include "framework/policy_selector.h"
 #include "hash/xhash.h"
 #include "helpers/process.h"
+#include "host_tracker/host_cache_segmented.h"
 #include "latency/latency_config.h"
 #include "log/messages.h"
 #include "managers/action_manager.h"
@@ -1061,6 +1062,7 @@ void SnortConfig::cleanup_fatal_error()
         EventManager::release_plugins();
         IpsManager::release_plugins();
         InspectorManager::release_plugins();
+        host_cache.term();
     }
 #endif
 }
index dd14d28dc9eaf43eab08ae4bba44f134f83e9666..453d92ec26e4795db962c644e05a757a68619c08 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "appid_discovery.h"
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
 
 #include "log/messages.h"
 #include "packet_tracer/packet_tracer.h"
index 08d898118c6f9d88908c1ea7aafb3f1462d756cc..55d654f77fa5960d9660acc1714db924d086c2f3 100644 (file)
@@ -31,6 +31,7 @@
 
 #include "control/control.h"
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
 #include "log/messages.h"
 #include "main/analyzer.h"
 #include "main/analyzer_command.h"
index 33442be8608a8af1ec4732887917982eef80e2f8..3080e4d60b445f1f515479e65fc965ae01f6330b 100644 (file)
@@ -54,6 +54,7 @@ const char* AppIdApi::get_application_name(AppId, OdpContext&) { return NULL; }
 THREAD_LOCAL PacketTracer* s_pkt_trace = nullptr;
 THREAD_LOCAL Stopwatch<SnortClock>* pt_timer = nullptr;
 void PacketTracer::daq_log(const char*, ...) { }
+void FatalError(const char* fmt, ...) { (void)fmt; exit(1); }
 
 // Stubs for packet
 Packet::Packet(bool) {}
@@ -261,7 +262,8 @@ static AppIdModule* s_app_module = nullptr;
 static AppIdInspector* s_ins = nullptr;
 static ServiceDiscovery* s_discovery_manager = nullptr;
 
-HostCacheIp host_cache(50);
+HostCacheIp default_host_cache(LRU_CACHE_INITIAL_SIZE);
+HostCacheSegmentedIp host_cache(1,50);
 AppId HostTracker::get_appid(Port, IpProtocol, bool, bool)
 {
     return APP_ID_NONE;
@@ -545,5 +547,6 @@ TEST(appid_discovery_tests, change_bits_to_string)
 int main(int argc, char** argv)
 {
     int rc = CommandLineTestRunner::RunAllTests(argc, argv);
+    host_cache.term();
     return rc;
 }
index ef44fdd088c27c4c565c708b4bcf6ae7de7d2807..21e95bf0642600c43d9653d008d14f62a7b98500 100644 (file)
@@ -26,6 +26,7 @@
 
 #include "helpers/discovery_filter.h"
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
 #include "protocols/packet.h"
 
 #ifdef UNIT_TEST
index f55a6a5d9714b94702a07ee6456e1edbd84e6efd..62f383b2bdc3884a2bc5cc204f84f60eaa816af1 100644 (file)
@@ -33,6 +33,7 @@
 
 #include "control/control.h"
 #include "host_tracker/host_cache.h"
+#include "host_tracker/host_cache_segmented.h"
 #include "log/messages.h"
 #include "lua/lua.h"
 #include "main/snort_config.h"