]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3439: netflow: Enforcing memcap for record and template LRU caches
authorMasud Hasan (mashasan) <mashasan@cisco.com>
Fri, 27 May 2022 19:36:44 +0000 (19:36 +0000)
committerMasud Hasan (mashasan) <mashasan@cisco.com>
Fri, 27 May 2022 19:36:44 +0000 (19:36 +0000)
Merge in SNORT/snort3 from ~MASHASAN/snort3:netflow_memcap to master

Squashed commit of the following:

commit bc2f0391d2011a359c8c1b238e222b305cd60db3
Author: Masud Hasan <mashasan@cisco.com>
Date:   Thu May 26 23:51:59 2022 -0400

    host_tracker: Renaming generic files and classes

commit bf7c31fd580de06f7c8311cd7e1fc3c91b7c5f4e
Author: Masud Hasan <mashasan@cisco.com>
Date:   Wed May 18 14:50:13 2022 -0400

    netflow: Enforcing memcap for session record and template LRU caches

26 files changed:
src/hash/CMakeLists.txt
src/hash/lru_cache_local.h [new file with mode: 0644]
src/hash/test/CMakeLists.txt
src/hash/test/lru_cache_local_test.cc [new file with mode: 0644]
src/host_tracker/CMakeLists.txt
src/host_tracker/cache_allocator.cc [moved from src/host_tracker/host_cache_allocator.cc with 90% similarity]
src/host_tracker/cache_allocator.h [moved from src/host_tracker/host_cache_allocator.h with 75% similarity]
src/host_tracker/cache_interface.h [moved from src/host_tracker/host_cache_interface.h with 93% similarity]
src/host_tracker/dev_notes.txt
src/host_tracker/host_cache.h
src/host_tracker/host_tracker.cc
src/host_tracker/host_tracker.h
src/host_tracker/host_tracker_module.cc
src/host_tracker/host_tracker_module.h
src/host_tracker/test/CMakeLists.txt
src/host_tracker/test/cache_allocator_test.cc [moved from src/host_tracker/test/host_cache_allocator_test.cc with 91% similarity]
src/host_tracker/test/host_cache_allocator_ht_test.cc
src/host_tracker/test/host_tracker_test.cc
src/network_inspectors/appid/lua_detector_api.cc
src/network_inspectors/rna/rna_mac_cache.h
src/service_inspectors/netflow/CMakeLists.txt
src/service_inspectors/netflow/netflow.cc
src/service_inspectors/netflow/netflow_cache.cc [new file with mode: 0644]
src/service_inspectors/netflow/netflow_cache.h [new file with mode: 0644]
src/service_inspectors/netflow/netflow_module.cc
src/service_inspectors/netflow/netflow_module.h

index 029a77ef9a9e5e7e4ea5be2ea6c3b3c0d8aefbe8..72605f00514a153199f71354e74b15c16e3935ba 100644 (file)
@@ -4,6 +4,7 @@ set (HASH_INCLUDES
     hashes.h
     hash_defs.h
     hash_key_operations.h
+    lru_cache_local.h
     lru_cache_shared.h
     xhash.h
 )
diff --git a/src/hash/lru_cache_local.h b/src/hash/lru_cache_local.h
new file mode 100644 (file)
index 0000000..52efbdf
--- /dev/null
@@ -0,0 +1,169 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2022-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// lru_cache_local.h author Masud Hasan <mashasan@cisco.com>
+
+#ifndef LRU_CACHE_LOCAL_H
+#define LRU_CACHE_LOCAL_H
+
+// LruCacheLocal - A simple thread-unsafe memcap-enforced least-recently-used cache.
+
+#include <list>
+#include <unordered_map>
+#include <vector>
+
+#include "framework/counts.h"
+
+#define LRU_CACHE_LOCAL_PEGS(module) \
+    { CountType::SUM, "cache_adds", module " cache added new entry" }, \
+    { CountType::SUM, "cache_hits", module " cache found existing entry" }, \
+    { CountType::SUM, "cache_misses", module " cache did not find entry" }, \
+    { CountType::SUM, "cache_replaces", module " cache found entry and replaced its value" }, \
+    { CountType::SUM, "cache_max", module " cache's maximum byte usage"}, \
+    { CountType::SUM, "cache_prunes", module " cache pruned entry to make space for new entry" }
+
+struct LruCacheLocalStats
+{
+    PegCount cache_adds;
+    PegCount cache_hits;
+    PegCount cache_misses;
+    PegCount cache_replaces;
+    PegCount cache_max;
+    PegCount cache_prunes;
+};
+
+template<typename Key, typename Value, typename Hash>
+class LruCacheLocal
+{
+public:
+    LruCacheLocal(const size_t sz, struct LruCacheLocalStats& st)
+        : max_size(sz), current_size(0), stats(st) { }
+
+    virtual ~LruCacheLocal() = default;
+
+    // Return 1 if an entry associated with the key exists, else return 0
+    int count(const Key&);
+
+    // Return the entry associated with the key; insert new entry if absent
+    Value& find_else_create(const Key&, bool* is_new = nullptr);
+
+    // If key does not exist, insert the key-value pair and return true;
+    // else return false replacing the existing value if asked
+    bool add(const Key&, const Value&, bool replace = false);
+
+    // Copy all key-value pairs from the cache
+    void get_all_values(std::vector<std::pair<Key, Value>>&);
+
+protected:
+    using LruList = std::list<std::pair<Key, Value>>;
+    using LruListIter = typename LruList::iterator;
+    using LruMap = std::unordered_map<Key, LruListIter, Hash>;
+    using LruMapIter = typename LruMap::iterator;
+
+    void prune();
+    void add_entry(const Key&, const Value&);
+
+    static constexpr size_t entry_size = 2 * sizeof(Key) + sizeof(Value) + sizeof(LruListIter);
+    const size_t max_size;
+    size_t current_size;
+    LruList list;
+    LruMap map;
+    struct LruCacheLocalStats& stats;
+};
+
+template<typename Key, typename Value, typename Hash>
+void LruCacheLocal<Key, Value, Hash>::prune()
+{
+    if ( !max_size )
+        return;
+
+    while ( current_size > max_size and !list.empty() )
+    {
+        auto it = --list.end();
+        map.erase(it->first);
+        list.erase(it);
+        current_size -= entry_size;
+        ++stats.cache_prunes;
+    }
+}
+
+template<typename Key, typename Value, typename Hash>
+int LruCacheLocal<Key, Value, Hash>::count(const Key& key)
+{
+    return map.count(key);
+}
+
+template<typename Key, typename Value, typename Hash>
+void LruCacheLocal<Key, Value, Hash>::add_entry(const Key& key, const Value& value)
+{
+    stats.cache_adds++;
+    list.emplace_front(std::make_pair(key, value));
+    map[key] = list.begin();
+    current_size += entry_size;
+    prune();
+    if ( stats.cache_max < current_size )
+        stats.cache_max = current_size;
+}
+
+template<typename Key, typename Value, typename Hash>
+Value& LruCacheLocal<Key, Value, Hash>::find_else_create(const Key& key, bool* is_new)
+{
+    LruMapIter it = map.find(key);
+    if (it == map.end())
+    {
+        stats.cache_misses++;
+        add_entry(key, Value());
+        if ( is_new )
+            *is_new = true;
+        return list.begin()->second;
+    }
+
+    stats.cache_hits++;
+    list.splice(list.begin(), list, it->second);
+    return list.begin()->second;
+}
+
+template<typename Key, typename Value, typename Hash>
+bool LruCacheLocal<Key, Value, Hash>::add(const Key& key, const Value& value, bool replace)
+{
+    LruMapIter it = map.find(key);
+    if (it == map.end())
+    {
+        stats.cache_misses++;
+        add_entry(key, value);
+        return true;
+    }
+
+    stats.cache_hits++;
+    list.splice(list.begin(), list, it->second);
+    if ( replace )
+    {
+        it->second->second = value;
+        stats.cache_replaces++;
+    }
+    return false;
+}
+
+template<typename Key, typename Value, typename Hash>
+void LruCacheLocal<Key, Value, Hash>::get_all_values(std::vector<std::pair<Key, Value>>& kv)
+{
+    for (auto& entry : list )
+        kv.emplace_back(entry);
+}
+
+#endif
index 20e176d8b16743c6e7cdb320f62d58181d0c1482..d140fc3d98ec7c723756e5e1f40720113b5e956f 100644 (file)
@@ -1,3 +1,7 @@
+add_cpputest( lru_cache_local_test
+    SOURCES ../lru_cache_local.h
+)
+
 add_cpputest( lru_cache_shared_test
     SOURCES ../lru_cache_shared.cc
 )
diff --git a/src/hash/test/lru_cache_local_test.cc b/src/hash/test/lru_cache_local_test.cc
new file mode 100644 (file)
index 0000000..b66ed9c
--- /dev/null
@@ -0,0 +1,78 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2022-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// lru_cache_local.h author Masud Hasan <mashasan@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "hash/lru_cache_local.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+TEST_GROUP(lru_cache_local)
+{
+};
+
+TEST(lru_cache_local, basic)
+{
+    LruCacheLocalStats stats;
+    LruCacheLocal<int, int, std::hash<int>> lru_cache(40, stats);
+
+    // Check pruning in LRU order; memcap = 40 bytes would hold 2 entries
+    int key = 1;
+    int val = 100;
+    lru_cache.add(key, val);
+
+    key = 2;
+    val = 200;
+    lru_cache.add(key, val);
+
+    key = 3;
+    val = 300;
+    lru_cache.add(key, val);
+
+    // Check entries are stored in LRU order indeed
+    std::vector<std::pair<int, int>> vec;
+    lru_cache.get_all_values(vec);
+    CHECK(vec.size() == 2);
+    CHECK(vec[0].first == 3 and vec[0].second == 300);
+    CHECK(vec[1].first == 2 and vec[1].second == 200);
+
+    // Check non-existent entry; find_else_create() would return a new entry
+    auto& entry = lru_cache.find_else_create(4);
+    entry = 400;
+    // Subsequent calls would return the same one
+    CHECK(lru_cache.find_else_create(4) == 400);
+    entry = 4000;
+    CHECK(lru_cache.find_else_create(4) == 4000);
+
+    // The cache is changed after the above calls
+    vec.clear();
+    lru_cache.get_all_values(vec);
+    CHECK(vec.size() == 2);
+    CHECK(vec[0].first == 4 and vec[0].second == 4000);
+    CHECK(vec[1].first == 3 and vec[1].second == 300);
+}
+
+int main(int argc, char** argv)
+{
+    return CommandLineTestRunner::RunAllTests(argc, argv);
+}
index 94fe835805ef5bc71464d0679d83017532e85bf7..56dfa68067335dd265edd11e15e17202d18753fd 100644 (file)
@@ -1,14 +1,14 @@
 set (HOST_TRACKER_INCLUDES
+    cache_allocator.h
+    cache_interface.h
     host_cache.h
-    host_cache_allocator.h
-    host_cache_interface.h
     host_tracker.h
 )
 
 add_library( host_tracker OBJECT
     ${HOST_TRACKER_INCLUDES}
+    cache_allocator.cc
     host_cache.cc
-    host_cache_allocator.cc
     host_cache_module.cc
     host_cache_module.h
     host_tracker_module.cc
similarity index 90%
rename from src/host_tracker/host_cache_allocator.cc
rename to src/host_tracker/cache_allocator.cc
index 6f752981aa5db8fcb1e3d03b9f534d816112f5d9..ca50c5c1b007fc546b91d9ff339b34f202b5237f 100644 (file)
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
 
-// host_cache_allocator.cc author Silviu Minut <sminut@cisco.com>
+// cache_allocator.cc author Silviu Minut <sminut@cisco.com>
 
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
 
-#ifndef HOST_CACHE_ALLOCATOR_CC
-#define HOST_CACHE_ALLOCATOR_CC
+#ifndef CACHE_ALLOCATOR_CC
+#define CACHE_ALLOCATOR_CC
 
 #include "host_cache.h"
 
similarity index 75%
rename from src/host_tracker/host_cache_allocator.h
rename to src/host_tracker/cache_allocator.h
index a14219c3ce7b270b9406f05c8c2002d21ac30915..d1f5be4c03cd82f0a22cfbb0f5ee299f07ab8716 100644 (file)
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
 
-// host_cache_allocator.h author Silviu Minut <sminut@cisco.com>
+// cache_allocator.h author Silviu Minut <sminut@cisco.com>
 
-#ifndef HOST_CACHE_ALLOCATOR_H
-#define HOST_CACHE_ALLOCATOR_H
+#ifndef CACHE_ALLOCATOR_H
+#define CACHE_ALLOCATOR_H
 
 #include <cassert>
 
-#include "host_cache_interface.h"
+#include "cache_interface.h"
 
 template <class T>
-class HostCacheAlloc : public std::allocator<T>
+class CacheAlloc : public std::allocator<T>
 {
 public:
 
     template <class U>
     struct rebind
     {
-        typedef HostCacheAlloc<U> other;
+        typedef CacheAlloc<U> other;
     };
 
     T* allocate(std::size_t n);
@@ -41,11 +41,11 @@ public:
 
 protected:
 
-    HostCacheInterface* lru = nullptr;
+    CacheInterface* lru = nullptr;
 };
 
 template <class T>
-T* HostCacheAlloc<T>::allocate(std::size_t n)
+T* CacheAlloc<T>::allocate(std::size_t n)
 {
     size_t sz = n * sizeof(T);
     T* out = std::allocator<T>::allocate(n);
@@ -54,7 +54,7 @@ T* HostCacheAlloc<T>::allocate(std::size_t n)
 }
 
 template <class T>
-void HostCacheAlloc<T>::deallocate(T* p, std::size_t n) noexcept
+void CacheAlloc<T>::deallocate(T* p, std::size_t n) noexcept
 {
     size_t sz = n * sizeof(T);
     std::allocator<T>::deallocate(p, n);
@@ -63,13 +63,13 @@ void HostCacheAlloc<T>::deallocate(T* p, std::size_t n) noexcept
 
 
 // Trivial derived allocator, pointing to their own host cache.
-// HostCacheAllocIp has a HostCacheInterface* pointing to an lru cache
+// HostCacheAllocIp has a CacheInterface* pointing to an lru cache
 // instantiated using snort::SfIp as the key. See host_cache.h.
 // We can create different cache types by instantiating the lru cache using
-// different keys and derive here allocators with HostCacheInterface*
+// different keys and derive here allocators with CacheInterface*
 // pointing to the appropriate lru cache object.
 template <class T>
-class HostCacheAllocIp : public HostCacheAlloc<T>
+class HostCacheAllocIp : public CacheAlloc<T>
 {
 public:
 
@@ -80,7 +80,7 @@ public:
         typedef HostCacheAllocIp<U> other;
     };
 
-    using HostCacheAlloc<T>::lru;
+    using CacheAlloc<T>::lru;
 
     HostCacheAllocIp();
 
similarity index 93%
rename from src/host_tracker/host_cache_interface.h
rename to src/host_tracker/cache_interface.h
index b57136efb2cd12de0fffa272203699eb9d3b2d24..43df9277036a16f6e823ca6f85ad4f664e62104f 100644 (file)
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
 
-// host_cache_interface.h author Silviu Minut <sminut@cisco.com>
+// cache_interface.h author Silviu Minut <sminut@cisco.com>
 
-#ifndef HOST_CACHE_INTERFACE_H
-#define HOST_CACHE_INTERFACE_H
+#ifndef CACHE_INTERFACE_H
+#define CACHE_INTERFACE_H
 
-class HostCacheInterface
+class CacheInterface
 {
 public:
 
index 3e438462bcfa90e48e5730d892e487048c5f0534..1a9a85361b48d9fd12a7686851ea525bf610da1c 100644 (file)
@@ -125,9 +125,9 @@ a single .h file. However, to break this dependency, we have to split the
 Allocator into a .h and a .cc file. We include the .h file in HostTracker
 and declare HostCache extern only in the .cc file. Then, we have to include
 the .cc file also in the HostTracker implementation file because Allocator
-is templated. See host_cache.h, host_cache_allocator.h, host_cache_allocator.cc,
+is templated. See host_cache.h, cache_allocator.h, cache_allocator.cc,
 host_tracker.h and host_tracker.cc.
 
-Illustrative examples are test/host_cache_allocator_test.cc (standalone
+Illustrative examples are test/cache_allocator_test.cc (standalone
 host cache / allocator example) and test/host_cache_allocator_ht_test.cc
 (host_cache / allocator with host tracker example).
index 0d3b51a8c10a479c6cb971621fa7d7410cfd9847..2ecc8ac709e5f46acb30fb0697ac07e708e93d8b 100644 (file)
 #include <cassert>
 
 #include "hash/lru_cache_shared.h"
-#include "host_cache_interface.h"
-#include "host_cache_allocator.h"
 #include "host_tracker.h"
 #include "log/messages.h"
 #include "main/snort_config.h"
 #include "sfip/sf_ip.h"
 #include "utils/stats.h"
 
+#include "cache_allocator.h"
+#include "cache_interface.h"
+
 // Used to create hash of key for indexing into cache.
 //
 // Note that both HashIp and IpEqualTo below ignore the IP family.
@@ -61,7 +62,7 @@ struct IpEqualTo
 template<typename Key, typename Value, typename Hash, typename Eq = std::equal_to<Key>,
     typename Purgatory = std::vector<std::shared_ptr<Value>>>
 class LruCacheSharedMemcap : public LruCacheShared<Key, Value, Hash, Eq, Purgatory>,
-    public HostCacheInterface
+    public CacheInterface
 {
 public:
     using LruBase = LruCacheShared<Key, Value, Hash, Eq, Purgatory>;
index cc48bf6cf5d63acff9b65d839e72009f2b28448a..4bd4ac6fe41baf280c39e3b76fbb7e3b8db65366 100644 (file)
@@ -28,8 +28,8 @@
 #include "network_inspectors/rna/rna_flow.h"
 #include "utils/util.h"
 
+#include "cache_allocator.cc"
 #include "host_cache.h"
-#include "host_cache_allocator.cc"
 #include "host_tracker.h"
 
 using namespace snort;
index 9c7f22d681ec9997b404641c81d38f2edce46fcd..b58e9fccf306590e0bff86b051131d9b4b784dc6 100644 (file)
@@ -33,7 +33,6 @@
 #include <vector>
 
 #include "framework/counts.h"
-#include "host_cache_allocator.h"
 #include "main/snort_types.h"
 #include "main/thread.h"
 #include "network_inspectors/appid/application_ids.h"
@@ -41,6 +40,8 @@
 #include "protocols/vlan.h"
 #include "time/packet_time.h"
 
+#include "cache_allocator.h"
+
 struct HostTrackerStats
 {
     PegCount service_adds;
index 7a1bbef6f5aec2ce9e7872acfc930cb22a04937c..1de162e3cc5330157a80eec20dd0f5a0381402e0 100644 (file)
 #endif
 
 #include "host_tracker_module.h"
-#include "host_cache_allocator.cc"
 
 #include "log/messages.h"
 #include "main/snort_config.h"
 
+#include "cache_allocator.cc"
+
 using namespace snort;
 
 const PegInfo host_tracker_pegs[] =
index 992eea50bbed2d500bbb57ced59accef22b99d06..9e338d6a45d017b9e28a150aa1c40738bbd8f63f 100644 (file)
@@ -30,8 +30,8 @@
 #include <cassert>
 
 #include "framework/module.h"
+#include "host_tracker/cache_allocator.cc"
 #include "host_tracker/host_cache.h"
-#include "host_tracker/host_cache_allocator.cc"
 
 #define host_tracker_help \
     "configure hosts"
index 8e91de0b45178ac7f377bd0ca0906c8e1098f360..41bc7a93d5998603e7ae29bf66f91dcdab7401c2 100644 (file)
@@ -52,7 +52,7 @@ add_cpputest( host_cache_allocator_ht_test
         ../../sfip/sf_ip.cc
 )
 
-add_cpputest( host_cache_allocator_test
+add_cpputest( cache_allocator_test
     SOURCES
         ../host_tracker.cc
         ../../network_inspectors/rna/test/rna_flow_stubs.cc
similarity index 91%
rename from src/host_tracker/test/host_cache_allocator_test.cc
rename to src/host_tracker/test/cache_allocator_test.cc
index 312f73891a3653a4ed5ccbde84bcbbb19b1d47e3..9ad6872a19e49d5d8cc97c2e54104c9b20db2e8b 100644 (file)
 // 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
 //--------------------------------------------------------------------------
 
-// host_cache_allocator_test.cc author Silviu Minut <sminut@cisco.com>
+// cache_allocator_test.cc author Silviu Minut <sminut@cisco.com>
 
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
 
 #include "host_tracker/host_cache.h"
-#include "host_tracker/host_cache_allocator.cc"
+#include "host_tracker/cache_allocator.cc"
 #include "network_inspectors/rna/rna_flow.h"
 
 #include <string>
@@ -38,9 +38,9 @@ HostCacheIp host_cache(100);
 using namespace std;
 using namespace snort;
 
-// Derive an allocator from HostCacheAlloc:
+// Derive an allocator from CacheAlloc:
 template <class T>
-class Allocator : public HostCacheAlloc<T>
+class Allocator : public CacheAlloc<T>
 {
 public:
 
@@ -51,7 +51,7 @@ public:
         typedef Allocator<U> other;
     };
 
-    using HostCacheAlloc<T>::lru;
+    using CacheAlloc<T>::lru;
 
     Allocator();
 };
@@ -70,7 +70,7 @@ typedef LruCacheSharedMemcap<string, Item, hash<string>> CacheType;
 CacheType cache(100);
 
 // Implement the allocator constructor AFTER we have a cache object
-// to point to and the implementation of our base HostCacheAlloc:
+// to point to and the implementation of our base CacheAlloc:
 template <class T>
 Allocator<T>::Allocator()
 {
index 8462831696a11ec3c11f747cbe804d1e901d0b1f..f9df18b889d4fa222160b32750d666f2d256a226 100644 (file)
@@ -23,8 +23,8 @@
 #include "config.h"
 #endif
 
+#include "host_tracker/cache_allocator.cc"
 #include "host_tracker/host_cache.h"
-#include "host_tracker/host_cache_allocator.cc"
 #include "network_inspectors/rna/rna_flow.h"
 
 #include <cstring>
index 7afe66229abaf4522dfb9a72ed1aab756aee159d..2a6f31b5a52d4b7f8004bb80821bc6416712566c 100644 (file)
@@ -25,8 +25,8 @@
 
 #include <cstring>
 
+#include "host_tracker/cache_allocator.cc"
 #include "host_tracker/host_cache.h"
-#include "host_tracker/host_cache_allocator.cc"
 #include "network_inspectors/rna/rna_flow.h"
 
 #include <CppUTest/CommandLineTestRunner.h>
index 7780d71368c20b4ac224ed4505055d7b5d3b50fd..d8dd5d6859767cbd794ce223b402e5c272c36b39 100644 (file)
@@ -28,6 +28,8 @@
 #include <pcre.h>
 #include <unordered_map>
 
+#include "host_tracker/cache_allocator.cc"
+#include "host_tracker/host_cache.h"
 #include "log/messages.h"
 #include "main/snort_debug.h"
 #include "main/snort_types.h"
@@ -50,8 +52,6 @@
 #include "lua_detector_util.h"
 #include "service_plugins/service_discovery.h"
 #include "service_plugins/service_ssl.h"
-#include "host_tracker/host_cache.h"
-#include "host_tracker/host_cache_allocator.cc"
 
 using namespace snort;
 using namespace std;
index 55ae54eb2d1f68e8fa67ee6aa13d7b5ef82aa488..ef14720c3a2345219da44aee9bf270ffd21b87c4 100644 (file)
 
 #include "hash/hash_key_operations.h"
 // Non-standard, required to include template definition across compiler translation units
-#include "host_tracker/host_cache_allocator.cc"
-#include "host_tracker/host_cache_allocator.h"
+#include "host_tracker/cache_allocator.cc"
+#include "host_tracker/cache_allocator.h"
 
 #define MAC_CACHE_INITIAL_SIZE 1024 * 1024 // Default to 1 MB
 
 template <class T>
-class HostCacheAllocMac : public HostCacheAlloc<T>
+class HostCacheAllocMac : public CacheAlloc<T>
 {
 public:
     template <class U>
@@ -38,7 +38,7 @@ public:
         typedef HostCacheAllocMac<U> other;
     };
 
-    using HostCacheAlloc<T>::lru;
+    using CacheAlloc<T>::lru;
 
     HostCacheAllocMac();
 };
index 11e6c5fb2dcea48ad0da46ee0ac0b085d73af5a5..a4e0531f3fc510dab489315975708303c864c7db 100644 (file)
@@ -1,5 +1,6 @@
 
 set ( FILE_LIST
+    netflow_cache.h
     netflow_headers.h
     netflow_module.cc
     netflow_module.h
index de7e1cfb1f31259dc8eda658f73cda1ff514388c..7ab33dd41f43ebc6b3e9960d36fe87e52ad70ffb 100644 (file)
@@ -23,9 +23,6 @@
 #include "config.h"
 #endif
 
-#include "netflow_headers.h"
-#include "netflow_module.h"
-
 #include <fstream>
 #include <mutex>
 #include <sys/stat.h>
 #include "profiler/profiler.h"
 #include "protocols/packet.h"
 #include "pub_sub/netflow_event.h"
-#include "sfip/sf_ip.h"
 #include "src/utils/endian.h"
 #include "utils/util.h"
 
+#include "netflow_cache.cc"
+
 using namespace snort;
 
 THREAD_LOCAL NetflowStats netflow_stats;
 THREAD_LOCAL ProfileStats netflow_perf_stats;
 
-// compare struct to use with ip sort
-struct IpCompare
-{
-    bool operator()(const snort::SfIp& a, const snort::SfIp& b)
-    { return a.less_than(b); }
-};
-
 // Used to ensure we fully populate the record; can't rely on the actual values being zero
 struct RecordStatus
 {
@@ -69,22 +60,22 @@ struct RecordStatus
 // static variables
 // -----------------------------------------------------------------------------
 
-// Used to avoid creating multiple events for the same initiator IP.
-// Cache can be thread local since Netflow packets coming from a Netflow
-// device will go to the same thread.
-typedef std::unordered_map<snort::SfIp, NetflowSessionRecord, NetflowHash> NetflowCache;
-static THREAD_LOCAL NetflowCache* netflow_cache = nullptr;
+// temporary cache required to dump the output
+typedef std::pair<snort::SfIp, NetflowSessionRecord> IpRecord;
+typedef std::vector<IpRecord> DumpCache;
+static DumpCache* dump_cache = nullptr;
 
-// cache required to dump the output
-static NetflowCache* dump_cache = nullptr;
-
-// Netflow version 9 Template fields cache.
-typedef std::unordered_map<std::pair<uint16_t, snort::SfIp>, std::vector<Netflow9TemplateField>, TemplateIpHash> TemplateFieldCache;
-static THREAD_LOCAL TemplateFieldCache* template_cache = nullptr;
+// compare struct to use with ip sort
+struct IpCompare
+{
+    bool operator()(const IpRecord& a, const IpRecord& b)
+    { return a.first.less_than(b.first); }
+};
 
 // -----------------------------------------------------------------------------
 // static functions
 // -----------------------------------------------------------------------------
+
 static bool filter_record(const NetflowRules* rules, const int zone,
     const SfIp* src, const SfIp* dst)
 {
@@ -111,16 +102,16 @@ static bool filter_record(const NetflowRules* rules, const int zone,
 }
 
 static bool version_9_record_update(const unsigned char* data, uint32_t unix_secs,
-    std::vector<Netflow9TemplateField>::iterator field, NetflowSessionRecord &record,
+    uint16_t field_type, uint16_t field_length, NetflowSessionRecord &record,
     RecordStatus& record_status)
 {
 
-    switch ( field->field_type )
+    switch ( field_type )
     {
         case NETFLOW_PROTOCOL:
 
             // invalid protocol
-            if( field->field_length != sizeof(record.proto) )
+            if( field_length != sizeof(record.proto) )
                 return false;
 
             record.proto = (uint8_t)*data;
@@ -129,7 +120,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_TCP_FLAGS:
 
             // invalid tcp flags
-            if( field->field_length != sizeof(record.tcp_flags ) )
+            if( field_length != sizeof(record.tcp_flags ) )
                 return false;
 
             record.tcp_flags = (uint8_t)*data;
@@ -138,7 +129,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_SRC_PORT:
 
             // invalid src port
-            if( field->field_length != sizeof(record.initiator_port) )
+            if( field_length != sizeof(record.initiator_port) )
                 return false;
 
             record.initiator_port = ntohs(*(const uint16_t*) data);
@@ -147,7 +138,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_SRC_IP:
 
             // invalid source ip
-            if( field->field_length != sizeof(uint32_t) )
+            if( field_length != sizeof(uint32_t) )
                 return false;
 
             // Invalid source IP address provided
@@ -169,7 +160,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_DST_PORT:
 
             // invalid destination port
-            if( field->field_length != sizeof(record.responder_port) )
+            if( field_length != sizeof(record.responder_port) )
                 return false;
 
             record.responder_port = ntohs(*(const uint16_t*) data);
@@ -178,7 +169,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_DST_IP:
 
             // invalid length
-            if( field->field_length != sizeof(uint32_t) )
+            if( field_length != sizeof(uint32_t) )
                 return false;
 
             // Invalid destination IP address
@@ -200,7 +191,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_IPV4_NEXT_HOP:
 
             // invalid length
-            if( field->field_length != sizeof(uint32_t) )
+            if( field_length != sizeof(uint32_t) )
                 return false;
 
             // Invalid next-hop IP address
@@ -210,7 +201,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_LAST_PKT:
 
-            if( field->field_length != sizeof(record.last_pkt_second) )
+            if( field_length != sizeof(record.last_pkt_second) )
                 return false;
 
             record.last_pkt_second = unix_secs + ntohl(*(const time_t*)data)/1000;
@@ -224,7 +215,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_FIRST_PKT:
 
-            if( field->field_length != sizeof(record.first_pkt_second) )
+            if( field_length != sizeof(record.first_pkt_second) )
                 return false;
 
             record.first_pkt_second = unix_secs + ntohl(*(const time_t*)data)/1000;
@@ -238,11 +229,11 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_IN_BYTES:
 
-            if ( field->field_length == sizeof(uint64_t) )
+            if ( field_length == sizeof(uint64_t) )
                 record.initiator_bytes = ntohll(*(const uint64_t*)data);
-            else if ( field->field_length == sizeof(uint32_t) )
+            else if ( field_length == sizeof(uint32_t) )
                 record.initiator_bytes = (uint64_t)ntohl(*(const uint32_t*)data);
-            else if ( field->field_length == sizeof(uint16_t) )
+            else if ( field_length == sizeof(uint16_t) )
                 record.initiator_bytes = (uint64_t)ntohs(*(const uint16_t*) data);
             else
                 return false;
@@ -252,11 +243,11 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_IN_PKTS:
 
-            if ( field->field_length == sizeof(uint64_t) )
+            if ( field_length == sizeof(uint64_t) )
                 record.initiator_pkts = ntohll(*(const uint64_t*)data);
-            else if ( field->field_length == sizeof(uint32_t) )
+            else if ( field_length == sizeof(uint32_t) )
                 record.initiator_pkts = (uint64_t)ntohl(*(const uint32_t*)data);
-            else if ( field->field_length == sizeof(uint16_t) )
+            else if ( field_length == sizeof(uint16_t) )
                 record.initiator_pkts = (uint64_t)ntohs(*(const uint16_t*) data);
             else
                 return false;
@@ -266,7 +257,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_SRC_TOS:
 
-            if( field->field_length != sizeof(record.nf_src_tos) )
+            if( field_length != sizeof(record.nf_src_tos) )
                 return false;
 
             record.nf_src_tos = (uint8_t)*data;
@@ -275,7 +266,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_DST_TOS:
 
-            if( field->field_length != sizeof(record.nf_dst_tos))
+            if( field_length != sizeof(record.nf_dst_tos))
                 return false;
 
             record.nf_dst_tos = (uint8_t)*data;
@@ -284,9 +275,9 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_SNMP_IN:
 
-            if ( field->field_length == sizeof(uint32_t) )
+            if ( field_length == sizeof(uint32_t) )
                 record.nf_snmp_in = ntohl(*(const uint32_t*)data);
-            else if ( field->field_length == sizeof(uint16_t) )
+            else if ( field_length == sizeof(uint16_t) )
                 record.nf_snmp_in = (uint32_t)ntohs(*(const uint16_t*) data);
             else
                 return false;
@@ -295,9 +286,9 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_SNMP_OUT:
 
-            if ( field->field_length == sizeof(uint32_t) )
+            if ( field_length == sizeof(uint32_t) )
                 record.nf_snmp_out = ntohl(*(const uint32_t*)data);
-            else if ( field->field_length == sizeof(uint16_t) )
+            else if ( field_length == sizeof(uint16_t) )
                 record.nf_snmp_out = (uint32_t)ntohs(*(const uint16_t*) data);
             else
                 return false;
@@ -306,9 +297,9 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_SRC_AS:
 
-            if( field->field_length == sizeof(uint16_t) )
+            if( field_length == sizeof(uint16_t) )
                 record.nf_src_as = (uint32_t)ntohs(*(const uint16_t*) data);
-            else if( field->field_length == sizeof(uint32_t) )
+            else if( field_length == sizeof(uint32_t) )
                 record.nf_src_as = ntohl(*(const uint32_t*)data);
             else
                 return false;
@@ -316,9 +307,9 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
 
         case NETFLOW_DST_AS:
 
-            if( field->field_length == sizeof(uint16_t) )
+            if( field_length == sizeof(uint16_t) )
                 record.nf_dst_as = (uint32_t)ntohs(*(const uint16_t*) data);
-            else if( field->field_length == sizeof(uint32_t) )
+            else if( field_length == sizeof(uint32_t) )
                 record.nf_dst_as = ntohl(*(const uint32_t*)data);
             else
                 return false;
@@ -327,7 +318,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_SRC_MASK:
         case NETFLOW_SRC_MASK_IPV6:
 
-            if( field->field_length != sizeof(record.nf_src_mask) )
+            if( field_length != sizeof(record.nf_src_mask) )
                 return false;
 
             record.nf_src_mask = (uint8_t)*data;
@@ -336,7 +327,7 @@ static bool version_9_record_update(const unsigned char* data, uint32_t unix_sec
         case NETFLOW_DST_MASK:
         case NETFLOW_DST_MASK_IPV6:
 
-            if( field->field_length != sizeof(record.nf_dst_mask) )
+            if( field_length != sizeof(record.nf_dst_mask) )
                 return false;
 
             record.nf_dst_mask = (uint8_t)*data;
@@ -413,8 +404,7 @@ static bool decode_netflow_v9(const unsigned char* data, uint16_t size,
         // It's a data flowset
         if ( f_id > 255 && template_cache->count(ti_key) > 0 )
         {
-            std::vector<Netflow9TemplateField> tf;
-            tf = template_cache->at(ti_key);
+            auto& tf = template_cache->find_else_create(ti_key);
 
             while( data < flowset_end && records )
             {
@@ -432,7 +422,7 @@ static bool decode_netflow_v9(const unsigned char* data, uint16_t size,
                     if ( !bad_field )
                     {
                         bool status = version_9_record_update(data, header.unix_secs,
-                            t_field, record, record_status);
+                            t_field->field_type, t_field->field_length, record, record_status);
 
                         if ( !status )
                             bad_field = true;
@@ -474,19 +464,8 @@ static bool decode_netflow_v9(const unsigned char* data, uint16_t size,
                     DataBus::publish(NETFLOW_EVENT, event);
                 }
 
-                // check if record exists
-                auto result = netflow_cache->find(record.initiator_ip);
-
-                if ( result != netflow_cache->end() )
-                {
-                    // record exists and hence first remove the element
-                    netflow_cache->erase(record.initiator_ip);
-                    --netflow_stats.unique_flows;
-                }
-
-                // emplace doesn't replace element if exist, hence removing it first
-                netflow_cache->emplace(record.initiator_ip, record);
-                ++netflow_stats.unique_flows;
+                if ( netflow_cache->add(record.initiator_ip, record, true) )
+                    ++netflow_stats.unique_flows;
 
                 records--;
             }
@@ -527,20 +506,8 @@ static bool decode_netflow_v9(const unsigned char* data, uint16_t size,
 
                 if ( field_count > 0 )
                 {
-                    auto t_key = std::make_pair(t_id, device_ip);
-
-                    // remove if there any entry exists for this template
-                    auto is_erased = template_cache->erase(t_key);
-
-                    // count only unique templates
-                    if ( is_erased == 1 )
-                        --netflow_stats.v9_templates;
-
-                    // add template to cache
-                    template_cache->emplace(t_key, tf);
-
-                    // update the total templates count
-                    ++netflow_stats.v9_templates;
+                    if ( template_cache->insert(std::make_pair(t_id, device_ip), tf) )
+                        ++netflow_stats.v9_templates;
 
                     // don't count template as record
                     netflow_stats.records--;
@@ -660,11 +627,7 @@ static bool decode_netflow_v5(const unsigned char* data, uint16_t size,
         record.nf_src_mask = precord->src_mask;
         record.nf_dst_mask = precord->dst_mask;
 
-        // insert record
-        auto result = netflow_cache->emplace(record.initiator_ip, record);
-
-        // new unique record
-        if ( result.second )
+        if ( netflow_cache->add(record.initiator_ip, record, false) )
             ++netflow_stats.unique_flows;
 
         // send create_host and create_service flags too so that rna event handler can log those
@@ -789,6 +752,8 @@ static void show_device(const NetflowRule& d, bool is_exclude)
 
 void NetflowInspector::show(const SnortConfig*) const
 {
+    ConfigLogger::log_value("flow_memcap", config->flow_memcap);
+    ConfigLogger::log_value("template_memcap", config->template_memcap);
     ConfigLogger::log_value("dump_file", config->dump_file);
     ConfigLogger::log_value("update_timeout", config->update_timeout);
     bool log_header = true;
@@ -816,29 +781,21 @@ void NetflowInspector::show(const SnortConfig*) const
 
 void NetflowInspector::stringify(std::ofstream& file_stream)
 {
-    std::vector<snort::SfIp> keys;
-    keys.reserve(dump_cache->size());
-
-    for (const auto& elem : *dump_cache)
-        keys.push_back(elem.first);
-
-    std::sort(keys.begin(),keys.end(), IpCompare());
+    std::sort(dump_cache->begin(), dump_cache->end(), IpCompare());
 
     std::string str;
     SfIpString ip_str;
     uint32_t i = 0;
 
-    auto& cache = *dump_cache;
-
-    for (auto elem : keys)
+    for (auto& elem : *dump_cache)
     {
-        NetflowSessionRecord& record = cache[elem];
+        NetflowSessionRecord& record = elem.second;
         str = "Netflow Record #";
         str += std::to_string(++i);
         str += "\n";
 
         str += "    Initiator IP (Port): ";
-        str += elem.ntop(ip_str);
+        str += elem.first.ntop(ip_str);
         str += " (" + std::to_string(record.initiator_port) + ")";
 
         str += " -> Responder IP (Port): ";
@@ -911,7 +868,7 @@ NetflowInspector::NetflowInspector(const NetflowConfig* pc)
     {
         // create dump cache
         if ( ! dump_cache )
-            dump_cache = new NetflowCache;
+            dump_cache = new DumpCache;
     }
 }
 
@@ -958,11 +915,10 @@ void NetflowInspector::eval(Packet* p)
 void NetflowInspector::tinit()
 {
     if ( !netflow_cache )
-        netflow_cache = new NetflowCache;
+        netflow_cache = new NetflowCache(config->flow_memcap, netflow_stats);
 
     if ( !template_cache )
-        template_cache = new TemplateFieldCache;
-
+        template_cache = new TemplateFieldCache(config->template_memcap, netflow_stats);
 }
 
 void NetflowInspector::tterm()
@@ -971,10 +927,7 @@ void NetflowInspector::tterm()
     {
         static std::mutex stats_mutex;
         std::lock_guard<std::mutex> lock(stats_mutex);
-        {
-            // insert each cache
-            dump_cache->insert(netflow_cache->begin(), netflow_cache->end());
-        }
+        netflow_cache->get_all_values(*dump_cache);
     }
     delete netflow_cache;
     delete template_cache;
diff --git a/src/service_inspectors/netflow/netflow_cache.cc b/src/service_inspectors/netflow/netflow_cache.cc
new file mode 100644 (file)
index 0000000..3a21fc0
--- /dev/null
@@ -0,0 +1,46 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2022-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// netflow_cache.cc author Masud Hasan <mashasan@cisco.com>
+
+#ifndef NETFLOW_CACHE_CC
+#define NETFLOW_CACHE_CC
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "netflow_cache.h"
+
+THREAD_LOCAL NetflowCache* netflow_cache = nullptr;
+
+template <class T>
+LruCacheAllocNetflow<T>::LruCacheAllocNetflow()
+{
+    lru = netflow_cache;
+}
+
+THREAD_LOCAL TemplateFieldCache* template_cache = nullptr;
+
+template <class T>
+LruCacheAllocTemplate<T>::LruCacheAllocTemplate()
+{
+    lru = template_cache;
+}
+
+#endif
diff --git a/src/service_inspectors/netflow/netflow_cache.h b/src/service_inspectors/netflow/netflow_cache.h
new file mode 100644 (file)
index 0000000..02f3eba
--- /dev/null
@@ -0,0 +1,147 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2022-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// netflow_cache.h author Masud Hasan <mashasan@cisco.com>
+
+#ifndef NETFLOW_CACHE_H
+#define NETFLOW_CACHE_H
+
+#include <cassert>
+
+#include "hash/lru_cache_local.h"
+#include "host_tracker/cache_allocator.h"
+#include "sfip/sf_ip.h"
+
+#include "netflow_headers.h"
+#include "netflow_module.h"
+
+// Trivial derived allocator, pointing to their own cache. LruCacheAllocNetflow has a
+// CacheInterface* pointing to an lru cache. We can create different cache types by
+// instantiating the lru cache using different keys and derive here allocators with
+// CacheInterface* pointing to the appropriate lru cache object.
+template <class T>
+class LruCacheAllocNetflow : public CacheAlloc<T>
+{
+public:
+    // This needs to be in every derived class:
+    template <class U>
+    struct rebind
+    {
+        typedef LruCacheAllocNetflow<U> other;
+    };
+
+    using CacheAlloc<T>::lru;
+    LruCacheAllocNetflow();
+};
+
+template<typename Key, typename Value, typename Hash>
+class LruCacheLocalNetflow : public LruCacheLocal<Key, Value, Hash>, public CacheInterface
+{
+public:
+    using LruLocal = LruCacheLocal<Key, Value, Hash>;
+    using LruLocal::current_size;
+    using LruLocal::max_size;
+    using LruLocal::list;
+
+    LruCacheLocalNetflow(const size_t sz, struct LruCacheLocalStats& st) : LruLocal(sz, st) {}
+
+    template <class T>
+    friend class LruCacheAllocNetflow;
+
+private:
+    // Only the allocator calls this
+    void update(int size) override
+    {
+        if ( size < 0 )
+            assert( current_size >= (size_t) -size);
+
+        // Checking 1+ size prevents crash if max_size is too low to hold even a single entry
+        if ( (current_size += size) > max_size and list.size() > 1 )
+            LruLocal::prune();
+    }
+};
+
+template <class T>
+class LruCacheAllocTemplate : public CacheAlloc<T>
+{
+public:
+    template <class U>
+    struct rebind
+    {
+        typedef LruCacheAllocTemplate<U> other;
+    };
+
+    using CacheAlloc<T>::lru;
+    LruCacheAllocTemplate();
+};
+
+template<typename Key, typename Value, typename Hash>
+class LruCacheLocalTemplate : public LruCacheLocal<Key, Value, Hash>, public CacheInterface
+{
+public:
+    using LruLocal = LruCacheLocal<Key, Value, Hash>;
+    using LruLocal::current_size;
+    using LruLocal::max_size;
+    using LruLocal::stats;
+    using LruLocal::list;
+
+    LruCacheLocalTemplate(const size_t sz, struct LruCacheLocalStats& st) : LruLocal(sz, st)
+    {}
+
+    bool insert(const Key& key, std::vector<Netflow9TemplateField>& tf)
+    {
+        bool is_new = false;
+        Value& entry = LruLocal::find_else_create(key, &is_new);
+
+        if ( !is_new )
+        {
+            stats.cache_replaces++;
+            entry.clear();
+        }
+
+        for ( auto& elem : tf )
+            entry.emplace_back(elem.field_type, elem.field_length);
+
+        return is_new;
+    }
+
+    template <class T>
+    friend class LruCacheAllocTemplate;
+
+private:
+    void update(int size) override
+    {
+        if ( size < 0 )
+            assert( current_size >= (size_t) -size);
+
+        if ( (current_size += size) > max_size and list.size() > 1 )
+            LruLocal::prune();
+    }
+};
+
+// Used to track record for unique IP; we assume Netflow packets coming from
+// a given Netflow device will go to the same thread
+typedef LruCacheLocalNetflow<snort::SfIp, NetflowSessionRecord, NetflowHash> NetflowCache;
+
+// Used to track Netflow version 9 Template fields
+typedef std::pair<uint16_t, snort::SfIp> TemplateFieldKey;
+typedef LruCacheAllocTemplate<Netflow9TemplateField> TemplateAllocator;
+typedef std::vector<Netflow9TemplateField, TemplateAllocator> TemplateFieldValue;
+typedef LruCacheLocalTemplate<TemplateFieldKey, TemplateFieldValue, TemplateIpHash> TemplateFieldCache;
+
+#endif
index e01412dfdeabc9353f50411ce786975fa351d229..06908db3bcbfc8e297820ae45fb0a9409ee3606f 100644 (file)
@@ -65,11 +65,18 @@ static const Parameter netflow_params[] =
     { "rules", Parameter::PT_LIST, device_rule_params, nullptr,
       "list of NetFlow device rules" },
 
+    { "flow_memcap", Parameter::PT_INT, "0:maxSZ", "0",
+      "maximum memory for flow record cache in bytes, 0 = unlimited" },
+
+    { "template_memcap", Parameter::PT_INT, "0:maxSZ", "0",
+      "maximum memory for template cache in bytes, 0 = unlimited" },
+
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
 static const PegInfo netflow_pegs[] =
 {
+    LRU_CACHE_LOCAL_PEGS("netflow"),
     { CountType::SUM, "invalid_netflow_record", "count of invalid netflow records" },
     { CountType::SUM, "packets", "total packets processed" },
     { CountType::SUM, "records", "total records found in netflow data" },
@@ -137,7 +144,11 @@ bool NetflowModule::end(const char* fqn, int idx, SnortConfig*)
 }
 bool NetflowModule::set(const char*, Value& v, SnortConfig*)
 {
-    if ( v.is("dump_file") )
+    if ( v.is("flow_memcap") )
+        conf->flow_memcap = v.get_size();
+    else if ( v.is("template_memcap") )
+        conf->template_memcap = v.get_size();
+    else if ( v.is("dump_file") )
     {
         if ( conf->dump_file )
             snort_free((void*)conf->dump_file);
index e8aa38662f7cf3f8e452cc743e72da85b368e958..15011d65ea4ea1eef46e67255911a13d3a366366 100644 (file)
@@ -25,6 +25,7 @@
 #include <unordered_map>
 
 #include "framework/module.h"
+#include "hash/lru_cache_local.h"
 #include "sfip/sf_cidr.h"
 #include "utils/util.h"
 
@@ -120,9 +121,11 @@ struct NetflowConfig
     const char* dump_file = nullptr;
     std::unordered_map <snort::SfIp, NetflowRules, NetflowHash> device_rule_map;
     uint32_t update_timeout = 0;
+    size_t flow_memcap = 0;
+    size_t template_memcap = 0;
 };
 
-struct NetflowStats
+struct NetflowStats : public LruCacheLocalStats
 {
     PegCount invalid_netflow_record;
     PegCount packets;