]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4014: flow_cache: added protocol base LRU caches
authorRaza Shafiq (rshafiq) <rshafiq@cisco.com>
Fri, 29 Sep 2023 16:22:27 +0000 (16:22 +0000)
committerSteven Baigal (sbaigal) <sbaigal@cisco.com>
Fri, 29 Sep 2023 16:22:27 +0000 (16:22 +0000)
Merge in SNORT/snort3 from ~RSHAFIQ/snort3:proto_base_lru to master

Squashed commit of the following:

commit 792e5b3c3febeece0f174c16a84646a3fb2e8a94
Author: rshafiq <rshafiq@cisco.com>
Date:   Wed Sep 13 07:23:23 2023 -0400

    flow_cache: added protocol base LRU caches

src/flow/dev_notes.txt
src/flow/flow_cache.cc
src/flow/flow_cache.h
src/flow/test/flow_cache_test.cc
src/flow/test/flow_control_test.cc
src/hash/dev_notes.txt
src/hash/xhash.cc
src/hash/xhash.h
src/hash/zhash.cc
src/hash/zhash.h

index d6ae957b848e5ed73314710a3ef54396889fe2c2..0d121440ef65daee40ce110fbfc430592ff3a724 100644 (file)
@@ -69,3 +69,20 @@ The HA subsystem implements these classes:
     and is handled as a special case.  Client 0 is the fundamental session HA
     state sync functionality.  Other clients are optional.
 
+
+09/25/2023
+In response to the need for more nuanced management of different protocol 
+types within the `flow_cache`, we've leveraged the newly added feature in 
+xhash. This enhancement includes the introduction of Protocol-Based 
+Least Recently Used (LRU) caches, which substantially improves the handling 
+of multiple protocol types in the cache.
+With this implementation, during various pruning sessions (idle, 
+excess, and timeout etc), we've adopted a round-robin method to cycle through 
+the LRU lists designated for each protocol type. This methodology ensures 
+that every protocol, irrespective of its inherent timeout characteristics, 
+is given an equal opportunity for pruning and retirement from the cache.
+This adjustment to `flow_cache` not only addresses the previously observed 
+bias during pruning sessions, especially concerning UDP flows, 
+but lays down a framework for more advanced and controlled data management 
+within the cache moving forward.
+
index 6deb4a992d0a5fb5ed6ec5e228893f2640c235b5..d218fcb99b8bf1373354f0562346bb1784f44cc0 100644 (file)
@@ -51,6 +51,8 @@ static const unsigned OFFLOADED_FLOWS_TOO = 2;
 static const unsigned ALL_FLOWS = 3;
 static const unsigned WDT_MASK = 7; // kick watchdog once for every 8 flows deleted
 
+const uint8_t MAX_PROTOCOLS = (uint8_t)to_utype(PktType::MAX) - 1; //removing PktType::NONE from count
+
 //-------------------------------------------------------------------------
 // FlowCache stuff
 //-------------------------------------------------------------------------
@@ -59,7 +61,7 @@ extern THREAD_LOCAL const snort::Trace* stream_trace;
 
 FlowCache::FlowCache(const FlowCacheConfig& cfg) : config(cfg)
 {
-    hash_table = new ZHash(config.max_flows, sizeof(FlowKey), false);
+    hash_table = new ZHash(config.max_flows, sizeof(FlowKey), MAX_PROTOCOLS, false);
     uni_flows = new FlowUniList;
     uni_ip_flows = new FlowUniList;
     flags = 0x0;
@@ -100,8 +102,7 @@ unsigned FlowCache::get_count()
 
 Flow* FlowCache::find(const FlowKey* key)
 {
-    Flow* flow = (Flow*)hash_table->get_user_data(key);
-
+    Flow* flow = (Flow*)hash_table->get_user_data(key,to_utype(key->pkt_type));
     if ( flow )
     {
         time_t t = packet_time();
@@ -172,7 +173,7 @@ Flow* FlowCache::allocate(const FlowKey* key)
     Flow* flow = new Flow;
     push(flow);
 
-    flow = (Flow*)hash_table->get(key);
+    flow = (Flow*)hash_table->get(key, to_utype(key->pkt_type));
     assert(flow);
     link_uni(flow);
     flow->last_data_seen = timestamp;
@@ -187,7 +188,7 @@ void FlowCache::remove(Flow* flow)
     const snort::FlowKey* key = flow->key;
     // Delete before releasing the node, so that the key is valid until the flow is completely freed
     delete flow;
-    hash_table->release_node(key);
+    hash_table->release_node(key, to_utype(key->pkt_type));
 }
 
 bool FlowCache::release(Flow* flow, PruneReason reason, bool do_cleanup)
@@ -220,39 +221,45 @@ unsigned FlowCache::prune_idle(uint32_t thetime, const Flow* save_me)
     ActiveSuspendContext act_susp(Active::ASP_PRUNE);
 
     unsigned pruned = 0;
-    auto flow = static_cast<Flow*>(hash_table->lru_first());
+    uint64_t skip_protos = 0;
+
+    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+
+    const uint64_t max_skip_protos = (1ULL << MAX_PROTOCOLS) - 1;
 
     {
         PacketTracerSuspend pt_susp;
-
-        while ( flow and pruned <= cleanup_flows )
+        while ( pruned <= cleanup_flows and 
+                skip_protos != max_skip_protos )
         {
-#if 0
-            // FIXIT-RC this loops forever if 1 flow in cache
-            if (flow == save_me)
+            // Round-robin through the proto types
+            for( uint8_t proto_idx = 0; proto_idx < MAX_PROTOCOLS; ++proto_idx ) 
             {
-                break;
-                if ( hash_table->get_count() == 1 )
+                if( pruned > cleanup_flows )
                     break;
 
-                hash_table->lru_touch();
+                if ( skip_protos & (1ULL << proto_idx) )
+                    continue;
+
+                auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx));
+                if ( !flow )
+                {
+                    skip_protos |= (1ULL << proto_idx);
+                    continue;
+                }
+                
+                if ( flow == save_me // Reached the current flow. This *should* be the newest flow
+                    or flow->is_suspended()
+                    or flow->last_data_seen + config.pruning_timeout >= thetime )
+                {
+                    skip_protos |= (1ULL << proto_idx);
+                    continue;
+                }
+
+                flow->ssn_state.session_flags |= SSNFLAG_TIMEDOUT;
+                if ( release(flow, PruneReason::IDLE_MAX_FLOWS) )
+                    ++pruned;
             }
-#else
-            // Reached the current flow. This *should* be the newest flow
-            if ( flow == save_me )
-                break;
-#endif
-            if ( flow->is_suspended() )
-                break;
-
-            if ( flow->last_data_seen + config.pruning_timeout >= thetime )
-                break;
-
-            flow->ssn_state.session_flags |= SSNFLAG_TIMEDOUT;
-            if ( release(flow, PruneReason::IDLE_MAX_FLOWS) )
-                ++pruned;
-
-            flow = static_cast<Flow*>(hash_table->lru_first());
         }
     }
 
@@ -312,39 +319,62 @@ unsigned FlowCache::prune_excess(const Flow* save_me)
     // initially skip offloads but if that doesn't work the hash table is iterated from the
     // beginning again. prune offloads at that point.
     unsigned ignore_offloads = hash_table->get_num_nodes();
+    uint64_t skip_protos = 0;
+
+    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+
+    const uint64_t max_skip_protos = (1ULL << MAX_PROTOCOLS) - 1;
 
     {
         PacketTracerSuspend pt_susp;
         unsigned blocks = 0;
 
-        while ( hash_table->get_num_nodes() > max_cap and hash_table->get_num_nodes() > blocks )
+        while ( true )
         {
-            auto flow = static_cast<Flow*>(hash_table->lru_first());
-            assert(flow); // holds true because hash_table->get_count() > 0
-
-            if ( (save_me and flow == save_me) or flow->was_blocked() or
-                    (flow->is_suspended() and ignore_offloads) )
-            {
-                // check for non-null save_me above to silence analyzer
-                // "called C++ object pointer is null" here
-                if ( flow->was_blocked() )
-                    ++blocks;
-
-                // FIXIT-M we should update last_data_seen upon touch to ensure
-                // the hash_table LRU list remains sorted by time
-                hash_table->lru_touch();
-            }
-            else
+            auto num_nodes = hash_table->get_num_nodes();
+            if ( num_nodes <= max_cap or num_nodes <= blocks or 
+                    ignore_offloads == 0 or skip_protos == max_skip_protos )
+                    break;
+            
+            for( uint8_t proto_idx = 0; proto_idx < MAX_PROTOCOLS; ++proto_idx )  
             {
-                flow->ssn_state.session_flags |= SSNFLAG_PRUNED;
-                if ( release(flow, PruneReason::EXCESS) )
-                    ++pruned;
+                num_nodes = hash_table->get_num_nodes();
+                if ( num_nodes <= max_cap or num_nodes <= blocks )
+                    break;
+                
+                if ( skip_protos & (1ULL << proto_idx) ) 
+                    continue;
+
+                auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx));
+                if ( !flow )
+                {
+                    skip_protos |= (1ULL << proto_idx);
+                    continue;
+                }
+
+                if ( (save_me and flow == save_me) or flow->was_blocked() or 
+                        (flow->is_suspended() and ignore_offloads) )
+                {
+                    // check for non-null save_me above to silence analyzer
+                    // "called C++ object pointer is null" here
+                    if ( flow->was_blocked() )
+                        ++blocks;
+                    // FIXIT-M we should update last_data_seen upon touch to ensure
+                    // the hash_table LRU list remains sorted by time
+                    hash_table->lru_touch(proto_idx);
+                }
+                else
+                {
+                    flow->ssn_state.session_flags |= SSNFLAG_PRUNED;
+                    if ( release(flow, PruneReason::EXCESS) )
+                        ++pruned;
+                }
+                if ( ignore_offloads > 0 )
+                    --ignore_offloads;
             }
-            if ( ignore_offloads > 0 )
-                --ignore_offloads;
         }
 
-        if (!pruned and hash_table->get_num_nodes() > max_cap)
+        if ( !pruned and hash_table->get_num_nodes() > max_cap )
         {
             pruned += prune_multiple(PruneReason::EXCESS, true);
         }
@@ -356,15 +386,16 @@ unsigned FlowCache::prune_excess(const Flow* save_me)
     return pruned;
 }
 
-bool FlowCache::prune_one(PruneReason reason, bool do_cleanup)
+bool FlowCache::prune_one(PruneReason reason, bool do_cleanup, uint8_t type)
 {
     // so we don't prune the current flow (assume current == MRU)
     if ( hash_table->get_num_nodes() <= 1 )
         return false;
 
     // ZHash returns in LRU order, which is updated per packet via find --> move_to_front call
-    auto flow = static_cast<Flow*>(hash_table->lru_first());
-    assert(flow);
+    auto flow = static_cast<Flow*>(hash_table->lru_first(type));
+    if( !flow )
+        return false;
 
     flow->ssn_state.session_flags |= SSNFLAG_PRUNED;
     release(flow, reason, do_cleanup);
@@ -378,8 +409,29 @@ unsigned FlowCache::prune_multiple(PruneReason reason, bool do_cleanup)
     // so we don't prune the current flow (assume current == MRU)
     if ( hash_table->get_num_nodes() <= 1 )
         return 0;
+    
+    uint8_t proto = 0;
+    uint64_t skip_protos = 0;
 
-    for (pruned = 0; pruned < config.prune_flows && prune_one(reason, do_cleanup); pruned++);
+    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+
+    const uint64_t max_skip_protos = (1ULL << MAX_PROTOCOLS) - 1;
+
+    while ( pruned < config.prune_flows )
+    {
+        if ( (skip_protos & (1ULL << proto)) or !prune_one(reason, do_cleanup, proto) )
+        {
+
+            skip_protos |= (1ULL << proto);
+            if ( skip_protos == max_skip_protos )
+                break;
+        }
+        else
+            pruned++;
+       
+        if ( ++proto >= MAX_PROTOCOLS )
+            proto = 0;
+    }
 
     if ( PacketTracer::is_active() and pruned )
         PacketTracer::log("Flow: Pruned memcap %u flows\n", pruned);
@@ -392,37 +444,54 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime)
     ActiveSuspendContext act_susp(Active::ASP_TIMEOUT);
 
     unsigned retired = 0;
+    uint64_t skip_protos = 0;
+
+    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
 
+    const uint64_t max_skip_protos = (1ULL << MAX_PROTOCOLS) - 1;
     {
         PacketTracerSuspend pt_susp;
 
-        auto flow = static_cast<Flow*>(hash_table->lru_current());
-
-        if ( !flow )
-            flow = static_cast<Flow*>(hash_table->lru_first());
-
-        while ( flow and (retired < num_flows) )
+        while ( retired < num_flows and skip_protos != max_skip_protos )
         {
-            if ( flow->is_hard_expiration() )
+            for( uint8_t proto_idx = 0; proto_idx < MAX_PROTOCOLS; ++proto_idx ) 
             {
-                if ( flow->expire_time > (uint64_t) thetime )
+                if( retired >= num_flows )
                     break;
-            }
-            else if ( flow->last_data_seen + flow->idle_timeout > thetime )
-                break;
 
-            if ( HighAvailabilityManager::in_standby(flow) or
-                    flow->is_suspended() )
-            {
-                flow = static_cast<Flow*>(hash_table->lru_next());
-                continue;
+                if ( skip_protos & (1ULL << proto_idx) ) 
+                    continue;
+
+                auto flow = static_cast<Flow*>(hash_table->lru_current(proto_idx));
+                if ( !flow )
+                    flow = static_cast<Flow*>(hash_table->lru_first(proto_idx));
+                if ( !flow )
+                {
+                    skip_protos |= (1ULL << proto_idx);
+                    continue;
+                }
+
+                if ( flow->is_hard_expiration() )
+                {
+                    if ( flow->expire_time > static_cast<uint64_t>(thetime) )
+                    {
+                        skip_protos |= (1ULL << proto_idx);
+                        continue;
+                    }
+                }
+                else if ( flow->last_data_seen + flow->idle_timeout > thetime )
+                {
+                    skip_protos |= (1ULL << proto_idx);
+                    continue;
+                }
+
+                if ( HighAvailabilityManager::in_standby(flow) or flow->is_suspended() )
+                    continue;
+
+                flow->ssn_state.session_flags |= SSNFLAG_TIMEDOUT;
+                if ( release(flow, PruneReason::IDLE_PROTOCOL_TIMEOUT) )
+                    ++retired;
             }
-
-            flow->ssn_state.session_flags |= SSNFLAG_TIMEDOUT;
-            if ( release(flow, PruneReason::IDLE_PROTOCOL_TIMEOUT) )
-                ++retired;
-
-            flow = static_cast<Flow*>(hash_table->lru_current());
         }
     }
 
@@ -434,39 +503,60 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime)
 
 unsigned FlowCache::delete_active_flows(unsigned mode, unsigned num_to_delete, unsigned &deleted)
 {
-    unsigned flows_to_check = hash_table->get_num_nodes();
-    while ( num_to_delete && flows_to_check-- )
+    uint64_t skip_protos = 0;
+    uint64_t undeletable = 0;
+
+    assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos));
+
+    const uint64_t max_skip_protos = (1ULL << MAX_PROTOCOLS) - 1;
+
+    while ( num_to_delete and skip_protos != max_skip_protos and
+            undeletable < hash_table->get_num_nodes() )
     {
-        auto flow = static_cast<Flow*>(hash_table->lru_first());
-        assert(flow);
-        if ( (mode == ALLOWED_FLOWS_ONLY and (flow->was_blocked() || flow->is_suspended()))
-                or (mode == OFFLOADED_FLOWS_TOO and flow->was_blocked()) )
+        for( uint8_t proto_idx = 0; proto_idx < MAX_PROTOCOLS; ++proto_idx ) 
         {
-            hash_table->lru_touch();
-            continue;
-        }
+            if( num_to_delete == 0)
+                break;
+            
+            if ( skip_protos & (1ULL << proto_idx) )
+                continue;
+            
+            auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx));
+            if ( !flow )
+            {
+                skip_protos |= (1ULL << proto_idx);
+                continue;
+            }
 
-        if ( (deleted & WDT_MASK) == 0 )
-            ThreadConfig::preemptive_kick();
+            if ( (mode == ALLOWED_FLOWS_ONLY and (flow->was_blocked() or flow->is_suspended()))
+                or (mode == OFFLOADED_FLOWS_TOO and flow->was_blocked()) )
+            {
+                undeletable++;
+                hash_table->lru_touch(proto_idx);
+                continue;
+            }
 
-        unlink_uni(flow);
+            if ( (deleted & WDT_MASK) == 0 )
+                ThreadConfig::preemptive_kick();
 
-        if ( flow->was_blocked() )
-            delete_stats.update(FlowDeleteState::BLOCKED);
-        else if ( flow->is_suspended() )
-            delete_stats.update(FlowDeleteState::OFFLOADED);
-        else
-            delete_stats.update(FlowDeleteState::ALLOWED);
-
-        flow->reset(true);
-        // Delete before removing the node, so that the key is valid until the flow is completely freed
-        delete flow;
-        //The flow should not be removed from the hash before reset
-        hash_table->remove();
-        ++deleted;
-        --num_to_delete;
-    }
+            unlink_uni(flow);
 
+            if ( flow->was_blocked() )
+                delete_stats.update(FlowDeleteState::BLOCKED);
+            else if ( flow->is_suspended() )
+                delete_stats.update(FlowDeleteState::OFFLOADED);
+            else
+                delete_stats.update(FlowDeleteState::ALLOWED);
+
+            flow->reset(true);
+            // Delete before removing the node, so that the key is valid until the flow is completely freed
+            delete flow;
+            // The flow should not be removed from the hash before reset
+            hash_table->remove(proto_idx);
+            ++deleted;
+            --num_to_delete;
+        }
+    }
     return num_to_delete;
 }
 
@@ -496,15 +586,17 @@ unsigned FlowCache::purge()
     FlagContext<decltype(flags)>(flags, SESSION_CACHE_FLAG_PURGING);
 
     unsigned retired = 0;
-    while ( auto flow = static_cast<Flow*>(hash_table->lru_first()) )
+
+    for( uint8_t proto_idx = 0; proto_idx < MAX_PROTOCOLS; ++proto_idx ) 
     {
-        retire(flow);
-        ++retired;
+        while ( auto flow = static_cast<Flow*>(hash_table->lru_first(proto_idx)) )
+        {
+            retire(flow);
+            ++retired;
+        }
     }
-
     // Remove these here so alloc/dealloc counts are right when Memory::get_pegs is called
     delete_uni();
-
     return retired;
 }
 
index 9ea5acced709bcffd647b0641887234e364aa651..c7db2371d3bbe070b149018b974c955a89b2189e 100644 (file)
@@ -58,7 +58,7 @@ public:
 
     unsigned prune_idle(uint32_t thetime, const snort::Flow* save_me);
     unsigned prune_excess(const snort::Flow* save_me);
-    bool prune_one(PruneReason, bool do_cleanup);
+    bool prune_one(PruneReason, bool do_cleanup, uint8_t type = 0);
     unsigned timeout(unsigned num_flows, time_t cur_time);
     unsigned delete_flows(unsigned num_to_delete);
     unsigned prune_multiple(PruneReason, bool do_cleanup);
index 777ed99dcedc8c116201aa3210d9d9daceb9d350..6ab9d5bb97691bee9e4a19c779c9481016328daf 100644 (file)
@@ -94,7 +94,7 @@ ExpectCache::ExpectCache(uint32_t) { }
 bool ExpectCache::check(Packet*, Flow*) { return true; }
 bool ExpectCache::is_expected(Packet*) { return true; }
 Flow* HighAvailabilityManager::import(Packet&, FlowKey&) { return nullptr; }
-bool HighAvailabilityManager::in_standby(Flow*) { return true; }
+bool HighAvailabilityManager::in_standby(Flow*) { return false; }
 SfIpRet SfIp::set(void const*, int) { return SFIP_SUCCESS; }
 void snort::trace_vprintf(const char*, TraceLevel, const char*, const Packet*, const char*, va_list) {}
 uint8_t snort::TraceApi::get_constraints_generation() { return 0; }
@@ -265,6 +265,82 @@ TEST(flow_prune, prune_all_blocked_flows)
     delete cache;
 }
 
+
+// prune base on the proto type of the flow
+TEST(flow_prune, prune_proto)
+{
+    FlowCacheConfig fcg;
+    fcg.max_flows = 5;
+    fcg.prune_flows = 3;
+    
+    for(uint8_t i = to_utype(PktType::NONE); i <= to_utype(PktType::MAX); i++)
+        fcg.proto[i].nominal_timeout = 5;
+    
+    FlowCache *cache = new FlowCache(fcg);
+    int port = 1;
+
+    for ( unsigned i = 0; i < 2; i++ )
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::UDP;
+        cache->allocate(&flow_key);
+    }
+
+    CHECK (cache->get_count() == 2);
+
+    //pruning should not happen for all other proto except UDP
+    for(uint8_t i = 0; i < to_utype(PktType::MAX) - 1; i++)
+    {
+        if (i == to_utype(PktType::UDP))
+            continue;
+        CHECK(cache->prune_one(PruneReason::NONE, true, i) == false);
+    }
+    
+    //pruning should happen for UDP
+    CHECK(cache->prune_one(PruneReason::NONE, true, to_utype(PktType::UDP)) == true);
+
+    FlowKey flow_key2;
+    flow_key2.port_l = port++;
+    flow_key2.pkt_type = PktType::ICMP;
+    cache->allocate(&flow_key2);
+
+    CHECK (cache->get_count() == 2);
+
+    //target flow is ICMP
+    CHECK(cache->prune_multiple(PruneReason::NONE, true) == 1);
+
+    //adding UDP flow it will become LRU
+    for ( unsigned i = 0; i < 2; i++ )
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::UDP;
+        Flow* flow = cache->allocate(&flow_key);
+        flow->last_data_seen = 2+i;
+    }
+
+    //adding TCP flow it will become MRU and put UDP flow to LRU
+    for ( unsigned i = 0; i < 3; i++ )
+    {
+        FlowKey flow_key;
+        flow_key.port_l = port++;
+        flow_key.pkt_type = PktType::TCP;
+        Flow* flow = cache->allocate(&flow_key);
+        flow->last_data_seen = 4+i; //this will force to timeout later than UDP
+    }
+
+    //timeout should happen for 2 UDP and 1 TCP flow
+    CHECK( 3 == cache->timeout(5,9));
+
+    //target flow UDP flow and it will fail because no UDP flow is present
+    CHECK(cache->prune_one(PruneReason::NONE, true, to_utype(PktType::UDP)) == false);
+
+    cache->purge();
+    CHECK(cache->get_flows_allocated() == 0);
+    delete cache;
+}
+
 int main(int argc, char** argv)
 {
     return CommandLineTestRunner::RunAllTests(argc, argv);
index 267eaeb27ba56896dc76aa3df1683ed504f064a4..cf83b6b8ffcb97a4b280834ccfafd8aa48ba9116 100644 (file)
@@ -77,7 +77,7 @@ unsigned FlowCache::get_flows_allocated() const { return 0; }
 Flow* FlowCache::find(const FlowKey*) { return nullptr; }
 Flow* FlowCache::allocate(const FlowKey*) { return nullptr; }
 void FlowCache::push(Flow*) { }
-bool FlowCache::prune_one(PruneReason, bool) { return true; }
+bool FlowCache::prune_one(PruneReason, bool, uint8_t) { return true; }
 unsigned FlowCache::prune_multiple(PruneReason , bool) { return 0; }
 unsigned FlowCache::delete_flows(unsigned) { return 0; }
 unsigned FlowCache::timeout(unsigned, time_t) { return 1; }
index 7765b6c6b1734df0ae236c69ac8befd90fb049b0..e4cff8b75bf661aec3cbb0798b857c3fdbee43e8 100644 (file)
@@ -18,3 +18,17 @@ For thread-safe shared caches:
 
 * lru_cache_shared: A thread-safe LRU map.
 
+09/25/2023
+
+A vector of pointers to HashLruCache objects, `vector<HashLruCache*>`, 
+has been introduced to manage multiple types of LRUs within xhash. 
+This enhancement facilitates more sophisticated handling of varied 
+data types within the same xhash. With this advancement, greater 
+control over data pruning in the LRU is achieved, depending on the 
+type of data. This feature is valuable when there's a necessity to 
+retain certain data in the LRU for durations longer or shorter than 
+other data. The utilization of this feature is optional. 
+During initialization, the number of LRUs to be created can be specified. 
+If not specified, a single LRU will be created by default.
+
+
index 654de28bead517d014bdc706a48c36b90365f1a9..dfab41c19857d9fdceaecca68c29c263a3225686 100644 (file)
@@ -108,7 +108,11 @@ void XHash::initialize(HashKeyOperations* hk_ops)
 {
     hashkey_ops = hk_ops;
     table = (HashNode**)snort_calloc(sizeof(HashNode*) * nrows);
-    lru_cache = new HashLruCache();
+
+    lru_caches.resize(num_lru_caches);
+    for (size_t i = 0; i < num_lru_caches; ++i)
+        lru_caches[i] = new HashLruCache();
+    
     mem_allocator = new MemCapAllocator(mem_cap, sizeof(HashNode) + keysize + datasize);
 }
 
@@ -125,16 +129,17 @@ void XHash::set_number_of_rows (int rows)
         nrows = -rows;
 }
 
-XHash::XHash(int rows, int keysize)
-    : keysize(keysize)
+XHash::XHash(int rows, int keysize, uint8_t num_lru_caches)
+    : keysize(keysize), num_lru_caches(num_lru_caches)
 
 {
     set_number_of_rows(rows);
 }
 
-XHash::XHash(int rows, int keysize, int datasize, unsigned long memcap)
-    : keysize(keysize), datasize(datasize), mem_cap(memcap)
+XHash::XHash(int rows, int keysize, int datasize, unsigned long memcap, uint8_t num_lru_caches)
+    : keysize(keysize), num_lru_caches(num_lru_caches), datasize(datasize), mem_cap(memcap)
 {
+    assert(num_lru_caches > 0);
     set_number_of_rows(rows);
     initialize();
 }
@@ -157,7 +162,10 @@ XHash::~XHash()
 
     purge_free_list();
     delete hashkey_ops;
-    delete lru_cache;
+    for (auto lru : lru_caches)
+    {
+        delete lru;
+    }
     delete mem_allocator;
 }
 
@@ -179,8 +187,10 @@ void XHash::delete_hash_table()
     table = nullptr;
 }
 
-void XHash::initialize_node(HashNode *hnode, const void *key, void *data, int index)
+void XHash::initialize_node(HashNode *hnode, const void *key, void *data, int index, uint8_t type)
 {
+    assert(type < num_lru_caches);
+
     hnode->key = (char*) (hnode) + sizeof(HashNode);
     memcpy(hnode->key, key, keysize);
     if ( datasize )
@@ -194,7 +204,7 @@ void XHash::initialize_node(HashNode *hnode, const void *key, void *data, int in
 
     hnode->rindex = index;
     link_node(hnode);
-    lru_cache->insert(hnode);
+    lru_caches[type]->insert(hnode);
 }
 
 HashNode* XHash::allocate_node(const void* key, void* data, int index)
@@ -295,29 +305,31 @@ void XHash::update_cursor()
     }
 }
 
-void* XHash::get_user_data(const void* key)
+void* XHash::get_user_data(const void* key, uint8_t type)
 {
     assert(key);
+    assert(type < num_lru_caches);
 
     int rindex = 0;
-    HashNode* hnode = find_node_row(key, rindex);
+    HashNode* hnode = find_node_row(key, rindex, type);
     return ( hnode ) ? hnode->data : nullptr;
 }
 
-void XHash::release()
+void XHash::release(uint8_t type)
 {
-    HashNode* node = lru_cache->get_current_node();
+    assert(type < num_lru_caches);
+    HashNode* node = lru_caches[type]->get_current_node();
     assert(node);
     release_node(node);
 }
-
-int XHash::release_node(HashNode* hnode)
+int XHash::release_node(HashNode* hnode, uint8_t type)
 {
     assert(hnode);
+    assert(type < num_lru_caches);
 
     free_user_data(hnode);
     unlink_node(hnode);
-    lru_cache->remove_node(hnode);
+    lru_caches[type]->remove_node(hnode);
     num_nodes--;
 
     if ( recycle_nodes )
@@ -334,9 +346,10 @@ int XHash::release_node(HashNode* hnode)
     return HASH_OK;
 }
 
-int XHash::release_node(const void* key)
+int XHash::release_node(const void* key, uint8_t type)
 {
     assert(key);
+    assert(type < num_lru_caches);
 
     unsigned hashkey = hashkey_ops->do_hash((const unsigned char*)key, keysize);
 
@@ -344,7 +357,7 @@ int XHash::release_node(const void* key)
     for (HashNode* hnode = table[index]; hnode; hnode = hnode->next)
     {
         if ( hashkey_ops->key_compare(hnode->key, key, keysize) )
-            return release_node(hnode);
+            return release_node(hnode, type);
     }
 
     return HASH_NOT_FOUND;
@@ -383,19 +396,20 @@ void XHash::unlink_node(HashNode* hnode)
     }
 }
 
-void XHash::move_to_front(HashNode* node)
+void XHash::move_to_front(HashNode* node,uint8_t type)
 {
+    assert(type < num_lru_caches);
     if ( table[node->rindex] != node )
     {
         unlink_node(node);
         link_node(node);
     }
-
-    lru_cache->touch(node);
+    lru_caches[type]->touch(node);
 }
 
-HashNode* XHash::find_node_row(const void* key, int& rindex)
+HashNode* XHash::find_node_row(const void* key, int& rindex, uint8_t type)
 {
+    assert(type < num_lru_caches);
     unsigned hashkey = hashkey_ops->do_hash((const unsigned char*)key, keysize);
 
     /* Modulus is slow. Switched to a table size that is a power of 2. */
@@ -404,7 +418,7 @@ HashNode* XHash::find_node_row(const void* key, int& rindex)
     {
         if ( hashkey_ops->key_compare(hnode->key, key, keysize) )
         {
-            move_to_front(hnode);
+            move_to_front(hnode,type);
             return hnode;
         }
     }
@@ -485,24 +499,27 @@ void XHash::clear_hash()
     cursor = nullptr;
 }
 
-void* XHash::get_mru_user_data()
+void* XHash::get_mru_user_data(uint8_t type)
 {
-    return lru_cache->get_mru_user_data();
+    assert(type < num_lru_caches);
+    return lru_caches[type]->get_mru_user_data();
 }
 
-void* XHash::get_lru_user_data()
+void* XHash::get_lru_user_data(uint8_t type)
 {
-    return lru_cache->get_lru_user_data();
+    assert(type < num_lru_caches);
+    return lru_caches[type]->get_lru_user_data();
 }
 
-HashNode* XHash::release_lru_node()
+HashNode* XHash::release_lru_node(uint8_t type)
 {
-    HashNode* hnode = lru_cache->get_lru_node();
+    assert(type < num_lru_caches);
+    HashNode* hnode = lru_caches[type]->get_lru_node();
     while ( hnode )
     {
         if ( is_node_recovery_ok(hnode) )
         {
-            lru_cache->remove_node(hnode);
+            lru_caches[type]->remove_node(hnode);
             free_user_data(hnode);
             unlink_node(hnode);
             --num_nodes;
@@ -510,14 +527,16 @@ HashNode* XHash::release_lru_node()
             break;
         }
         else
-            hnode = lru_cache->get_next_lru_node ();
+            hnode = lru_caches[type]->get_next_lru_node ();
     }
     return hnode;
 }
 
-bool XHash::delete_lru_node()
+bool XHash::delete_lru_node(uint8_t type)
 {
-    if ( HashNode* hnode = lru_cache->remove_lru_node() )
+    assert(type < num_lru_caches);
+
+    if ( HashNode* hnode = lru_caches[type]->remove_lru_node() )
     {
         unlink_node(hnode);
         free_user_data(hnode);
index f823e31718a4286f01965c6ec4570908f38b539d..f77059b4ddc2b18ac6a04838e99562bb939c1017 100644 (file)
@@ -24,6 +24,7 @@
 
 // generic hash table - stores and maps key + data pairs
 // (supports memcap and automatic memory recovery when out of memory)
+#include <vector>
 
 #include "framework/counts.h"
 #include "main/snort_types.h"
@@ -48,8 +49,8 @@ struct XHashStats
 class SO_PUBLIC XHash
 {
 public:
-    XHash(int rows, int keysize);
-    XHash(int rows, int keysize, int datasize, unsigned long memcap);
+    XHash(int rows, int keysize, uint8_t num_lru_caches = 1);
+    XHash(int rows, int keysize, int datasize, unsigned long memcap, uint8_t num_lru_caches = 1);
     virtual ~XHash();
 
     int insert(const void* key, void* data);
@@ -57,13 +58,13 @@ public:
     HashNode* find_first_node();
     HashNode* find_next_node();
     void* get_user_data();
-    void* get_user_data(const void* key);
-    void release();
-    int release_node(const void* key);
-    int release_node(HashNode* node);
-    void* get_mru_user_data();
-    void* get_lru_user_data();
-    bool delete_lru_node();
+    void* get_user_data(const void* key, uint8_t type = 0);
+    void release(uint8_t type = 0);
+    int release_node(const void* key, u_int8_t type = 0);
+    int release_node(HashNode* node, uint8_t type = 0);
+    void* get_mru_user_data(uint8_t type = 0);
+    void* get_lru_user_data(uint8_t type = 0);
+    bool delete_lru_node(uint8_t type = 0);
     void clear_hash();
     bool full() const { return !fhead; }
 
@@ -95,9 +96,9 @@ protected:
     void initialize(HashKeyOperations*);
     void initialize();
 
-    void initialize_node (HashNode*, const void* key, void* data, int index);
+    void initialize_node(HashNode*, const void* key, void* data, int index, uint8_t type = 0);
     HashNode* allocate_node(const void* key, void* data, int index);
-    HashNode* find_node_row(const void* key, int& rindex);
+    HashNode* find_node_row(const void* key, int& rindex, uint8_t type = 0);
     void link_node(HashNode*);
     void unlink_node(HashNode*);
     bool delete_a_node();
@@ -112,13 +113,15 @@ protected:
     { }
 
     MemCapAllocator* mem_allocator = nullptr;
-    HashLruCache* lru_cache = nullptr;
+    std::vector<HashLruCache*> lru_caches;  // Multiple LRU caches
+
     unsigned nrows = 0;
     unsigned keysize = 0;
     unsigned num_nodes = 0;
     unsigned num_free_nodes = 0;
     bool recycle_nodes = true;
     bool anr_enabled = true;
+    uint8_t num_lru_caches = 1;
 
 private:
     HashNode** table = nullptr;
@@ -132,9 +135,9 @@ private:
     XHashStats stats;
 
     void set_number_of_rows(int nrows);
-    void move_to_front(HashNode*);
+    void move_to_front(HashNode*, uint8_t type = 0);
     bool delete_free_node();
-    HashNode* release_lru_node();
+    HashNode* release_lru_node(uint8_t type = 0);
     void update_cursor();
     void purge_free_list();
 };
index ed928e2dcf5edeb9282ecf71243805bff0c546f1..98c0acf57f7931ac44d3c85a0517c4ab0738e732 100644 (file)
@@ -41,18 +41,19 @@ using namespace snort;
 //-------------------------------------------------------------------------
 
 
-ZHash::ZHash(int rows, int key_len, bool recycle)
-    : XHash(rows, key_len)
+ZHash::ZHash(int rows, int key_len, uint8_t lru_count, bool recycle)
+    : XHash(rows, key_len, lru_count)
 {
     initialize(new FlowHashKeyOps(nrows));
     anr_enabled = false;
     recycle_nodes = recycle;
 }
 
-void* ZHash::get(const void* key)
+void* ZHash::get(const void* key, uint8_t type)
 {
     assert(key);
-
+    assert(type < num_lru_caches);
+    
     int index;
     HashNode* node = find_node_row(key, index);
     if ( node )
@@ -65,19 +66,20 @@ void* ZHash::get(const void* key)
     memcpy(node->key, key, keysize);
     node->rindex = index;
     link_node(node);
-    lru_cache->insert(node);
+    lru_caches[type]->insert(node);
     num_nodes++;
     return node->data;
 }
 
-void* ZHash::remove()
+void* ZHash::remove(uint8_t type)
 {
-    HashNode* node = lru_cache->get_current_node();
+    assert(type < num_lru_caches);
+    HashNode* node = lru_caches[type]->get_current_node();
     assert(node);
     void* pv = node->data;
 
     unlink_node(node);
-    lru_cache->remove_node(node);
+    lru_caches[type]->remove_node(node);
     num_nodes--;
     mem_allocator->free(node);
     return pv;
@@ -104,27 +106,31 @@ void* ZHash::pop()
     return pv;
 }
 
-void* ZHash::lru_first()
+void* ZHash::lru_first(uint8_t type)
 {
-    HashNode* node = lru_cache->get_lru_node();
+    assert(type < num_lru_caches);
+    HashNode* node = lru_caches[type]->get_lru_node();
     return node ? node->data : nullptr;
 }
 
-void* ZHash::lru_next()
+void* ZHash::lru_next(uint8_t type)
 {
-    HashNode* node = lru_cache->get_next_lru_node();
+    assert(type < num_lru_caches);
+    HashNode* node = lru_caches[type]->get_next_lru_node();
     return node ? node->data : nullptr;
 }
 
-void* ZHash::lru_current()
+void* ZHash::lru_current(uint8_t type)
 {
-    HashNode* node = lru_cache->get_current_node();
+    assert(type < num_lru_caches);
+    HashNode* node = lru_caches[type]->get_current_node();
     return node ? node->data : nullptr;
 }
 
-void ZHash::lru_touch()
+void ZHash::lru_touch(uint8_t type)
 {
-    HashNode* node = lru_cache->get_current_node();
+    assert(type < num_lru_caches);
+    HashNode* node = lru_caches[type]->get_current_node();
     assert(node);
-    lru_cache->touch(node);
+    lru_caches[type]->touch(node);
 }
index 45c2df455e083952c3ec2a701e949d1b0bebfa9c..143e21ae2099a8dea05a69c9c88937d33f6ea7a8 100644 (file)
@@ -27,7 +27,7 @@
 class ZHash : public snort::XHash
 {
 public:
-    ZHash(int nrows, int keysize, bool recycle = true);
+    ZHash(int nrows, int keysize, uint8_t lru_count = 1, bool recycle = true);
 
     ZHash(const ZHash&) = delete;
     ZHash& operator=(const ZHash&) = delete;
@@ -35,13 +35,13 @@ public:
     void* push(void* p);
     void* pop();
 
-    void* get(const void* key);
-    void* remove();
+    void* get(const void* key, uint8_t type = 0);
+    void* remove(uint8_t type = 0);
 
-    void* lru_first();
-    void* lru_next();
-    void* lru_current();
-    void lru_touch();
+    void* lru_first(uint8_t type = 0);
+    void* lru_next(uint8_t type = 0);
+    void* lru_current(uint8_t type = 0);
+    void lru_touch(uint8_t type = 0);
 };
 
 #endif