]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4483: flow: new allowlist LRU
authorRaza Shafiq (rshafiq) <rshafiq@cisco.com>
Tue, 29 Oct 2024 15:22:31 +0000 (15:22 +0000)
committerSteven Baigal (sbaigal) <sbaigal@cisco.com>
Tue, 29 Oct 2024 15:22:31 +0000 (15:22 +0000)
Merge in SNORT/snort3 from ~RSHAFIQ/snort3:whitelist_cache to master

Squashed commit of the following:

commit a1647130533346a651396d00c1d251c294094416
Author: rshafiq <rshafiq@cisco.com>
Date:   Wed Oct 2 19:08:52 2024 -0400

    flow: new allowlist LRU

25 files changed:
src/flow/dev_notes.txt
src/flow/flow.cc
src/flow/flow.h
src/flow/flow_cache.cc
src/flow/flow_cache.h
src/flow/flow_config.h
src/flow/flow_control.cc
src/flow/flow_control.h
src/flow/prune_stats.h
src/flow/test/flow_cache_test.cc
src/flow/test/flow_control_test.cc
src/flow/test/flow_test.cc
src/framework/decode_data.h
src/hash/hash_lru_cache.cc
src/hash/hash_lru_cache.h
src/hash/xhash.cc
src/hash/xhash.h
src/hash/zhash.cc
src/hash/zhash.h
src/main/analyzer.cc
src/main/test/distill_verdict_stubs.h
src/main/test/distill_verdict_test.cc
src/stream/base/stream_base.cc
src/stream/base/stream_module.cc
src/stream/base/stream_module.h

index 0d121440ef65daee40ce110fbfc430592ff3a724..9c8725719facb513fecd15919846d801cd7e06d6 100644 (file)
@@ -86,3 +86,21 @@ bias during pruning sessions, especially concerning UDP flows,
 but lays down a framework for more advanced and controlled data management 
 within the cache moving forward.
 
+10/25/2024
+Allowlist LRU
+To address the need for preserving flows that have been allowlisted and are 
+at risk of timing out, we've introduced a configurable Allowlist LRU cache 
+within the flow_cache. This enhancement enables the retention of flows marked 
+with a whitelist verdict, preventing them from being prematurely pruned due 
+to inactivity timeouts. This is particularly beneficial in scenarios where 
+Snort ceases to observe traffic for a flow after the whitelist decision, 
+especially if that flow is long-lived. Without this adjustment, 
+such flows may be pruned by the cache upon timeout, potentially impacting 
+event logging at the flow’s end-of-life (EOF) due to missing 
+pruned flow information.
+
+The Allowlist LRU cache is disabled by default but can be enabled by adding
+allowlist_cache = { enable = true } in the stream configuration. 
+Like the protocol-based LRUs, this allowlist functionality is an 
+additional LRU rather than a new hash_table, thereby maintaining 
+consistent performance with previous configurations.
\ No newline at end of file
index de7f2b078627f13ce89df8525ef1474f888e905a..f9ec73a286f0ed3d555df623629ee214908763ce 100644 (file)
 #include "detection/context_switcher.h"
 #include "detection/detection_continuation.h"
 #include "detection/detection_engine.h"
+#include "flow/flow_control.h"
 #include "flow/flow_key.h"
 #include "flow/ha.h"
 #include "flow/session.h"
 #include "framework/data_bus.h"
 #include "helpers/bitop.h"
 #include "main/analyzer.h"
+#include "packet_io/packet_tracer.h"
 #include "protocols/packet.h"
 #include "protocols/tcp.h"
 #include "pub_sub/intrinsic_event_ids.h"
@@ -41,6 +43,7 @@
 #include "utils/util.h"
 
 using namespace snort;
+extern THREAD_LOCAL class FlowControl* flow_con;
 
 Flow::~Flow()
 {
@@ -538,3 +541,16 @@ void Flow::swap_roles()
     std::swap(outer_client_ttl, outer_server_ttl);
     flags.client_initiated = !flags.client_initiated;
 }
+
+bool Flow::handle_allowlist()
+{
+    if ( flow_con->get_flow_cache_config().allowlist_cache and !flags.in_allowlist )
+    {
+        if ( flow_con->move_to_allowlist(this) )
+        {
+            PacketTracer::log("Flow: flow has been moved to allowlist cache\n");
+            return true;
+        }
+    }
+    return false;
+}
index 6e4bb03a69ec8b56ad01d9449a96f1d1b8a37568..fd17d645474237ef48353ea4cc4017a129b1b3ba 100644 (file)
@@ -428,6 +428,8 @@ public:
 
     uint64_t fetch_add_inspection_duration();
 
+    bool handle_allowlist();
+
 public:  // FIXIT-M privatize if possible
     // fields are organized by initialization and size to minimize
     // void space
@@ -523,6 +525,7 @@ public:  // FIXIT-M privatize if possible
         bool ips_pblock_event_suppressed : 1; // set if event filters have suppressed a partial block ips event
         bool binder_action_allow : 1;
         bool binder_action_block : 1;
+        bool in_allowlist : 1; // Set if the flow is in the allowlist
     } flags = {};
 
     FlowState flow_state = FlowState::SETUP;
index 54b0eca22bc72edc79f31e2dfa24e5e9b3bb3274..9e9278990561460d98e7d6e97815669325758166 100644 (file)
@@ -60,9 +60,6 @@ static const unsigned OFFLOADED_FLOWS_TOO = 2;
 static const unsigned ALL_FLOWS = 3;
 static const unsigned WDT_MASK = 7; // kick watchdog once for every 8 flows deleted
 
-constexpr uint8_t MAX_PROTOCOLS = (uint8_t)to_utype(PktType::MAX) - 1; //removing PktType::NONE from count
-constexpr uint64_t max_skip_protos = (1ULL << MAX_PROTOCOLS) - 1;
-constexpr uint8_t first_proto = to_utype(PktType::NONE) + 1;
 
 uint8_t DumpFlows::dump_code = 0;
 
@@ -202,11 +199,11 @@ bool DumpFlows::execute(Analyzer&, void**)
 
 FlowCache::FlowCache(const FlowCacheConfig& cfg) : config(cfg)
 {
-    hash_table = new ZHash(config.max_flows, sizeof(FlowKey), MAX_PROTOCOLS, false);
+    hash_table = new ZHash(config.max_flows, sizeof(FlowKey), total_lru_count, false);
     uni_flows = new FlowUniList;
     uni_ip_flows = new FlowUniList;
     flags = 0x0;
-    empty_proto = ( 1 << MAX_PROTOCOLS ) - 1;
+    empty_lru_mask = ( 1 << max_protocols ) - 1;
     timeout_idx = first_proto;
 
     assert(prune_stats.get_total() == 0);
@@ -245,9 +242,14 @@ unsigned FlowCache::get_count()
 
 Flow* FlowCache::find(const FlowKey* key)
 {
-    Flow* flow = (Flow*)hash_table->get_user_data(key,to_utype(key->pkt_type));
+    Flow* flow = (Flow*)hash_table->get_user_data(key,to_utype(key->pkt_type), false);
     if ( flow )
     {
+        if ( flow->flags.in_allowlist )
+            hash_table->touch_last_found(allowlist_lru_index);
+        else
+            hash_table->touch_last_found(to_utype(key->pkt_type));
+
         time_t t = packet_time();
 
         if ( flow->last_data_seen < t )
@@ -321,7 +323,7 @@ Flow* FlowCache::allocate(const FlowKey* key)
     link_uni(flow);
     flow->last_data_seen = timestamp;
     flow->set_idle_timeout(config.proto[to_utype(flow->key->pkt_type)].nominal_timeout);
-    empty_proto &= ~(1ULL << to_utype(key->pkt_type)); // clear the bit for this protocol
+    empty_lru_mask &= ~(1ULL << to_utype(key->pkt_type)); // clear the bit for this protocol
 
     return flow;
 }
@@ -330,9 +332,13 @@ void FlowCache::remove(Flow* flow)
 {
     unlink_uni(flow);
     const snort::FlowKey* key = flow->key;
+    uint8_t in_allowlist = flow->flags.in_allowlist;
     // Delete before releasing the node, so that the key is valid until the flow is completely freed
     delete flow;
-    hash_table->release_node(key, to_utype(key->pkt_type));
+    if ( in_allowlist )
+        hash_table->release_node(key, allowlist_lru_index);
+    else
+        hash_table->release_node(key, to_utype(key->pkt_type));
 }
 
 bool FlowCache::release(Flow* flow, PruneReason reason, bool do_cleanup)
@@ -347,8 +353,9 @@ bool FlowCache::release(Flow* flow, PruneReason reason, bool do_cleanup)
         }
     }
 
+    uint8_t in_allowlist = flow->flags.in_allowlist;
     flow->reset(do_cleanup);
-    prune_stats.update(reason, flow->key->pkt_type);
+    prune_stats.update(reason, ( in_allowlist ? static_cast<PktType>(allowlist_lru_index) : flow->key->pkt_type ));
     remove(flow);
     return true;
 }
@@ -365,31 +372,30 @@ unsigned FlowCache::prune_idle(uint32_t thetime, const Flow* save_me)
     ActiveSuspendContext act_susp(Active::ASP_PRUNE);
 
     unsigned pruned = 0;
-    uint64_t skip_protos = empty_proto;
+    uint64_t checked_lrus_mask = empty_lru_mask;
 
-    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+    assert(max_protocols < 8 * sizeof(checked_lrus_mask));
 
     {
         PacketTracerSuspend pt_susp;
         while ( pruned <= cleanup_flows and 
-                skip_protos != max_skip_protos )
+                !all_lrus_checked(checked_lrus_mask) )
         {
-            // Round-robin through the proto types
-            for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx ) 
+            // Round-robin through the LRU types
+            for( uint8_t lru_idx = first_proto; lru_idx < max_protocols; ++lru_idx )
             {
                 if( pruned > cleanup_flows )
                     break;
 
-                const uint64_t proto_mask = 1ULL << proto_idx;
+                const uint64_t lru_mask = get_lru_mask(lru_idx);
 
-                if ( skip_protos & proto_mask )
+                if( is_lru_checked(checked_lrus_mask, lru_mask) )
                     continue;
 
-                auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx));
+                auto flow = static_cast<Flow*>(hash_table->lru_first(lru_idx));
                 if ( !flow )
                 {
-                    skip_protos |= proto_mask;
-                    empty_proto |= proto_mask;
+                    mark_lru_checked(checked_lrus_mask, empty_lru_mask, lru_mask);
                     continue;
                 }
                 
@@ -397,7 +403,7 @@ unsigned FlowCache::prune_idle(uint32_t thetime, const Flow* save_me)
                     or flow->is_suspended()
                     or flow->last_data_seen + config.pruning_timeout >= thetime )
                 {
-                    skip_protos |= proto_mask;
+                    mark_lru_checked(checked_lrus_mask, lru_mask);
                     continue;
                 }
 
@@ -461,52 +467,53 @@ unsigned FlowCache::prune_excess(const Flow* save_me)
 
     unsigned pruned = 0;
 
-    // initially skip offloads but if that doesn't work the hash table is iterated from the
-    // beginning again. prune offloads at that point.
+    // Initially skip offloads but if that doesn't work, the hash table is iterated from the
+    // beginning again. Prune offloads at that point.
     unsigned ignore_offloads = hash_table->get_num_nodes();
-    uint64_t skip_protos = 0;
+    uint64_t checked_lrus_mask = 0;
 
-    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+    assert(total_lru_count < 8 * sizeof(checked_lrus_mask));
 
     {
         PacketTracerSuspend pt_susp;
         unsigned blocks = 0;
+        // EXCESS pruning will start from the allowlist LRU
+        uint8_t lru_idx = allowlist_lru_index;
 
         while ( true )
         {
             auto num_nodes = hash_table->get_num_nodes();
-            if ( num_nodes <= max_cap or num_nodes <= blocks or 
-                    ignore_offloads == 0 or skip_protos == max_skip_protos )
-                    break;
-            
-            for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx )  
+            if ( num_nodes <= max_cap or num_nodes <= blocks or
+                ignore_offloads == 0 or all_lrus_checked(checked_lrus_mask) )
+                break;
+
+            for (; lru_idx < total_lru_count; ++lru_idx)
             {
                 num_nodes = hash_table->get_num_nodes();
                 if ( num_nodes <= max_cap or num_nodes <= blocks )
                     break;
 
-                const uint64_t proto_mask = 1ULL << proto_idx;
+                const uint64_t lru_mask = get_lru_mask(lru_idx);
 
-                if ( skip_protos & proto_mask ) 
+                if ( is_lru_checked(checked_lrus_mask, lru_mask) )
                     continue;
 
-                auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx));
+                auto flow = static_cast<Flow*>(hash_table->lru_first(lru_idx));
                 if ( !flow )
                 {
-                    skip_protos |= proto_mask;
+                    mark_lru_checked(checked_lrus_mask, lru_mask);
                     continue;
                 }
 
                 if ( (save_me and flow == save_me) or flow->was_blocked() or 
                         (flow->is_suspended() and ignore_offloads) )
                 {
-                    // check for non-null save_me above to silence analyzer
-                    // "called C++ object pointer is null" here
+                    // Avoid pruning the current flow (save_me) or blocked/suspended flows
                     if ( flow->was_blocked() )
                         ++blocks;
-                    // FIXIT-M we should update last_data_seen upon touch to ensure
-                    // the hash_table LRU list remains sorted by time
-                    hash_table->lru_touch(proto_idx);
+
+                    // Ensure LRU list remains sorted by time on touch
+                    hash_table->lru_touch(lru_idx);
                 }
                 else
                 {
@@ -517,6 +524,9 @@ unsigned FlowCache::prune_excess(const Flow* save_me)
                 if ( ignore_offloads > 0 )
                     --ignore_offloads;
             }
+
+            if ( lru_idx >= total_lru_count )
+                lru_idx = first_proto;
         }
 
         if ( !pruned and hash_table->get_num_nodes() > max_cap )
@@ -533,19 +543,22 @@ unsigned FlowCache::prune_excess(const Flow* save_me)
 
 bool FlowCache::prune_one(PruneReason reason, bool do_cleanup, uint8_t type)
 {
-    // so we don't prune the current flow (assume current == MRU)
-    if ( hash_table->get_num_nodes() <= 1 )
+    // Avoid pruning the current flow (assume current == MRU)
+    if (hash_table->get_num_nodes() <= 1)
         return false;
 
-    // ZHash returns in LRU order, which is updated per packet via find --> move_to_front call
     auto flow = static_cast<Flow*>(hash_table->lru_first(type));
-    if( !flow )
+    if ( !flow )
         return false;
 
     flow->ssn_state.session_flags |= SSNFLAG_PRUNED;
-    release(flow, reason, do_cleanup);
 
-    return true;
+    if ( type != allowlist_lru_index )
+        return release(flow, reason, do_cleanup);
+    else if ( reason == PruneReason::MEMCAP or reason == PruneReason::EXCESS )
+        return release(flow, reason, do_cleanup);
+
+    return false;
 }
 
 unsigned FlowCache::prune_multiple(PruneReason reason, bool do_cleanup)
@@ -554,28 +567,38 @@ unsigned FlowCache::prune_multiple(PruneReason reason, bool do_cleanup)
     // so we don't prune the current flow (assume current == MRU)
     if ( hash_table->get_num_nodes() <= 1 )
         return 0;
-    
-    uint8_t proto = 0;
-    uint64_t skip_protos = 0;
 
-    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+    uint8_t lru_idx = 0;
+    uint64_t checked_lrus_mask = 0;
+
+    assert(max_protocols < 8 * sizeof(checked_lrus_mask));
+
+    if( reason == PruneReason::MEMCAP or reason == PruneReason::EXCESS )
+    {
+        // if MEMCAP or EXCESS, prune the allowlist first
+        while ( pruned < config.prune_flows )
+        {
+            if ( !prune_one(reason, do_cleanup, allowlist_lru_index) )
+                break;
+            pruned++;
+        }
+    }
 
-    
     while ( pruned < config.prune_flows )
     {
-        const uint64_t proto_mask = 1ULL << proto;
-        if ( (skip_protos & proto_mask) or !prune_one(reason, do_cleanup, proto) )
+        const uint64_t lru_mask = get_lru_mask(lru_idx);
+        if ( is_lru_checked(checked_lrus_mask, lru_mask) or !prune_one(reason, do_cleanup, lru_idx) )
         {
+            mark_lru_checked(checked_lrus_mask, lru_mask);
 
-            skip_protos |= proto_mask;
-            if ( skip_protos == max_skip_protos )
+            if ( all_lrus_checked(checked_lrus_mask) )
                 break;
         }
         else
             pruned++;
-       
-        if ( ++proto >= MAX_PROTOCOLS )
-            proto = 0;
+
+        if ( ++lru_idx >= max_protocols )
+            lru_idx = 0;
     }
 
     if ( PacketTracer::is_active() and pruned )
@@ -589,21 +612,37 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime)
     ActiveSuspendContext act_susp(Active::ASP_TIMEOUT);
 
     unsigned retired = 0;
-    uint64_t skip_protos = empty_proto;
+    uint64_t checked_lrus_mask = empty_lru_mask;  // Start by skipping any protocols that have no flows.
 
-    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+#ifdef REG_TEST
+    if ( hash_table->get_node_count(allowlist_lru_index) > 0 )
+    {
+        uint64_t allowlist_timeout_count = 0;
+        auto flow = static_cast<Flow*>(hash_table->lru_first(allowlist_lru_index));
+        while ( flow )
+        {
+            if ( flow->last_data_seen + flow->idle_timeout > thetime )
+                allowlist_timeout_count++;
+            flow = static_cast<Flow*>(hash_table->lru_next(allowlist_lru_index));
+        }
+        if ( PacketTracer::is_active() and allowlist_timeout_count )
+            PacketTracer::log("Flow: %lu allowlist flow(s) timed out but not pruned \n", allowlist_timeout_count);
+    }
+#endif
+
+    assert(max_protocols < 8 * sizeof(checked_lrus_mask));
 
     {
         PacketTracerSuspend pt_susp;
 
-        while ( retired < num_flows and skip_protos != max_skip_protos )
+        while ( retired < num_flows and !all_lrus_checked(checked_lrus_mask) )
         {
-            for( ; timeout_idx < MAX_PROTOCOLS; ++timeout_idx ) 
+            for( ; timeout_idx < max_protocols; ++timeout_idx ) 
             {
 
-                const uint64_t proto_mask = 1ULL << timeout_idx;
+                const uint64_t lru_mask = get_lru_mask(timeout_idx);
 
-                if ( skip_protos & proto_mask ) 
+                if ( is_lru_checked(checked_lrus_mask, lru_mask) )
                     continue;
 
                 auto flow = static_cast<Flow*>(hash_table->lru_current(timeout_idx));
@@ -612,8 +651,7 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime)
                     flow = static_cast<Flow*>(hash_table->lru_first(timeout_idx));
                     if ( !flow )
                     {
-                        skip_protos |= proto_mask;
-                        empty_proto |= proto_mask;
+                        mark_lru_checked(checked_lrus_mask, empty_lru_mask, lru_mask);
                         continue;
                     }
                 }
@@ -622,13 +660,13 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime)
                 {
                     if ( flow->expire_time > static_cast<uint64_t>(thetime) )
                     {
-                        skip_protos |= proto_mask;
+                        mark_lru_checked(checked_lrus_mask, lru_mask);
                         continue;
                     }
                 }
                 else if ( flow->last_data_seen + flow->idle_timeout > thetime )
                 {
-                    skip_protos |= proto_mask;
+                    mark_lru_checked(checked_lrus_mask, lru_mask);
                     continue;
                 }
 
@@ -655,30 +693,29 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime)
 
 unsigned FlowCache::delete_active_flows(unsigned mode, unsigned num_to_delete, unsigned &deleted)
 {
-    uint64_t skip_protos = empty_proto;
+    uint64_t checked_lrus_mask = empty_lru_mask;
     uint64_t undeletable = 0;
 
-    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+    assert(max_protocols < 8 * sizeof(checked_lrus_mask));
 
 
-    while ( num_to_delete and skip_protos != max_skip_protos and
+    while ( num_to_delete and !all_lrus_checked(checked_lrus_mask) and
             undeletable < hash_table->get_num_nodes() )
     {
-        for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx ) 
+        for ( uint8_t lru_idx = first_proto; lru_idx < max_protocols; ++lru_idx )
         {
-            if( num_to_delete == 0)
+            if ( num_to_delete == 0 )
                 break;
-            
-            const uint64_t proto_mask = 1ULL << proto_idx;
 
-            if ( skip_protos & proto_mask )
+            const uint64_t lru_mask = get_lru_mask(lru_idx);
+
+            if ( is_lru_checked(checked_lrus_mask, lru_mask) )
                 continue;
-            
-            auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx));
+
+            auto flow = static_cast<Flow*>(hash_table->lru_first(lru_idx));
             if ( !flow )
             {
-                skip_protos |= proto_mask;
-                empty_proto |= proto_mask;
+                mark_lru_checked(checked_lrus_mask, empty_lru_mask, lru_mask);
                 continue;
             }
 
@@ -686,7 +723,7 @@ unsigned FlowCache::delete_active_flows(unsigned mode, unsigned num_to_delete, u
                 or (mode == OFFLOADED_FLOWS_TOO and flow->was_blocked()) )
             {
                 undeletable++;
-                hash_table->lru_touch(proto_idx);
+                hash_table->lru_touch(lru_idx);
                 continue;
             }
 
@@ -706,7 +743,7 @@ unsigned FlowCache::delete_active_flows(unsigned mode, unsigned num_to_delete, u
             // Delete before removing the node, so that the key is valid until the flow is completely freed
             delete flow;
             // The flow should not be removed from the hash before reset
-            hash_table->remove(proto_idx);
+            hash_table->remove(lru_idx);
             ++deleted;
             --num_to_delete;
         }
@@ -741,7 +778,7 @@ unsigned FlowCache::purge()
 
     unsigned retired = 0;
 
-    for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx ) 
+    for( uint8_t proto_idx = first_proto; proto_idx < total_lru_count; ++proto_idx ) 
     {
         while ( auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx)) )
         {
@@ -840,7 +877,7 @@ void FlowCache::output_flow(std::fstream& stream, const Flow& flow, const struct
     }
     std::stringstream out;
     std::stringstream proto;
-    switch ( flow.pkt_type )
+    switch ( flow.key->pkt_type )
     {
         case PktType::IP:
             out << "Instance-ID: " << get_relative_instance_number() << " IP " << flow.key->addressSpaceId << ": " << src_ip << " " << dst_ip;
@@ -878,67 +915,51 @@ void FlowCache::output_flow(std::fstream& stream, const Flow& flow, const struct
         timeout_to_str(flow.expire_time - now.tv_sec) :
         timeout_to_str((flow.last_data_seen + config.proto[to_utype(flow.key->pkt_type)].nominal_timeout) - now.tv_sec);
     out << t;
-    stream << out.str() << proto.str() << std::endl;
+    stream << out.str() << proto.str() << (flow.flags.in_allowlist ? " (allowlist)" : "") << std::endl;
 }
 
 bool FlowCache::dump_flows(std::fstream& stream, unsigned count, const FilterFlowCriteria& ffc, bool first, uint8_t code) const
 {
     struct timeval now;
     packet_gettimeofday(&now);
-    unsigned i;
-    bool has_more_flows = false;
     Flow* walk_flow = nullptr;
+    bool has_more_flows = false;
 
-    for(uint8_t proto_id = to_utype(PktType::NONE)+1; proto_id <= to_utype(PktType::ICMP); proto_id++)
+    for (uint8_t proto_id = to_utype(PktType::NONE) + 1; proto_id < total_lru_count; ++proto_id)
     {
-        if (first)
-        {
+        if ( proto_id == to_utype(PktType::USER) or
+             proto_id == to_utype(PktType::FILE) or 
+             proto_id == to_utype(PktType::PDU) )
+            continue;
+
+        unsigned i = 0;
+
+        if ( first )
             walk_flow = static_cast<Flow*>(hash_table->get_walk_user_data(proto_id));
-            if (!walk_flow)
-            {
-                //Return only if all the protocol caches are processed.
-                if (proto_id < to_utype(PktType::ICMP))
-                    continue;
-                return !has_more_flows;
-            }
-            walk_flow->dump_code = code;
-            bool matched_filter = filter_flows(*walk_flow, ffc);
-            if (matched_filter)
-                output_flow(stream, *walk_flow, now);
-            i = 1;
-        }
         else
-            i = 0;
-        while (i < count)
-        {
             walk_flow = static_cast<Flow*>(hash_table->get_next_walk_user_data(proto_id));
 
-            if (!walk_flow )
-            {
-                //Return only if all the protocol caches are processed.
-                if (proto_id < to_utype(PktType::ICMP))
-                    break;
-                return !has_more_flows;
-            }
-            if (walk_flow->dump_code != code)
+        while ( walk_flow && i < count )
+        {
+            if  ( walk_flow->dump_code != code )
             {
                 walk_flow->dump_code = code;
-                bool matched_filter = filter_flows(*walk_flow, ffc);
-                if (matched_filter)
+                if( filter_flows(*walk_flow, ffc) )
                     output_flow(stream, *walk_flow, now);
                 ++i;
             }
-#ifdef REG_TEST
-            else
-                LogMessage("dump_flows skipping already dumped flow\n");
-#endif
+            if (i < count)
+                walk_flow = static_cast<Flow*>(hash_table->get_next_walk_user_data(proto_id));
         }
-        if(walk_flow) // we have output 'count' flows, but the protocol cache still has more flows
+
+        if ( walk_flow )
             has_more_flows = true;
     }
-    return false;
+
+    return !has_more_flows;
 }
 
+
 size_t FlowCache::uni_flows_size() const
 {
     return uni_flows ? uni_flows->get_count() : 0;
@@ -953,3 +974,32 @@ size_t FlowCache::flows_size() const
 {
     return hash_table->get_num_nodes();
 }
+
+PegCount FlowCache::get_lru_flow_count(uint8_t lru_idx) const
+{ 
+    return hash_table->get_node_count(lru_idx); 
+}
+
+bool FlowCache::move_to_allowlist(snort::Flow* f)
+{
+    if( hash_table->switch_lru_cache(f->key, to_utype(f->key->pkt_type), allowlist_lru_index) )
+    {
+        f->flags.in_allowlist = 1;
+        return true;
+    }
+    return false;
+}
+
+#ifdef UNIT_TEST
+size_t FlowCache::count_flows_in_lru(uint8_t lru_index) const
+{
+    size_t count = 0;
+    Flow* flow = static_cast<Flow*>(hash_table->get_walk_user_data(lru_index));
+    while (flow)
+    {
+        ++count;
+        flow = static_cast<Flow*>(hash_table->get_next_walk_user_data(lru_index));
+    }
+    return count;
+}
+#endif
index 7756a16afcb1717e50ffdf457d7e0fbe0dafdd8b..dd85c60c1f46424e6498c5d0b86b2e3640d1651f 100644 (file)
 #include "prune_stats.h"
 #include "filter_flow_critera.h"
 
+constexpr uint8_t max_protocols = static_cast<uint8_t>(to_utype(PktType::MAX));
+constexpr uint8_t allowlist_lru_index = max_protocols;
+constexpr uint8_t total_lru_count = max_protocols + 1;
+constexpr uint64_t all_lru_mask = (1ULL << max_protocols) - 1;
+constexpr uint8_t first_proto = to_utype(PktType::NONE) + 1;
+
 namespace snort
 {
 class Flow;
@@ -134,6 +140,8 @@ public:
     const FlowCacheConfig& get_flow_cache_config() const
     { return config; }
 
+    bool move_to_allowlist(snort::Flow* f);
+
     virtual bool filter_flows(const snort::Flow&, const FilterFlowCriteria&) const;
     virtual void output_flow(std::fstream&, const snort::Flow&, const struct timeval&) const;
 
@@ -142,6 +150,10 @@ public:
     size_t uni_flows_size() const;
     size_t uni_ip_flows_size() const;
     size_t flows_size() const;
+    PegCount get_lru_flow_count(uint8_t lru_idx) const;
+#ifdef UNIT_TEST
+    size_t count_flows_in_lru(uint8_t lru_index) const;
+#endif
 
 private:
     void delete_uni();
@@ -154,6 +166,24 @@ private:
     static std::string timeout_to_str(time_t);
     bool is_ip_match(const snort::SfIp& flow_ip, const snort::SfIp& filter_ip, const snort::SfIp& subnet) const;
 
+    inline bool is_lru_checked(uint64_t checked_lrus_mask, uint64_t lru_mask)
+    { return (checked_lrus_mask & lru_mask) != 0; }
+
+    inline bool all_lrus_checked(uint64_t checked_lrus_mask)
+    { return checked_lrus_mask == all_lru_mask; }
+
+    inline void mark_lru_checked(uint64_t& checked_lrus_mask, uint64_t lru_mask)
+    { checked_lrus_mask |= lru_mask; }
+
+    inline uint64_t get_lru_mask(uint8_t lru_idx)
+    { return 1ULL << lru_idx; }
+
+    inline void mark_lru_checked(uint64_t& checked_lrus_mask, uint64_t& empty_lru_masks, uint64_t lru_mask)
+    {
+        checked_lrus_mask |= lru_mask;
+        empty_lru_masks |= lru_mask;
+    }
+
 private:
     uint8_t timeout_idx;
     static const unsigned cleanup_flows = 1;
@@ -166,7 +196,8 @@ private:
 
     PruneStats prune_stats;
     FlowDeleteStats delete_stats;
-    uint64_t empty_proto;
+    uint64_t empty_lru_mask;
+
 };
 #endif
 
index 5b42ffc9db54bc130e20c02aa3917f67d53878b4..1638f4ff744dab2f361efb21f4d4e409a6c74dfb 100644 (file)
@@ -35,6 +35,7 @@ struct FlowCacheConfig
     unsigned pruning_timeout = 0;
     FlowTypeConfig proto[to_utype(PktType::MAX)];
     unsigned prune_flows = 0;
+    bool allowlist_cache = false;
 };
 
 #endif
index 16373c7749a046850584f7793f0d0e0e7ccd87b2..ea6b7fad1f7cace41a966493f785fa52481ef56e 100644 (file)
@@ -116,6 +116,18 @@ void FlowControl::release_flow(const FlowKey* key)
         cache->release(flow, PruneReason::HA);
 }
 
+bool FlowControl::move_to_allowlist(Flow* f)
+{
+    // Preserve the flow only if it is a TCP or UDP flow,
+    // as only these flow types contain appid-related info needed at the EOF event.
+    if ( f->key->pkt_type != PktType::TCP and f->key->pkt_type != PktType::UDP )
+        return false;
+    return cache->move_to_allowlist(f);
+}
+
+PegCount FlowControl::get_allowlist_flow_count() const
+{ return cache->get_lru_flow_count(allowlist_lru_index); }
+
 void FlowControl::release_flow(Flow* flow, PruneReason reason)
 { cache->release(flow, reason); }
 
index b84c1abca76770abe0a50eeff3e119025ba71332..2f3d62383269ffac4d28d965327923d704ee404b 100644 (file)
@@ -72,6 +72,7 @@ public:
     void timeout_flows(unsigned int, time_t cur_time);
     void check_expected_flow(snort::Flow*, snort::Packet*);
     unsigned prune_multiple(PruneReason, bool do_cleanup);
+    bool move_to_allowlist(snort::Flow*);
 
     bool dump_flows(std::fstream&, unsigned count, const FilterFlowCriteria& ffc, bool first, uint8_t code) const;
 
@@ -92,6 +93,7 @@ public:
     PegCount get_flows()
     { return num_flows; }
 
+    PegCount get_allowlist_flow_count() const;
     PegCount get_total_prunes() const;
     PegCount get_prunes(PruneReason) const;
     PegCount get_proto_prune_count(PruneReason, PktType) const;
index 53a22e5cd78b1eb2f547387e33e9fc81e9fff76b..3367ddfae414bf27eed42b4df3acb80162ff035c 100644 (file)
@@ -39,25 +39,25 @@ enum class PruneReason : uint8_t
     MAX
 };
 
-struct ProtoPruneStats
+struct LRUPruneStats
 {
-    using proto_t = std::underlying_type_t<PktType>;
-    PegCount proto_counts[static_cast<proto_t>(PktType::MAX)] { };
+    using lru_t = std::underlying_type_t<LRUType>;
+    PegCount lru_counts[static_cast<lru_t>(LRUType::MAX)] { };
 
     PegCount get_total() const
     {
         PegCount total = 0;
-        for ( proto_t i = 0; i < static_cast<proto_t>(PktType::MAX); ++i )
-            total += proto_counts[i];
+        for ( lru_t i = 0; i < static_cast<lru_t>(LRUType::MAX); ++i )
+            total += lru_counts[i];
 
         return total;
     }
 
     PegCount& get(PktType type)
-    { return proto_counts[static_cast<proto_t>(type)]; }
+    { return lru_counts[static_cast<lru_t>(type)]; }
 
     const PegCount& get(PktType type) const
-    { return proto_counts[static_cast<proto_t>(type)]; }
+    { return lru_counts[static_cast<lru_t>(type)]; }
 
     void update(PktType type)
     { ++get(type); }
@@ -68,7 +68,7 @@ struct PruneStats
     using reason_t = std::underlying_type<PruneReason>::type;
 
     PegCount prunes[static_cast<reason_t>(PruneReason::MAX)] { };
-    ProtoPruneStats protoPruneStats[static_cast<reason_t>(PruneReason::MAX)] { };
+    LRUPruneStats lruPruneStats[static_cast<reason_t>(PruneReason::MAX)] { };
 
     PegCount get_total() const
     {
@@ -88,20 +88,20 @@ struct PruneStats
     void update(PruneReason reason, PktType type = PktType::NONE)
     { 
         ++get(reason); 
-        protoPruneStats[static_cast<reason_t>(reason)].update(type);
+        lruPruneStats[static_cast<reason_t>(reason)].update(type);
     }
 
     PegCount& get_proto_prune_count(PruneReason reason, PktType type)
-    { return protoPruneStats[static_cast<reason_t>(reason)].get(type); }
+    { return lruPruneStats[static_cast<reason_t>(reason)].get(type); }
 
     const PegCount& get_proto_prune_count(PruneReason reason, PktType type) const
-    { return protoPruneStats[static_cast<reason_t>(reason)].get(type); }
+    { return lruPruneStats[static_cast<reason_t>(reason)].get(type); }
 
     PegCount get_proto_prune_count(PktType type) const
     {
         PegCount total = 0;
         for ( reason_t i = 0; i < static_cast<reason_t>(PruneReason::NONE); ++i )
-            total += protoPruneStats[i].get(type);
+            total += lruPruneStats[i].get(type);
 
         return total;
     }
index dcd28e0c3529fc38c943626fdffc072e94b6e39b..936b1b118dfa6cbdcdb8509ffeb634b59f5eb444 100644 (file)
@@ -396,6 +396,243 @@ TEST(flow_prune, prune_counts)
     CHECK_EQUAL(3, stats.get_proto_prune_count(PruneReason::IDLE_PROTOCOL_TIMEOUT, PktType::IP));
 }
 
+TEST_GROUP(allowlist_test) { };
+
+TEST(allowlist_test, move_to_allowlist)
+{
+    FlowCacheConfig fcg;
+    fcg.max_flows = 5;
+    DummyCache* cache = new DummyCache(fcg);
+    int port = 1;
+
+    // Adding two UDP flows and moving them to allow list
+    for (unsigned i = 0; i < 2; ++i) {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::UDP;
+        
+        Flow* flow = cache->allocate(&flow_key);
+        CHECK(cache->move_to_allowlist(flow) == true);  // Move flow to allow list
+
+        Flow* found_flow = cache->find(&flow_key);
+        CHECK(found_flow == flow);  // Verify flow is found
+        CHECK(found_flow->flags.in_allowlist == 1);  // Verify it's in allowlist
+    }
+
+    CHECK_EQUAL(2, cache->get_count());  // Check two flows in cache
+    CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index));  // Check 2 allow list flows
+
+    cache->purge();
+    delete cache;
+}
+
+
+TEST(allowlist_test, allowlist_timeout_prune_fail)
+{
+    FlowCacheConfig fcg;
+    fcg.max_flows = 5;
+    DummyCache* cache = new DummyCache(fcg);
+    int port = 1;
+
+    for (unsigned i = 0; i < 2; ++i)
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::TCP;
+        
+        Flow* flow = cache->allocate(&flow_key);
+        CHECK(cache->move_to_allowlist(flow) == true);
+    }
+
+    CHECK_EQUAL(2, cache->get_count());
+    CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index));
+
+    // Ensure pruning doesn't occur because all flows are allow listed
+    for (uint8_t i = 0; i < total_lru_count; ++i)
+        CHECK(cache->prune_one(PruneReason::IDLE_PROTOCOL_TIMEOUT, true, i) == false);
+    
+    CHECK_EQUAL(2, cache->get_count());
+    CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index));
+
+    cache->purge();
+    delete cache;
+}
+
+TEST(allowlist_test, allowlist_memcap_prune_pass)
+{
+    PruneStats stats;
+    FlowCacheConfig fcg;
+    fcg.max_flows = 10;
+    fcg.prune_flows = 5;
+    DummyCache* cache = new DummyCache(fcg);
+    int port = 1;
+
+    for (unsigned i = 0; i < 10; ++i)
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::TCP;
+        
+        Flow* flow = cache->allocate(&flow_key);
+        CHECK(cache->move_to_allowlist(flow) == true);
+    }
+
+    CHECK_EQUAL(10, cache->get_count());  // Check 10 flows in cache
+    CHECK_EQUAL(10, cache->get_lru_flow_count(allowlist_lru_index));  // Check 2 allow listed flows
+
+    CHECK_EQUAL(5, cache->prune_multiple(PruneReason::MEMCAP, true));
+    CHECK_EQUAL(5, cache->get_count());
+    CHECK_EQUAL(5, cache->get_proto_prune_count(PruneReason::MEMCAP, (PktType)allowlist_lru_index));
+    CHECK_EQUAL(5, cache->get_lru_flow_count(allowlist_lru_index));
+
+    cache->purge();
+    delete cache;
+}
+
+
+TEST(allowlist_test, allowlist_timeout_with_other_protos)
+{
+    FlowCacheConfig fcg;
+    fcg.max_flows = 10;
+    fcg.prune_flows = 10;
+
+    for (uint8_t i = to_utype(PktType::NONE); i <= to_utype(PktType::MAX); ++i) 
+        fcg.proto[i].nominal_timeout = 5;
+    
+    FlowCache* cache = new FlowCache(fcg);
+    int port = 1;
+
+    for (unsigned i = 0; i < 2; ++i) 
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::UDP;
+        
+        Flow* flow = cache->allocate(&flow_key);
+        CHECK(cache->move_to_allowlist(flow) == true);  // Move flow to allow list
+
+        Flow* found_flow = cache->find(&flow_key);
+        CHECK(found_flow == flow);
+        CHECK(found_flow->flags.in_allowlist == 1);
+    }
+
+    CHECK_EQUAL(2, cache->get_count());
+
+    // Ensure pruning doesn't occur because all flows are allow listed
+    for (uint8_t i = 0; i < to_utype(PktType::MAX) - 1; ++i) 
+        CHECK(cache->prune_one(PruneReason::NONE, true, i) == false);
+    
+    CHECK_EQUAL(2, cache->get_count());  // Ensure no flows were pruned
+
+    // Add a new ICMP flow
+    FlowKey flow_key_icmp;
+    flow_key_icmp.port_l = port++;
+    flow_key_icmp.pkt_type = PktType::ICMP;
+    cache->allocate(&flow_key_icmp);
+
+    CHECK_EQUAL(3, cache->get_count());
+    CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index));
+
+    // Prune Reason::NONE will not be able to prune allow listed flow, only 1 UDP
+    CHECK_EQUAL(1, cache->prune_multiple(PruneReason::NONE, true));
+
+    // we can't prune to 0 so 1 flow will be pruned
+    CHECK_EQUAL(1, cache->prune_multiple(PruneReason::MEMCAP, true));
+
+    CHECK_EQUAL(1, cache->get_count()); 
+    CHECK_EQUAL(1, cache->get_lru_flow_count(allowlist_lru_index));
+
+    // Adding five UDP flows, these will become the LRU flows
+    for (unsigned i = 0; i < 5; ++i) 
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::UDP;
+        
+        Flow* flow = cache->allocate(&flow_key);
+        flow->last_data_seen = 2 + i;
+    }
+
+    CHECK_EQUAL(6, cache->get_count());
+
+    // Adding three TCP flows, move two to allow list, making them MRU
+    for (unsigned i = 0; i < 3; ++i) 
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::TCP;
+        
+        Flow* flow = cache->allocate(&flow_key);
+        flow->last_data_seen = 4 + i;  // Set TCP flows to have later timeout
+
+        if (i > 0) 
+        {
+            CHECK(cache->move_to_allowlist(flow) == true);
+
+            Flow* found_flow = cache->find(&flow_key);
+            CHECK(found_flow == flow);
+            CHECK(found_flow->flags.in_allowlist == 1);
+        }
+    }
+
+    CHECK_EQUAL(5, cache->get_lru_flow_count(to_utype(PktType::UDP)));
+    CHECK_EQUAL(1, cache->get_lru_flow_count(to_utype(PktType::TCP)));
+    CHECK_EQUAL(3, cache->get_lru_flow_count(allowlist_lru_index));
+    CHECK_EQUAL(9, cache->get_count());  // Verify total flows (5 UDP + 1 TCP + 3 allow list)
+    CHECK_EQUAL(3, cache->get_lru_flow_count(allowlist_lru_index));  // Verify 3 allow listed flows
+
+    // Timeout 4 flows, 3 UDP and 1 TCP
+    CHECK_EQUAL(4, cache->timeout(5, 9));
+    CHECK_EQUAL(5, cache->get_count());  // Ensure 4 flows remain (2 UDP + 3 allow listed TCP)
+    CHECK_EQUAL(3, cache->count_flows_in_lru(allowlist_lru_index));
+    CHECK_EQUAL(0, cache->count_flows_in_lru(to_utype(PktType::TCP)));
+    CHECK_EQUAL(2, cache->count_flows_in_lru(to_utype(PktType::UDP)));
+
+    //try multiple prune 2 UDP flow should be pruned as other flows are allow listed
+    CHECK_EQUAL(2, cache->prune_multiple(PruneReason::NONE, true));
+
+    //memcap prune can prune all the flows
+    CHECK_EQUAL(2, cache->prune_multiple(PruneReason::MEMCAP, true));
+
+    CHECK_EQUAL(1, cache->get_count());
+    CHECK_EQUAL(1, cache->get_lru_flow_count(allowlist_lru_index));
+
+    // Clean up
+    cache->purge();
+    delete cache;
+}
+TEST(allowlist_test, excess_prune)
+{
+    FlowCacheConfig fcg;
+    fcg.max_flows = 5;
+    fcg.prune_flows = 2;
+    DummyCache* cache = new DummyCache(fcg);
+    int port = 1;
+
+    for (unsigned i = 0; i < 6; ++i)
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::TCP;
+        
+        Flow* flow = cache->allocate(&flow_key);
+        CHECK(cache->move_to_allowlist(flow) == true);
+    }
+
+    // allocating 6 flows and moving all to allowlist
+    // max_flows is 5 one flow should be pruned
+    CHECK_EQUAL(5, cache->get_count());
+    CHECK_EQUAL(5, cache->get_lru_flow_count(allowlist_lru_index));
+
+    // Prune 3 flows, expect 2 flows pruned
+    CHECK_EQUAL(2, cache->prune_multiple(PruneReason::EXCESS, true));
+    CHECK_EQUAL(3, cache->get_count());
+    CHECK_EQUAL(3, cache->get_lru_flow_count(allowlist_lru_index));
+
+    cache->purge();
+    delete cache;
+}
+
 TEST_GROUP(dump_flows) { };
 
 TEST(dump_flows, dump_flows_with_all_empty_caches)
@@ -495,6 +732,139 @@ TEST(dump_flows, dump_flows_with_102_tcp_flows_and_202_udp_flows)
     delete cache;
 }
 
+TEST(dump_flows, dump_flows_with_allowlist)
+{
+    FlowCacheConfig fcg;
+    fcg.max_flows = 500;
+    FilterFlowCriteria ffc;
+    std::fstream dump_stream;
+    DummyCache* cache = new DummyCache(fcg);
+    int port = 1;
+    FlowKey flow_key[10];
+
+    // Add TCP flows and mark some as allow listed
+    for (unsigned i = 0; i < 10; ++i)
+    {
+        flow_key[i].port_l = port++;
+        flow_key[i].pkt_type = PktType::TCP;
+        Flow* flow = cache->allocate(&flow_key[i]);
+        // Mark the first 5 flows as allow listed
+        if (i < 5)
+        {
+            CHECK(cache->move_to_allowlist(flow) == true);
+        }
+    }
+
+    CHECK(cache->get_count() == 10);
+
+    //check flows are properly moved to allow list
+    CHECK(cache->count_flows_in_lru(to_utype(PktType::TCP)) == 5);  // Check 5 TCP flows
+    CHECK(cache->count_flows_in_lru(allowlist_lru_index) == 5);  // Check 5 allow listed flows
+
+    // Check that the first dump call works (with allow listed and non-allow listed flows)
+    CHECK(cache->dump_flows(dump_stream, 10, ffc, true, 1) == true);
+
+
+    // Verify that allow listed flows exist and are correctly handled
+    for (unsigned i = 0; i < 5; ++i)
+    {
+        flow_key[i].port_l = i + 1;  // allow listed flow ports
+        flow_key[i].pkt_type = PktType::TCP;
+        Flow* flow = cache->find(&flow_key[i]);
+        CHECK(flow != nullptr);  // Ensure the flow is found
+        CHECK(flow->flags.in_allowlist == 1);  // Ensure the flow is allow listed
+    }
+
+    // Ensure cache cleanup and correct flow counts
+    cache->purge();
+    CHECK(cache->get_flows_allocated() == 0);
+    CHECK(cache->get_count() == 0);
+    delete cache;
+}
+
+TEST(dump_flows, dump_flows_no_flows_to_dump)
+{
+    FlowCacheConfig fcg;
+    FilterFlowCriteria ffc;
+    fcg.max_flows = 10;
+    std::fstream dump_stream;
+
+    DummyCache* cache = new DummyCache(fcg);
+    CHECK(cache->dump_flows(dump_stream, 100, ffc, true, 1) == true);
+
+    delete cache;   
+}
+
+TEST_GROUP(flow_cache_lrus) 
+{ 
+    FlowCacheConfig fcg;
+    DummyCache* cache;
+
+    void setup()
+    {
+        fcg.max_flows = 20;
+        cache = new DummyCache(fcg);
+    }
+
+    void teardown()
+    {
+        cache->purge();
+        delete cache;
+    }
+};
+
+TEST(flow_cache_lrus, count_flows_in_lru_test)
+{
+    FlowKey flow_keys[10];
+    memset(flow_keys, 0, sizeof(flow_keys));
+
+    flow_keys[0].pkt_type = PktType::TCP;
+    flow_keys[1].pkt_type = PktType::UDP;
+    flow_keys[2].pkt_type = PktType::USER;
+    flow_keys[3].pkt_type = PktType::FILE;
+    flow_keys[4].pkt_type = PktType::TCP;
+    flow_keys[5].pkt_type = PktType::TCP;
+    flow_keys[6].pkt_type = PktType::PDU;
+    flow_keys[7].pkt_type = PktType::ICMP;
+    flow_keys[8].pkt_type = PktType::TCP;
+    flow_keys[9].pkt_type = PktType::ICMP;
+
+    //flow count 4 TCP, 1 UDP, 1 USER, 1 FILE, 1 PDU, 2 ICMP = 10
+    // Add the flows to the hash_table
+    for (int i = 0; i < 10; ++i)
+    {
+        flow_keys[i].port_l = i;
+        Flow* flow = cache->allocate(&flow_keys[i]);
+        CHECK(flow != nullptr);
+    }
+
+    CHECK_EQUAL(10, cache->get_count());  // Verify 10 flows in 
+    CHECK_EQUAL(4, cache->count_flows_in_lru(to_utype(PktType::TCP)));  // 4 TCP flows
+    CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::UDP)));  // 1 UDP flow
+    CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::USER)));  // 1 USER flow
+    CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::FILE)));  // 1 FILE flow
+    CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::PDU)));  // 1 PDU flow
+    CHECK_EQUAL(2, cache->count_flows_in_lru(to_utype(PktType::ICMP)));  // 2 ICMP flow
+
+    Flow* flow1 = cache->find(&flow_keys[0]);
+    Flow* flow2 = cache->find(&flow_keys[1]);
+    Flow* flow3 = cache->find(&flow_keys[6]);
+    CHECK(cache->move_to_allowlist(flow1));
+    CHECK(cache->move_to_allowlist(flow2));
+    CHECK(cache->move_to_allowlist(flow3));
+
+    CHECK_EQUAL(10, cache->get_count());
+    CHECK_EQUAL(3, cache->count_flows_in_lru(to_utype(PktType::TCP)));  // 3 TCP flows
+    CHECK_EQUAL(0, cache->count_flows_in_lru(to_utype(PktType::UDP)));  // 0 UDP flows
+    CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::USER)));  // 1 USER flow
+    CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::FILE)));  // 1 FILE flow
+    CHECK_EQUAL(0, cache->count_flows_in_lru(to_utype(PktType::PDU)));  // 0 PDU flows
+    CHECK_EQUAL(2, cache->count_flows_in_lru(to_utype(PktType::ICMP)));  // 2 ICMP flow
+    // Check the allow listed flows
+    CHECK_EQUAL(3, cache->count_flows_in_lru(allowlist_lru_index));  // 3 allowlist flows
+
+}
+
 int main(int argc, char** argv)
 {
     return CommandLineTestRunner::RunAllTests(argc, argv);
index c238da128b695ab20c3f8f8f6d0875005517e66b..5785eb84a5278e2bebe29594e4767aabf18ea92b 100644 (file)
@@ -85,6 +85,8 @@ ExpectCache::ExpectCache(uint32_t) { }
 ExpectCache::~ExpectCache() = default;
 bool ExpectCache::check(Packet*, Flow*) { return true; }
 Flow* HighAvailabilityManager::import(Packet&, FlowKey&) { return nullptr; }
+bool FlowCache::move_to_allowlist(snort::Flow*) { return true; }
+uint64_t FlowCache::get_lru_flow_count(uint8_t) const { return 0; }
 
 namespace snort
 {
index 560eeb24dfc917fbc700a3e5c3c2acb4c1b74a2b..d5d5f024089b05d25a6d73b8b59003b00237dc34 100644 (file)
@@ -26,6 +26,8 @@
 #include "detection/context_switcher.h"
 #include "detection/detection_engine.h"
 #include "flow/flow.h"
+#include "flow/flow_config.h"
+#include "flow/flow_control.h"
 #include "flow/flow_stash.h"
 #include "flow/ha.h"
 #include "framework/inspector.h"
 #include "flow_stubs.h"
 
 using namespace snort;
+THREAD_LOCAL class FlowControl* flow_con;
+
+const FlowCacheConfig& FlowControl::get_flow_cache_config() const
+{
+    static FlowCacheConfig fcc;
+    fcc.allowlist_cache = true;
+    return fcc;
+}
+
+bool FlowControl:: move_to_allowlist(snort::Flow*) { return true; }
 
 void Inspector::rem_ref() {}
 
index 64a3522aa8dbb25a45f606729fb4e506007e862f..3e87718bc303c5733e4ab5cf1740c1b718e115be 100644 (file)
@@ -47,6 +47,20 @@ enum class PktType : std::uint8_t
     NONE, IP, TCP, UDP, ICMP, USER, FILE, PDU, MAX
 };
 
+enum class LRUType : std::uint8_t 
+{
+    NONE = static_cast<std::uint8_t>(PktType::NONE),
+    IP = static_cast<std::uint8_t>(PktType::IP),
+    TCP = static_cast<std::uint8_t>(PktType::TCP),
+    UDP = static_cast<std::uint8_t>(PktType::UDP),
+    ICMP = static_cast<std::uint8_t>(PktType::ICMP),
+    USER = static_cast<std::uint8_t>(PktType::USER),
+    FILE = static_cast<std::uint8_t>(PktType::FILE),
+    PDU = static_cast<std::uint8_t>(PktType::PDU),
+    ALLOW_LIST,
+    MAX
+};
+
 // the first several of these bits must map to PktType
 // eg PROTO_BIT__IP == BIT(PktType::IP), etc.
 #define PROTO_BIT__NONE             0x000000
index 43764bbc9c3d64864b441fe721a3676db5d0ea19..c6fe321495cc22b6d30b270231ec192e9a68aaee 100644 (file)
@@ -44,6 +44,7 @@ void HashLruCache::insert(HashNode* hnode)
     else
         tail = hnode;
     head = hnode;
+    node_count++;
 }
 
 void HashLruCache::touch(HashNode* hnode)
@@ -83,4 +84,5 @@ void HashLruCache::remove_node(HashNode* hnode)
 
     if ( tail == hnode )
         tail = hnode->gprev;
+    node_count--;
 }
index c5ed0e88b406b9b813665bb541af5fc1404a34b7..9b1d2f8ef895f6e327b161d583e1e47ab56411dc 100644 (file)
@@ -78,12 +78,16 @@ public:
         return rnode;
     }
 
+    inline uint64_t get_node_count()
+    { return node_count; }
+
 private:
     snort::HashNode* head = nullptr;
     snort::HashNode* tail = nullptr;
     snort::HashNode* cursor = nullptr;
     //walk_cursor is used to traverse from tail to head while dumping the flows.
     snort::HashNode* walk_cursor = nullptr;
+    uint64_t node_count = 0;
 };
 
 #endif
index b4b206f4b40230772db958bd987232444402490d..187fc76a5730d2c6627abde083acc5772e7af0c6 100644 (file)
@@ -305,13 +305,13 @@ void XHash::update_cursor()
     }
 }
 
-void* XHash::get_user_data(const void* key, uint8_t type)
+void* XHash::get_user_data(const void* key, uint8_t type, bool touch)
 {
     assert(key);
     assert(type < num_lru_caches);
 
     int rindex = 0;
-    HashNode* hnode = find_node_row(key, rindex, type);
+    HashNode* hnode = find_node_row(key, rindex, type, touch);
     return ( hnode ) ? hnode->data : nullptr;
 }
 
@@ -407,7 +407,14 @@ void XHash::move_to_front(HashNode* node,uint8_t type)
     lru_caches[type]->touch(node);
 }
 
-HashNode* XHash::find_node_row(const void* key, int& rindex, uint8_t type)
+void XHash::touch_last_found(uint8_t type)
+{
+    assert(type < num_lru_caches);
+    if ( lfind )
+        move_to_front(lfind, type);
+}
+
+HashNode* XHash::find_node_row(const void* key, int& rindex, uint8_t type, bool touch)
 {
     assert(type < num_lru_caches);
     unsigned hashkey = hashkey_ops->do_hash((const unsigned char*)key, keysize);
@@ -418,7 +425,9 @@ HashNode* XHash::find_node_row(const void* key, int& rindex, uint8_t type)
     {
         if ( hashkey_ops->key_compare(hnode->key, key, keysize) )
         {
-            move_to_front(hnode,type);
+            lfind = hnode;
+            if ( touch )
+                move_to_front(hnode, type);
             return hnode;
         }
     }
@@ -544,6 +553,28 @@ HashNode* XHash::release_lru_node(uint8_t type)
     return hnode;
 }
 
+void XHash::switch_node_lru_cache(HashNode* hnode, uint8_t old_type, uint8_t new_type)
+{
+    lru_caches[old_type]->remove_node(hnode);
+    lru_caches[new_type]->insert(hnode);
+}
+
+bool XHash::switch_lru_cache(const void* key, uint8_t old_type, uint8_t new_type)
+{
+    assert(old_type < num_lru_caches);
+    assert(new_type < num_lru_caches);
+
+    int rindex = 0;
+    HashNode* hnode = (HashNode*)find_node_row(key, rindex, old_type);
+    if ( hnode )
+    {
+        switch_node_lru_cache(hnode, old_type, new_type);
+        return true;
+    }
+
+    return false;
+}
+
 bool XHash::delete_lru_node(uint8_t type)
 {
     assert(type < num_lru_caches);
index 6ecf6206c9206f4bca1fcd15f9577c56a303296c..ffa716eac103cf1ffc152c292205cb7b7e1326d0 100644 (file)
@@ -58,7 +58,7 @@ public:
     HashNode* find_first_node();
     HashNode* find_next_node();
     void* get_user_data();
-    void* get_user_data(const void* key, uint8_t type = 0);
+    void* get_user_data(const void* key, uint8_t type = 0, bool touch = true);
     void release(uint8_t type = 0);
     int release_node(const void* key, uint8_t type = 0);
     int release_node(HashNode* node, uint8_t type = 0);
@@ -69,6 +69,8 @@ public:
     bool delete_lru_node(uint8_t type = 0);
     void clear_hash();
     bool full() const { return !fhead; }
+    bool switch_lru_cache(const void* key, uint8_t old_type, uint8_t new_type);
+    void touch_last_found(uint8_t type = 0);
 
     // set max hash nodes, 0 == no limit
     void set_max_nodes(int max)
@@ -100,7 +102,7 @@ protected:
 
     void initialize_node(HashNode*, const void* key, void* data, int index, uint8_t type = 0);
     HashNode* allocate_node(const void* key, void* data, int index);
-    HashNode* find_node_row(const void* key, int& rindex, uint8_t type = 0);
+    HashNode* find_node_row(const void* key, int& rindex, uint8_t type = 0, bool touch = true);
     void link_node(HashNode*);
     void unlink_node(HashNode*);
     bool delete_a_node();
@@ -130,6 +132,7 @@ private:
     HashKeyOperations* hashkey_ops = nullptr;
     HashNode* cursor = nullptr;
     HashNode* fhead = nullptr;
+    HashNode* lfind = nullptr;
     unsigned datasize = 0;
     unsigned long mem_cap = 0;
     unsigned max_nodes = 0;
@@ -142,6 +145,7 @@ private:
     HashNode* release_lru_node(uint8_t type = 0);
     void update_cursor();
     void purge_free_list();
+    void switch_node_lru_cache(HashNode* hnode, uint8_t old_type, uint8_t new_type);
 };
 
 } // namespace snort
index e3c937d57ee3470f0242663be1847c2c3f46de7b..79713db46f841122e54664dde0a1bf4b613b1984 100644 (file)
@@ -134,3 +134,9 @@ void ZHash::lru_touch(uint8_t type)
     assert(node);
     lru_caches[type]->touch(node);
 }
+
+uint64_t ZHash::get_node_count(uint8_t type)
+{
+    assert(type < num_lru_caches);
+    return lru_caches[type]->get_node_count();
+}
index e43076cc2ca7af8d67fc30664dca76c1bfd30420..52d7c43e9730ee2fabb7942d89626b9df73e5708 100644 (file)
@@ -36,6 +36,7 @@ public:
     void* pop();
 
     void* get(const void* key, uint8_t type = 0);
+    uint64_t get_node_count(uint8_t type);
     void* remove(uint8_t type = 0);
 
     void* lru_first(uint8_t type = 0);
index 05f173eac35697b74c835c03db03f88ab38ea763..15f8236289f5f1434252ab287716e0acaf0dfffe 100644 (file)
@@ -297,6 +297,8 @@ static DAQ_Verdict distill_verdict(Packet* p)
             verdict = DAQ_VERDICT_PASS;
             daq_stats.internal_whitelist++;
         }
+        else if ( p->flow )
+            p->flow->handle_allowlist();
     }
 
     if ( p->flow )
index 8b2ef50a306894dd22adaaac3f0530b73cd396b7..34a1ea037eb9663bbd4efd015b309e9a9aa46878 100644 (file)
@@ -27,6 +27,7 @@
 #include "filters/rate_filter.h"
 #include "filters/sfrf.h"
 #include "filters/sfthreshold.h"
+#include "flow/flow_control.h"
 #include "flow/ha.h"
 #include "framework/data_bus.h"
 #include "latency/packet_latency.h"
@@ -68,6 +69,7 @@ THREAD_LOCAL DAQStats daq_stats;
 THREAD_LOCAL bool RuleContext::enabled = false;
 THREAD_LOCAL bool snort::TimeProfilerStats::enabled;
 THREAD_LOCAL snort::PacketTracer* snort::PacketTracer::s_pkt_trace;
+THREAD_LOCAL class FlowControl* flow_con;
 
 void Profiler::start() { }
 void Profiler::stop(uint64_t) { }
@@ -230,11 +232,13 @@ IpsContext::IpsContext(unsigned) { }
 NetworkPolicy* get_network_policy() { return nullptr; }
 InspectionPolicy* get_inspection_policy() { return nullptr; }
 Flow::~Flow() = default;
+bool Flow::handle_allowlist() { return true; }
 void ThreadConfig::implement_thread_affinity(SThreadType, unsigned) { }
 void ThreadConfig::apply_thread_policy(SThreadType , unsigned ) { }
 void ThreadConfig::set_instance_tid(int) { }
 }
 
+bool FlowControl::move_to_allowlist(snort::Flow*) { return true; }
 void memory::MemoryCap::thread_init() { }
 void memory::MemoryCap::thread_term() { }
 void memory::MemoryCap::free_space() { }
index 0159654d8bd4a3ea0f031d330598007f6b748420..0a69a0cae05c5a7b78a4ec6a0c126e92a57eae2a 100644 (file)
@@ -51,12 +51,20 @@ void Flow::trust() { }
 
 SFDAQInstance* SFDAQ::get_local_instance() { return nullptr; }
 
+
 unsigned int get_random_seed()
 { return 3193; }
 unsigned DataBus::get_id(const PubKey&)
 { return 0; }
 }
 
+const FlowCacheConfig& FlowControl::get_flow_cache_config() const
+{
+    static FlowCacheConfig cfg;
+    cfg.allowlist_cache = true;
+    return cfg;
+}
+
 using namespace snort;
 
 //--------------------------------------------------------------------------
index 5f0df34c1278509c6cd94f4f79ee4063b10db29c..084adb43cddc34e93f167fa375b8329873b1863d 100644 (file)
@@ -96,15 +96,17 @@ const PegInfo base_pegs[] =
     { CountType::SUM, "user_memcap_prunes", "number of USER flows pruned due to memcap" },
     { CountType::SUM, "file_memcap_prunes", "number of FILE flows pruned due to memcap" },
     { CountType::SUM, "pdu_memcap_prunes", "number of PDU flows pruned due to memcap" },
+    { CountType::SUM, "allowlist_memcap_prunes", "number of allowlist flows pruned due to memcap" },
 
     // Keep the NOW stats at the bottom as it requires special sum_stats logic
+    { CountType::NOW, "allowlist_flows", "number of flows moved to the allow list" },
     { CountType::NOW, "current_flows", "current number of flows in cache" },
     { CountType::NOW, "uni_flows", "number of uni flows in cache" },
     { CountType::NOW, "uni_ip_flows", "number of uni ip flows in cache" },
     { CountType::END, nullptr, nullptr }
 };
 
-#define NOW_PEGS_NUM 3
+#define NOW_PEGS_NUM 4
 
 // FIXIT-L dependency on stats define in another file
 void base_prep()
@@ -139,7 +141,9 @@ void base_prep()
     stream_base_stats.user_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, PktType::USER);
     stream_base_stats.file_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, PktType::FILE);
     stream_base_stats.pdu_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, PktType::PDU);
+    stream_base_stats.allowlist_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, static_cast<PktType>(allowlist_lru_index));
 
+    stream_base_stats.allowlist_flows = flow_con->get_allowlist_flow_count();
     stream_base_stats.current_flows = flow_con->get_num_flows();
     stream_base_stats.uni_flows = flow_con->get_uni_flows();
     stream_base_stats.uni_ip_flows = flow_con->get_uni_ip_flows();
index 895510569bec2c1620a088949385527d97a14766..af8bf25068efc7bdf6c8d024b6fd395193c2fff4 100644 (file)
@@ -67,6 +67,14 @@ static const Parameter name[] = \
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } \
 }
 
+static const Parameter allowlist_cache_params[] =
+{
+    { "enable", Parameter::PT_BOOL, nullptr, "false",
+      "enable allowlist cache" },
+
+    { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
+};
+
 FLOW_TYPE_PARAMS(ip_params, "180");
 FLOW_TYPE_PARAMS(icmp_params, "180");
 FLOW_TYPE_PARAMS(tcp_params, "3600");
@@ -103,6 +111,8 @@ static const Parameter s_params[] =
     { "require_3whs", Parameter::PT_INT, "-1:max31", "-1",
       "don't track midstream TCP sessions after given seconds from start up; -1 tracks all" },
 
+    { "allowlist_cache", Parameter::PT_TABLE, allowlist_cache_params, nullptr, "configure allowlist cache" },
+
     FLOW_TYPE_TABLE("ip_cache",   "ip",   ip_params),
     FLOW_TYPE_TABLE("icmp_cache", "icmp", icmp_params),
     FLOW_TYPE_TABLE("tcp_cache",  "tcp",  tcp_params),
@@ -317,6 +327,9 @@ bool StreamModule::set(const char* fqn, Value& v, SnortConfig* c)
     else if ( v.is("require_3whs") )
         config.hs_timeout = v.get_int32();
 
+    else if ( !strcmp(fqn, "stream.allowlist_cache.enable") )
+        config.flow_cache_cfg.allowlist_cache = v.get_bool();
+
     else if ( !strcmp(fqn, "stream.file_cache.idle_timeout") )
         config.flow_cache_cfg.proto[to_utype(PktType::FILE)].nominal_timeout = v.get_uint32();
 
@@ -493,6 +506,12 @@ void StreamModuleConfig::show() const
 
         ConfigLogger::log_value(flow_type_names[i], tmp.c_str());
     }
+    {
+        std::string tmp;
+        tmp += "{ enable = " + (flow_cache_cfg.allowlist_cache ? std::string("true") : std::string("false"));
+        tmp += " }";
+        ConfigLogger::log_value("allowlist_cache", tmp.c_str());
+    }
 }
 
 bool HPQReloadTuner::tinit()
index d4ccdb619b458a460c9714bd189a0c5059007233..ab7c4d6eea15e671c79121a702b17bfb6f24d502 100644 (file)
@@ -103,8 +103,10 @@ struct BaseStats
      PegCount user_memcap_prunes;
      PegCount file_memcap_prunes;
      PegCount pdu_memcap_prunes;
+     PegCount allowlist_memcap_prunes;
 
      // Keep the NOW stats at the bottom as it requires special sum_stats logic
+     PegCount allowlist_flows;
      PegCount current_flows;
      PegCount uni_flows;
      PegCount uni_ip_flows;