From: Raza Shafiq (rshafiq) Date: Tue, 29 Oct 2024 15:22:31 +0000 (+0000) Subject: Pull request #4483: flow: new allowlist LRU X-Git-Tag: 3.5.1.0~6 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=db89d6cf15378f34d9051806a0260182f42d93d5;p=thirdparty%2Fsnort3.git Pull request #4483: flow: new allowlist LRU Merge in SNORT/snort3 from ~RSHAFIQ/snort3:whitelist_cache to master Squashed commit of the following: commit a1647130533346a651396d00c1d251c294094416 Author: rshafiq Date: Wed Oct 2 19:08:52 2024 -0400 flow: new allowlist LRU --- diff --git a/src/flow/dev_notes.txt b/src/flow/dev_notes.txt index 0d121440e..9c8725719 100644 --- a/src/flow/dev_notes.txt +++ b/src/flow/dev_notes.txt @@ -86,3 +86,21 @@ bias during pruning sessions, especially concerning UDP flows, but lays down a framework for more advanced and controlled data management within the cache moving forward. +10/25/2024 +Allowlist LRU +To address the need for preserving flows that have been allowlisted and are +at risk of timing out, we've introduced a configurable Allowlist LRU cache +within the flow_cache. This enhancement enables the retention of flows marked +with a whitelist verdict, preventing them from being prematurely pruned due +to inactivity timeouts. This is particularly beneficial in scenarios where +Snort ceases to observe traffic for a flow after the whitelist decision, +especially if that flow is long-lived. Without this adjustment, +such flows may be pruned by the cache upon timeout, potentially impacting +event logging at the flow’s end-of-life (EOF) due to missing +pruned flow information. + +The Allowlist LRU cache is disabled by default but can be enabled by adding +allowlist_cache = { enable = true } in the stream configuration. +Like the protocol-based LRUs, this allowlist functionality is an +additional LRU rather than a new hash_table, thereby maintaining +consistent performance with previous configurations. \ No newline at end of file diff --git a/src/flow/flow.cc b/src/flow/flow.cc index de7f2b078..f9ec73a28 100644 --- a/src/flow/flow.cc +++ b/src/flow/flow.cc @@ -26,12 +26,14 @@ #include "detection/context_switcher.h" #include "detection/detection_continuation.h" #include "detection/detection_engine.h" +#include "flow/flow_control.h" #include "flow/flow_key.h" #include "flow/ha.h" #include "flow/session.h" #include "framework/data_bus.h" #include "helpers/bitop.h" #include "main/analyzer.h" +#include "packet_io/packet_tracer.h" #include "protocols/packet.h" #include "protocols/tcp.h" #include "pub_sub/intrinsic_event_ids.h" @@ -41,6 +43,7 @@ #include "utils/util.h" using namespace snort; +extern THREAD_LOCAL class FlowControl* flow_con; Flow::~Flow() { @@ -538,3 +541,16 @@ void Flow::swap_roles() std::swap(outer_client_ttl, outer_server_ttl); flags.client_initiated = !flags.client_initiated; } + +bool Flow::handle_allowlist() +{ + if ( flow_con->get_flow_cache_config().allowlist_cache and !flags.in_allowlist ) + { + if ( flow_con->move_to_allowlist(this) ) + { + PacketTracer::log("Flow: flow has been moved to allowlist cache\n"); + return true; + } + } + return false; +} diff --git a/src/flow/flow.h b/src/flow/flow.h index 6e4bb03a6..fd17d6454 100644 --- a/src/flow/flow.h +++ b/src/flow/flow.h @@ -428,6 +428,8 @@ public: uint64_t fetch_add_inspection_duration(); + bool handle_allowlist(); + public: // FIXIT-M privatize if possible // fields are organized by initialization and size to minimize // void space @@ -523,6 +525,7 @@ public: // FIXIT-M privatize if possible bool ips_pblock_event_suppressed : 1; // set if event filters have suppressed a partial block ips event bool binder_action_allow : 1; bool binder_action_block : 1; + bool in_allowlist : 1; // Set if the flow is in the allowlist } flags = {}; FlowState flow_state = FlowState::SETUP; diff --git a/src/flow/flow_cache.cc b/src/flow/flow_cache.cc index 54b0eca22..9e9278990 100644 --- a/src/flow/flow_cache.cc +++ b/src/flow/flow_cache.cc @@ -60,9 +60,6 @@ static const unsigned OFFLOADED_FLOWS_TOO = 2; static const unsigned ALL_FLOWS = 3; static const unsigned WDT_MASK = 7; // kick watchdog once for every 8 flows deleted -constexpr uint8_t MAX_PROTOCOLS = (uint8_t)to_utype(PktType::MAX) - 1; //removing PktType::NONE from count -constexpr uint64_t max_skip_protos = (1ULL << MAX_PROTOCOLS) - 1; -constexpr uint8_t first_proto = to_utype(PktType::NONE) + 1; uint8_t DumpFlows::dump_code = 0; @@ -202,11 +199,11 @@ bool DumpFlows::execute(Analyzer&, void**) FlowCache::FlowCache(const FlowCacheConfig& cfg) : config(cfg) { - hash_table = new ZHash(config.max_flows, sizeof(FlowKey), MAX_PROTOCOLS, false); + hash_table = new ZHash(config.max_flows, sizeof(FlowKey), total_lru_count, false); uni_flows = new FlowUniList; uni_ip_flows = new FlowUniList; flags = 0x0; - empty_proto = ( 1 << MAX_PROTOCOLS ) - 1; + empty_lru_mask = ( 1 << max_protocols ) - 1; timeout_idx = first_proto; assert(prune_stats.get_total() == 0); @@ -245,9 +242,14 @@ unsigned FlowCache::get_count() Flow* FlowCache::find(const FlowKey* key) { - Flow* flow = (Flow*)hash_table->get_user_data(key,to_utype(key->pkt_type)); + Flow* flow = (Flow*)hash_table->get_user_data(key,to_utype(key->pkt_type), false); if ( flow ) { + if ( flow->flags.in_allowlist ) + hash_table->touch_last_found(allowlist_lru_index); + else + hash_table->touch_last_found(to_utype(key->pkt_type)); + time_t t = packet_time(); if ( flow->last_data_seen < t ) @@ -321,7 +323,7 @@ Flow* FlowCache::allocate(const FlowKey* key) link_uni(flow); flow->last_data_seen = timestamp; flow->set_idle_timeout(config.proto[to_utype(flow->key->pkt_type)].nominal_timeout); - empty_proto &= ~(1ULL << to_utype(key->pkt_type)); // clear the bit for this protocol + empty_lru_mask &= ~(1ULL << to_utype(key->pkt_type)); // clear the bit for this protocol return flow; } @@ -330,9 +332,13 @@ void FlowCache::remove(Flow* flow) { unlink_uni(flow); const snort::FlowKey* key = flow->key; + uint8_t in_allowlist = flow->flags.in_allowlist; // Delete before releasing the node, so that the key is valid until the flow is completely freed delete flow; - hash_table->release_node(key, to_utype(key->pkt_type)); + if ( in_allowlist ) + hash_table->release_node(key, allowlist_lru_index); + else + hash_table->release_node(key, to_utype(key->pkt_type)); } bool FlowCache::release(Flow* flow, PruneReason reason, bool do_cleanup) @@ -347,8 +353,9 @@ bool FlowCache::release(Flow* flow, PruneReason reason, bool do_cleanup) } } + uint8_t in_allowlist = flow->flags.in_allowlist; flow->reset(do_cleanup); - prune_stats.update(reason, flow->key->pkt_type); + prune_stats.update(reason, ( in_allowlist ? static_cast(allowlist_lru_index) : flow->key->pkt_type )); remove(flow); return true; } @@ -365,31 +372,30 @@ unsigned FlowCache::prune_idle(uint32_t thetime, const Flow* save_me) ActiveSuspendContext act_susp(Active::ASP_PRUNE); unsigned pruned = 0; - uint64_t skip_protos = empty_proto; + uint64_t checked_lrus_mask = empty_lru_mask; - assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos)); + assert(max_protocols < 8 * sizeof(checked_lrus_mask)); { PacketTracerSuspend pt_susp; while ( pruned <= cleanup_flows and - skip_protos != max_skip_protos ) + !all_lrus_checked(checked_lrus_mask) ) { - // Round-robin through the proto types - for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx ) + // Round-robin through the LRU types + for( uint8_t lru_idx = first_proto; lru_idx < max_protocols; ++lru_idx ) { if( pruned > cleanup_flows ) break; - const uint64_t proto_mask = 1ULL << proto_idx; + const uint64_t lru_mask = get_lru_mask(lru_idx); - if ( skip_protos & proto_mask ) + if( is_lru_checked(checked_lrus_mask, lru_mask) ) continue; - auto flow = static_cast(hash_table->lru_first(proto_idx)); + auto flow = static_cast(hash_table->lru_first(lru_idx)); if ( !flow ) { - skip_protos |= proto_mask; - empty_proto |= proto_mask; + mark_lru_checked(checked_lrus_mask, empty_lru_mask, lru_mask); continue; } @@ -397,7 +403,7 @@ unsigned FlowCache::prune_idle(uint32_t thetime, const Flow* save_me) or flow->is_suspended() or flow->last_data_seen + config.pruning_timeout >= thetime ) { - skip_protos |= proto_mask; + mark_lru_checked(checked_lrus_mask, lru_mask); continue; } @@ -461,52 +467,53 @@ unsigned FlowCache::prune_excess(const Flow* save_me) unsigned pruned = 0; - // initially skip offloads but if that doesn't work the hash table is iterated from the - // beginning again. prune offloads at that point. + // Initially skip offloads but if that doesn't work, the hash table is iterated from the + // beginning again. Prune offloads at that point. unsigned ignore_offloads = hash_table->get_num_nodes(); - uint64_t skip_protos = 0; + uint64_t checked_lrus_mask = 0; - assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos)); + assert(total_lru_count < 8 * sizeof(checked_lrus_mask)); { PacketTracerSuspend pt_susp; unsigned blocks = 0; + // EXCESS pruning will start from the allowlist LRU + uint8_t lru_idx = allowlist_lru_index; while ( true ) { auto num_nodes = hash_table->get_num_nodes(); - if ( num_nodes <= max_cap or num_nodes <= blocks or - ignore_offloads == 0 or skip_protos == max_skip_protos ) - break; - - for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx ) + if ( num_nodes <= max_cap or num_nodes <= blocks or + ignore_offloads == 0 or all_lrus_checked(checked_lrus_mask) ) + break; + + for (; lru_idx < total_lru_count; ++lru_idx) { num_nodes = hash_table->get_num_nodes(); if ( num_nodes <= max_cap or num_nodes <= blocks ) break; - const uint64_t proto_mask = 1ULL << proto_idx; + const uint64_t lru_mask = get_lru_mask(lru_idx); - if ( skip_protos & proto_mask ) + if ( is_lru_checked(checked_lrus_mask, lru_mask) ) continue; - auto flow = static_cast(hash_table->lru_first(proto_idx)); + auto flow = static_cast(hash_table->lru_first(lru_idx)); if ( !flow ) { - skip_protos |= proto_mask; + mark_lru_checked(checked_lrus_mask, lru_mask); continue; } if ( (save_me and flow == save_me) or flow->was_blocked() or (flow->is_suspended() and ignore_offloads) ) { - // check for non-null save_me above to silence analyzer - // "called C++ object pointer is null" here + // Avoid pruning the current flow (save_me) or blocked/suspended flows if ( flow->was_blocked() ) ++blocks; - // FIXIT-M we should update last_data_seen upon touch to ensure - // the hash_table LRU list remains sorted by time - hash_table->lru_touch(proto_idx); + + // Ensure LRU list remains sorted by time on touch + hash_table->lru_touch(lru_idx); } else { @@ -517,6 +524,9 @@ unsigned FlowCache::prune_excess(const Flow* save_me) if ( ignore_offloads > 0 ) --ignore_offloads; } + + if ( lru_idx >= total_lru_count ) + lru_idx = first_proto; } if ( !pruned and hash_table->get_num_nodes() > max_cap ) @@ -533,19 +543,22 @@ unsigned FlowCache::prune_excess(const Flow* save_me) bool FlowCache::prune_one(PruneReason reason, bool do_cleanup, uint8_t type) { - // so we don't prune the current flow (assume current == MRU) - if ( hash_table->get_num_nodes() <= 1 ) + // Avoid pruning the current flow (assume current == MRU) + if (hash_table->get_num_nodes() <= 1) return false; - // ZHash returns in LRU order, which is updated per packet via find --> move_to_front call auto flow = static_cast(hash_table->lru_first(type)); - if( !flow ) + if ( !flow ) return false; flow->ssn_state.session_flags |= SSNFLAG_PRUNED; - release(flow, reason, do_cleanup); - return true; + if ( type != allowlist_lru_index ) + return release(flow, reason, do_cleanup); + else if ( reason == PruneReason::MEMCAP or reason == PruneReason::EXCESS ) + return release(flow, reason, do_cleanup); + + return false; } unsigned FlowCache::prune_multiple(PruneReason reason, bool do_cleanup) @@ -554,28 +567,38 @@ unsigned FlowCache::prune_multiple(PruneReason reason, bool do_cleanup) // so we don't prune the current flow (assume current == MRU) if ( hash_table->get_num_nodes() <= 1 ) return 0; - - uint8_t proto = 0; - uint64_t skip_protos = 0; - assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos)); + uint8_t lru_idx = 0; + uint64_t checked_lrus_mask = 0; + + assert(max_protocols < 8 * sizeof(checked_lrus_mask)); + + if( reason == PruneReason::MEMCAP or reason == PruneReason::EXCESS ) + { + // if MEMCAP or EXCESS, prune the allowlist first + while ( pruned < config.prune_flows ) + { + if ( !prune_one(reason, do_cleanup, allowlist_lru_index) ) + break; + pruned++; + } + } - while ( pruned < config.prune_flows ) { - const uint64_t proto_mask = 1ULL << proto; - if ( (skip_protos & proto_mask) or !prune_one(reason, do_cleanup, proto) ) + const uint64_t lru_mask = get_lru_mask(lru_idx); + if ( is_lru_checked(checked_lrus_mask, lru_mask) or !prune_one(reason, do_cleanup, lru_idx) ) { + mark_lru_checked(checked_lrus_mask, lru_mask); - skip_protos |= proto_mask; - if ( skip_protos == max_skip_protos ) + if ( all_lrus_checked(checked_lrus_mask) ) break; } else pruned++; - - if ( ++proto >= MAX_PROTOCOLS ) - proto = 0; + + if ( ++lru_idx >= max_protocols ) + lru_idx = 0; } if ( PacketTracer::is_active() and pruned ) @@ -589,21 +612,37 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime) ActiveSuspendContext act_susp(Active::ASP_TIMEOUT); unsigned retired = 0; - uint64_t skip_protos = empty_proto; + uint64_t checked_lrus_mask = empty_lru_mask; // Start by skipping any protocols that have no flows. - assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos)); +#ifdef REG_TEST + if ( hash_table->get_node_count(allowlist_lru_index) > 0 ) + { + uint64_t allowlist_timeout_count = 0; + auto flow = static_cast(hash_table->lru_first(allowlist_lru_index)); + while ( flow ) + { + if ( flow->last_data_seen + flow->idle_timeout > thetime ) + allowlist_timeout_count++; + flow = static_cast(hash_table->lru_next(allowlist_lru_index)); + } + if ( PacketTracer::is_active() and allowlist_timeout_count ) + PacketTracer::log("Flow: %lu allowlist flow(s) timed out but not pruned \n", allowlist_timeout_count); + } +#endif + + assert(max_protocols < 8 * sizeof(checked_lrus_mask)); { PacketTracerSuspend pt_susp; - while ( retired < num_flows and skip_protos != max_skip_protos ) + while ( retired < num_flows and !all_lrus_checked(checked_lrus_mask) ) { - for( ; timeout_idx < MAX_PROTOCOLS; ++timeout_idx ) + for( ; timeout_idx < max_protocols; ++timeout_idx ) { - const uint64_t proto_mask = 1ULL << timeout_idx; + const uint64_t lru_mask = get_lru_mask(timeout_idx); - if ( skip_protos & proto_mask ) + if ( is_lru_checked(checked_lrus_mask, lru_mask) ) continue; auto flow = static_cast(hash_table->lru_current(timeout_idx)); @@ -612,8 +651,7 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime) flow = static_cast(hash_table->lru_first(timeout_idx)); if ( !flow ) { - skip_protos |= proto_mask; - empty_proto |= proto_mask; + mark_lru_checked(checked_lrus_mask, empty_lru_mask, lru_mask); continue; } } @@ -622,13 +660,13 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime) { if ( flow->expire_time > static_cast(thetime) ) { - skip_protos |= proto_mask; + mark_lru_checked(checked_lrus_mask, lru_mask); continue; } } else if ( flow->last_data_seen + flow->idle_timeout > thetime ) { - skip_protos |= proto_mask; + mark_lru_checked(checked_lrus_mask, lru_mask); continue; } @@ -655,30 +693,29 @@ unsigned FlowCache::timeout(unsigned num_flows, time_t thetime) unsigned FlowCache::delete_active_flows(unsigned mode, unsigned num_to_delete, unsigned &deleted) { - uint64_t skip_protos = empty_proto; + uint64_t checked_lrus_mask = empty_lru_mask; uint64_t undeletable = 0; - assert(MAX_PROTOCOLS < 8 * sizeof(skip_protos)); + assert(max_protocols < 8 * sizeof(checked_lrus_mask)); - while ( num_to_delete and skip_protos != max_skip_protos and + while ( num_to_delete and !all_lrus_checked(checked_lrus_mask) and undeletable < hash_table->get_num_nodes() ) { - for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx ) + for ( uint8_t lru_idx = first_proto; lru_idx < max_protocols; ++lru_idx ) { - if( num_to_delete == 0) + if ( num_to_delete == 0 ) break; - - const uint64_t proto_mask = 1ULL << proto_idx; - if ( skip_protos & proto_mask ) + const uint64_t lru_mask = get_lru_mask(lru_idx); + + if ( is_lru_checked(checked_lrus_mask, lru_mask) ) continue; - - auto flow = static_cast(hash_table->lru_first(proto_idx)); + + auto flow = static_cast(hash_table->lru_first(lru_idx)); if ( !flow ) { - skip_protos |= proto_mask; - empty_proto |= proto_mask; + mark_lru_checked(checked_lrus_mask, empty_lru_mask, lru_mask); continue; } @@ -686,7 +723,7 @@ unsigned FlowCache::delete_active_flows(unsigned mode, unsigned num_to_delete, u or (mode == OFFLOADED_FLOWS_TOO and flow->was_blocked()) ) { undeletable++; - hash_table->lru_touch(proto_idx); + hash_table->lru_touch(lru_idx); continue; } @@ -706,7 +743,7 @@ unsigned FlowCache::delete_active_flows(unsigned mode, unsigned num_to_delete, u // Delete before removing the node, so that the key is valid until the flow is completely freed delete flow; // The flow should not be removed from the hash before reset - hash_table->remove(proto_idx); + hash_table->remove(lru_idx); ++deleted; --num_to_delete; } @@ -741,7 +778,7 @@ unsigned FlowCache::purge() unsigned retired = 0; - for( uint8_t proto_idx = first_proto; proto_idx < MAX_PROTOCOLS; ++proto_idx ) + for( uint8_t proto_idx = first_proto; proto_idx < total_lru_count; ++proto_idx ) { while ( auto flow = static_cast(hash_table->lru_first(proto_idx)) ) { @@ -840,7 +877,7 @@ void FlowCache::output_flow(std::fstream& stream, const Flow& flow, const struct } std::stringstream out; std::stringstream proto; - switch ( flow.pkt_type ) + switch ( flow.key->pkt_type ) { case PktType::IP: out << "Instance-ID: " << get_relative_instance_number() << " IP " << flow.key->addressSpaceId << ": " << src_ip << " " << dst_ip; @@ -878,67 +915,51 @@ void FlowCache::output_flow(std::fstream& stream, const Flow& flow, const struct timeout_to_str(flow.expire_time - now.tv_sec) : timeout_to_str((flow.last_data_seen + config.proto[to_utype(flow.key->pkt_type)].nominal_timeout) - now.tv_sec); out << t; - stream << out.str() << proto.str() << std::endl; + stream << out.str() << proto.str() << (flow.flags.in_allowlist ? " (allowlist)" : "") << std::endl; } bool FlowCache::dump_flows(std::fstream& stream, unsigned count, const FilterFlowCriteria& ffc, bool first, uint8_t code) const { struct timeval now; packet_gettimeofday(&now); - unsigned i; - bool has_more_flows = false; Flow* walk_flow = nullptr; + bool has_more_flows = false; - for(uint8_t proto_id = to_utype(PktType::NONE)+1; proto_id <= to_utype(PktType::ICMP); proto_id++) + for (uint8_t proto_id = to_utype(PktType::NONE) + 1; proto_id < total_lru_count; ++proto_id) { - if (first) - { + if ( proto_id == to_utype(PktType::USER) or + proto_id == to_utype(PktType::FILE) or + proto_id == to_utype(PktType::PDU) ) + continue; + + unsigned i = 0; + + if ( first ) walk_flow = static_cast(hash_table->get_walk_user_data(proto_id)); - if (!walk_flow) - { - //Return only if all the protocol caches are processed. - if (proto_id < to_utype(PktType::ICMP)) - continue; - return !has_more_flows; - } - walk_flow->dump_code = code; - bool matched_filter = filter_flows(*walk_flow, ffc); - if (matched_filter) - output_flow(stream, *walk_flow, now); - i = 1; - } else - i = 0; - while (i < count) - { walk_flow = static_cast(hash_table->get_next_walk_user_data(proto_id)); - if (!walk_flow ) - { - //Return only if all the protocol caches are processed. - if (proto_id < to_utype(PktType::ICMP)) - break; - return !has_more_flows; - } - if (walk_flow->dump_code != code) + while ( walk_flow && i < count ) + { + if ( walk_flow->dump_code != code ) { walk_flow->dump_code = code; - bool matched_filter = filter_flows(*walk_flow, ffc); - if (matched_filter) + if( filter_flows(*walk_flow, ffc) ) output_flow(stream, *walk_flow, now); ++i; } -#ifdef REG_TEST - else - LogMessage("dump_flows skipping already dumped flow\n"); -#endif + if (i < count) + walk_flow = static_cast(hash_table->get_next_walk_user_data(proto_id)); } - if(walk_flow) // we have output 'count' flows, but the protocol cache still has more flows + + if ( walk_flow ) has_more_flows = true; } - return false; + + return !has_more_flows; } + size_t FlowCache::uni_flows_size() const { return uni_flows ? uni_flows->get_count() : 0; @@ -953,3 +974,32 @@ size_t FlowCache::flows_size() const { return hash_table->get_num_nodes(); } + +PegCount FlowCache::get_lru_flow_count(uint8_t lru_idx) const +{ + return hash_table->get_node_count(lru_idx); +} + +bool FlowCache::move_to_allowlist(snort::Flow* f) +{ + if( hash_table->switch_lru_cache(f->key, to_utype(f->key->pkt_type), allowlist_lru_index) ) + { + f->flags.in_allowlist = 1; + return true; + } + return false; +} + +#ifdef UNIT_TEST +size_t FlowCache::count_flows_in_lru(uint8_t lru_index) const +{ + size_t count = 0; + Flow* flow = static_cast(hash_table->get_walk_user_data(lru_index)); + while (flow) + { + ++count; + flow = static_cast(hash_table->get_next_walk_user_data(lru_index)); + } + return count; +} +#endif diff --git a/src/flow/flow_cache.h b/src/flow/flow_cache.h index 7756a16af..dd85c60c1 100644 --- a/src/flow/flow_cache.h +++ b/src/flow/flow_cache.h @@ -38,6 +38,12 @@ #include "prune_stats.h" #include "filter_flow_critera.h" +constexpr uint8_t max_protocols = static_cast(to_utype(PktType::MAX)); +constexpr uint8_t allowlist_lru_index = max_protocols; +constexpr uint8_t total_lru_count = max_protocols + 1; +constexpr uint64_t all_lru_mask = (1ULL << max_protocols) - 1; +constexpr uint8_t first_proto = to_utype(PktType::NONE) + 1; + namespace snort { class Flow; @@ -134,6 +140,8 @@ public: const FlowCacheConfig& get_flow_cache_config() const { return config; } + bool move_to_allowlist(snort::Flow* f); + virtual bool filter_flows(const snort::Flow&, const FilterFlowCriteria&) const; virtual void output_flow(std::fstream&, const snort::Flow&, const struct timeval&) const; @@ -142,6 +150,10 @@ public: size_t uni_flows_size() const; size_t uni_ip_flows_size() const; size_t flows_size() const; + PegCount get_lru_flow_count(uint8_t lru_idx) const; +#ifdef UNIT_TEST + size_t count_flows_in_lru(uint8_t lru_index) const; +#endif private: void delete_uni(); @@ -154,6 +166,24 @@ private: static std::string timeout_to_str(time_t); bool is_ip_match(const snort::SfIp& flow_ip, const snort::SfIp& filter_ip, const snort::SfIp& subnet) const; + inline bool is_lru_checked(uint64_t checked_lrus_mask, uint64_t lru_mask) + { return (checked_lrus_mask & lru_mask) != 0; } + + inline bool all_lrus_checked(uint64_t checked_lrus_mask) + { return checked_lrus_mask == all_lru_mask; } + + inline void mark_lru_checked(uint64_t& checked_lrus_mask, uint64_t lru_mask) + { checked_lrus_mask |= lru_mask; } + + inline uint64_t get_lru_mask(uint8_t lru_idx) + { return 1ULL << lru_idx; } + + inline void mark_lru_checked(uint64_t& checked_lrus_mask, uint64_t& empty_lru_masks, uint64_t lru_mask) + { + checked_lrus_mask |= lru_mask; + empty_lru_masks |= lru_mask; + } + private: uint8_t timeout_idx; static const unsigned cleanup_flows = 1; @@ -166,7 +196,8 @@ private: PruneStats prune_stats; FlowDeleteStats delete_stats; - uint64_t empty_proto; + uint64_t empty_lru_mask; + }; #endif diff --git a/src/flow/flow_config.h b/src/flow/flow_config.h index 5b42ffc9d..1638f4ff7 100644 --- a/src/flow/flow_config.h +++ b/src/flow/flow_config.h @@ -35,6 +35,7 @@ struct FlowCacheConfig unsigned pruning_timeout = 0; FlowTypeConfig proto[to_utype(PktType::MAX)]; unsigned prune_flows = 0; + bool allowlist_cache = false; }; #endif diff --git a/src/flow/flow_control.cc b/src/flow/flow_control.cc index 16373c774..ea6b7fad1 100644 --- a/src/flow/flow_control.cc +++ b/src/flow/flow_control.cc @@ -116,6 +116,18 @@ void FlowControl::release_flow(const FlowKey* key) cache->release(flow, PruneReason::HA); } +bool FlowControl::move_to_allowlist(Flow* f) +{ + // Preserve the flow only if it is a TCP or UDP flow, + // as only these flow types contain appid-related info needed at the EOF event. + if ( f->key->pkt_type != PktType::TCP and f->key->pkt_type != PktType::UDP ) + return false; + return cache->move_to_allowlist(f); +} + +PegCount FlowControl::get_allowlist_flow_count() const +{ return cache->get_lru_flow_count(allowlist_lru_index); } + void FlowControl::release_flow(Flow* flow, PruneReason reason) { cache->release(flow, reason); } diff --git a/src/flow/flow_control.h b/src/flow/flow_control.h index b84c1abca..2f3d62383 100644 --- a/src/flow/flow_control.h +++ b/src/flow/flow_control.h @@ -72,6 +72,7 @@ public: void timeout_flows(unsigned int, time_t cur_time); void check_expected_flow(snort::Flow*, snort::Packet*); unsigned prune_multiple(PruneReason, bool do_cleanup); + bool move_to_allowlist(snort::Flow*); bool dump_flows(std::fstream&, unsigned count, const FilterFlowCriteria& ffc, bool first, uint8_t code) const; @@ -92,6 +93,7 @@ public: PegCount get_flows() { return num_flows; } + PegCount get_allowlist_flow_count() const; PegCount get_total_prunes() const; PegCount get_prunes(PruneReason) const; PegCount get_proto_prune_count(PruneReason, PktType) const; diff --git a/src/flow/prune_stats.h b/src/flow/prune_stats.h index 53a22e5cd..3367ddfae 100644 --- a/src/flow/prune_stats.h +++ b/src/flow/prune_stats.h @@ -39,25 +39,25 @@ enum class PruneReason : uint8_t MAX }; -struct ProtoPruneStats +struct LRUPruneStats { - using proto_t = std::underlying_type_t; - PegCount proto_counts[static_cast(PktType::MAX)] { }; + using lru_t = std::underlying_type_t; + PegCount lru_counts[static_cast(LRUType::MAX)] { }; PegCount get_total() const { PegCount total = 0; - for ( proto_t i = 0; i < static_cast(PktType::MAX); ++i ) - total += proto_counts[i]; + for ( lru_t i = 0; i < static_cast(LRUType::MAX); ++i ) + total += lru_counts[i]; return total; } PegCount& get(PktType type) - { return proto_counts[static_cast(type)]; } + { return lru_counts[static_cast(type)]; } const PegCount& get(PktType type) const - { return proto_counts[static_cast(type)]; } + { return lru_counts[static_cast(type)]; } void update(PktType type) { ++get(type); } @@ -68,7 +68,7 @@ struct PruneStats using reason_t = std::underlying_type::type; PegCount prunes[static_cast(PruneReason::MAX)] { }; - ProtoPruneStats protoPruneStats[static_cast(PruneReason::MAX)] { }; + LRUPruneStats lruPruneStats[static_cast(PruneReason::MAX)] { }; PegCount get_total() const { @@ -88,20 +88,20 @@ struct PruneStats void update(PruneReason reason, PktType type = PktType::NONE) { ++get(reason); - protoPruneStats[static_cast(reason)].update(type); + lruPruneStats[static_cast(reason)].update(type); } PegCount& get_proto_prune_count(PruneReason reason, PktType type) - { return protoPruneStats[static_cast(reason)].get(type); } + { return lruPruneStats[static_cast(reason)].get(type); } const PegCount& get_proto_prune_count(PruneReason reason, PktType type) const - { return protoPruneStats[static_cast(reason)].get(type); } + { return lruPruneStats[static_cast(reason)].get(type); } PegCount get_proto_prune_count(PktType type) const { PegCount total = 0; for ( reason_t i = 0; i < static_cast(PruneReason::NONE); ++i ) - total += protoPruneStats[i].get(type); + total += lruPruneStats[i].get(type); return total; } diff --git a/src/flow/test/flow_cache_test.cc b/src/flow/test/flow_cache_test.cc index dcd28e0c3..936b1b118 100644 --- a/src/flow/test/flow_cache_test.cc +++ b/src/flow/test/flow_cache_test.cc @@ -396,6 +396,243 @@ TEST(flow_prune, prune_counts) CHECK_EQUAL(3, stats.get_proto_prune_count(PruneReason::IDLE_PROTOCOL_TIMEOUT, PktType::IP)); } +TEST_GROUP(allowlist_test) { }; + +TEST(allowlist_test, move_to_allowlist) +{ + FlowCacheConfig fcg; + fcg.max_flows = 5; + DummyCache* cache = new DummyCache(fcg); + int port = 1; + + // Adding two UDP flows and moving them to allow list + for (unsigned i = 0; i < 2; ++i) { + FlowKey flow_key; + flow_key.port_l = port++; + flow_key.pkt_type = PktType::UDP; + + Flow* flow = cache->allocate(&flow_key); + CHECK(cache->move_to_allowlist(flow) == true); // Move flow to allow list + + Flow* found_flow = cache->find(&flow_key); + CHECK(found_flow == flow); // Verify flow is found + CHECK(found_flow->flags.in_allowlist == 1); // Verify it's in allowlist + } + + CHECK_EQUAL(2, cache->get_count()); // Check two flows in cache + CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index)); // Check 2 allow list flows + + cache->purge(); + delete cache; +} + + +TEST(allowlist_test, allowlist_timeout_prune_fail) +{ + FlowCacheConfig fcg; + fcg.max_flows = 5; + DummyCache* cache = new DummyCache(fcg); + int port = 1; + + for (unsigned i = 0; i < 2; ++i) + { + FlowKey flow_key; + flow_key.port_l = port++; + flow_key.pkt_type = PktType::TCP; + + Flow* flow = cache->allocate(&flow_key); + CHECK(cache->move_to_allowlist(flow) == true); + } + + CHECK_EQUAL(2, cache->get_count()); + CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index)); + + // Ensure pruning doesn't occur because all flows are allow listed + for (uint8_t i = 0; i < total_lru_count; ++i) + CHECK(cache->prune_one(PruneReason::IDLE_PROTOCOL_TIMEOUT, true, i) == false); + + CHECK_EQUAL(2, cache->get_count()); + CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index)); + + cache->purge(); + delete cache; +} + +TEST(allowlist_test, allowlist_memcap_prune_pass) +{ + PruneStats stats; + FlowCacheConfig fcg; + fcg.max_flows = 10; + fcg.prune_flows = 5; + DummyCache* cache = new DummyCache(fcg); + int port = 1; + + for (unsigned i = 0; i < 10; ++i) + { + FlowKey flow_key; + flow_key.port_l = port++; + flow_key.pkt_type = PktType::TCP; + + Flow* flow = cache->allocate(&flow_key); + CHECK(cache->move_to_allowlist(flow) == true); + } + + CHECK_EQUAL(10, cache->get_count()); // Check 10 flows in cache + CHECK_EQUAL(10, cache->get_lru_flow_count(allowlist_lru_index)); // Check 2 allow listed flows + + CHECK_EQUAL(5, cache->prune_multiple(PruneReason::MEMCAP, true)); + CHECK_EQUAL(5, cache->get_count()); + CHECK_EQUAL(5, cache->get_proto_prune_count(PruneReason::MEMCAP, (PktType)allowlist_lru_index)); + CHECK_EQUAL(5, cache->get_lru_flow_count(allowlist_lru_index)); + + cache->purge(); + delete cache; +} + + +TEST(allowlist_test, allowlist_timeout_with_other_protos) +{ + FlowCacheConfig fcg; + fcg.max_flows = 10; + fcg.prune_flows = 10; + + for (uint8_t i = to_utype(PktType::NONE); i <= to_utype(PktType::MAX); ++i) + fcg.proto[i].nominal_timeout = 5; + + FlowCache* cache = new FlowCache(fcg); + int port = 1; + + for (unsigned i = 0; i < 2; ++i) + { + FlowKey flow_key; + flow_key.port_l = port++; + flow_key.pkt_type = PktType::UDP; + + Flow* flow = cache->allocate(&flow_key); + CHECK(cache->move_to_allowlist(flow) == true); // Move flow to allow list + + Flow* found_flow = cache->find(&flow_key); + CHECK(found_flow == flow); + CHECK(found_flow->flags.in_allowlist == 1); + } + + CHECK_EQUAL(2, cache->get_count()); + + // Ensure pruning doesn't occur because all flows are allow listed + for (uint8_t i = 0; i < to_utype(PktType::MAX) - 1; ++i) + CHECK(cache->prune_one(PruneReason::NONE, true, i) == false); + + CHECK_EQUAL(2, cache->get_count()); // Ensure no flows were pruned + + // Add a new ICMP flow + FlowKey flow_key_icmp; + flow_key_icmp.port_l = port++; + flow_key_icmp.pkt_type = PktType::ICMP; + cache->allocate(&flow_key_icmp); + + CHECK_EQUAL(3, cache->get_count()); + CHECK_EQUAL(2, cache->get_lru_flow_count(allowlist_lru_index)); + + // Prune Reason::NONE will not be able to prune allow listed flow, only 1 UDP + CHECK_EQUAL(1, cache->prune_multiple(PruneReason::NONE, true)); + + // we can't prune to 0 so 1 flow will be pruned + CHECK_EQUAL(1, cache->prune_multiple(PruneReason::MEMCAP, true)); + + CHECK_EQUAL(1, cache->get_count()); + CHECK_EQUAL(1, cache->get_lru_flow_count(allowlist_lru_index)); + + // Adding five UDP flows, these will become the LRU flows + for (unsigned i = 0; i < 5; ++i) + { + FlowKey flow_key; + flow_key.port_l = port++; + flow_key.pkt_type = PktType::UDP; + + Flow* flow = cache->allocate(&flow_key); + flow->last_data_seen = 2 + i; + } + + CHECK_EQUAL(6, cache->get_count()); + + // Adding three TCP flows, move two to allow list, making them MRU + for (unsigned i = 0; i < 3; ++i) + { + FlowKey flow_key; + flow_key.port_l = port++; + flow_key.pkt_type = PktType::TCP; + + Flow* flow = cache->allocate(&flow_key); + flow->last_data_seen = 4 + i; // Set TCP flows to have later timeout + + if (i > 0) + { + CHECK(cache->move_to_allowlist(flow) == true); + + Flow* found_flow = cache->find(&flow_key); + CHECK(found_flow == flow); + CHECK(found_flow->flags.in_allowlist == 1); + } + } + + CHECK_EQUAL(5, cache->get_lru_flow_count(to_utype(PktType::UDP))); + CHECK_EQUAL(1, cache->get_lru_flow_count(to_utype(PktType::TCP))); + CHECK_EQUAL(3, cache->get_lru_flow_count(allowlist_lru_index)); + CHECK_EQUAL(9, cache->get_count()); // Verify total flows (5 UDP + 1 TCP + 3 allow list) + CHECK_EQUAL(3, cache->get_lru_flow_count(allowlist_lru_index)); // Verify 3 allow listed flows + + // Timeout 4 flows, 3 UDP and 1 TCP + CHECK_EQUAL(4, cache->timeout(5, 9)); + CHECK_EQUAL(5, cache->get_count()); // Ensure 4 flows remain (2 UDP + 3 allow listed TCP) + CHECK_EQUAL(3, cache->count_flows_in_lru(allowlist_lru_index)); + CHECK_EQUAL(0, cache->count_flows_in_lru(to_utype(PktType::TCP))); + CHECK_EQUAL(2, cache->count_flows_in_lru(to_utype(PktType::UDP))); + + //try multiple prune 2 UDP flow should be pruned as other flows are allow listed + CHECK_EQUAL(2, cache->prune_multiple(PruneReason::NONE, true)); + + //memcap prune can prune all the flows + CHECK_EQUAL(2, cache->prune_multiple(PruneReason::MEMCAP, true)); + + CHECK_EQUAL(1, cache->get_count()); + CHECK_EQUAL(1, cache->get_lru_flow_count(allowlist_lru_index)); + + // Clean up + cache->purge(); + delete cache; +} +TEST(allowlist_test, excess_prune) +{ + FlowCacheConfig fcg; + fcg.max_flows = 5; + fcg.prune_flows = 2; + DummyCache* cache = new DummyCache(fcg); + int port = 1; + + for (unsigned i = 0; i < 6; ++i) + { + FlowKey flow_key; + flow_key.port_l = port++; + flow_key.pkt_type = PktType::TCP; + + Flow* flow = cache->allocate(&flow_key); + CHECK(cache->move_to_allowlist(flow) == true); + } + + // allocating 6 flows and moving all to allowlist + // max_flows is 5 one flow should be pruned + CHECK_EQUAL(5, cache->get_count()); + CHECK_EQUAL(5, cache->get_lru_flow_count(allowlist_lru_index)); + + // Prune 3 flows, expect 2 flows pruned + CHECK_EQUAL(2, cache->prune_multiple(PruneReason::EXCESS, true)); + CHECK_EQUAL(3, cache->get_count()); + CHECK_EQUAL(3, cache->get_lru_flow_count(allowlist_lru_index)); + + cache->purge(); + delete cache; +} + TEST_GROUP(dump_flows) { }; TEST(dump_flows, dump_flows_with_all_empty_caches) @@ -495,6 +732,139 @@ TEST(dump_flows, dump_flows_with_102_tcp_flows_and_202_udp_flows) delete cache; } +TEST(dump_flows, dump_flows_with_allowlist) +{ + FlowCacheConfig fcg; + fcg.max_flows = 500; + FilterFlowCriteria ffc; + std::fstream dump_stream; + DummyCache* cache = new DummyCache(fcg); + int port = 1; + FlowKey flow_key[10]; + + // Add TCP flows and mark some as allow listed + for (unsigned i = 0; i < 10; ++i) + { + flow_key[i].port_l = port++; + flow_key[i].pkt_type = PktType::TCP; + Flow* flow = cache->allocate(&flow_key[i]); + // Mark the first 5 flows as allow listed + if (i < 5) + { + CHECK(cache->move_to_allowlist(flow) == true); + } + } + + CHECK(cache->get_count() == 10); + + //check flows are properly moved to allow list + CHECK(cache->count_flows_in_lru(to_utype(PktType::TCP)) == 5); // Check 5 TCP flows + CHECK(cache->count_flows_in_lru(allowlist_lru_index) == 5); // Check 5 allow listed flows + + // Check that the first dump call works (with allow listed and non-allow listed flows) + CHECK(cache->dump_flows(dump_stream, 10, ffc, true, 1) == true); + + + // Verify that allow listed flows exist and are correctly handled + for (unsigned i = 0; i < 5; ++i) + { + flow_key[i].port_l = i + 1; // allow listed flow ports + flow_key[i].pkt_type = PktType::TCP; + Flow* flow = cache->find(&flow_key[i]); + CHECK(flow != nullptr); // Ensure the flow is found + CHECK(flow->flags.in_allowlist == 1); // Ensure the flow is allow listed + } + + // Ensure cache cleanup and correct flow counts + cache->purge(); + CHECK(cache->get_flows_allocated() == 0); + CHECK(cache->get_count() == 0); + delete cache; +} + +TEST(dump_flows, dump_flows_no_flows_to_dump) +{ + FlowCacheConfig fcg; + FilterFlowCriteria ffc; + fcg.max_flows = 10; + std::fstream dump_stream; + + DummyCache* cache = new DummyCache(fcg); + CHECK(cache->dump_flows(dump_stream, 100, ffc, true, 1) == true); + + delete cache; +} + +TEST_GROUP(flow_cache_lrus) +{ + FlowCacheConfig fcg; + DummyCache* cache; + + void setup() + { + fcg.max_flows = 20; + cache = new DummyCache(fcg); + } + + void teardown() + { + cache->purge(); + delete cache; + } +}; + +TEST(flow_cache_lrus, count_flows_in_lru_test) +{ + FlowKey flow_keys[10]; + memset(flow_keys, 0, sizeof(flow_keys)); + + flow_keys[0].pkt_type = PktType::TCP; + flow_keys[1].pkt_type = PktType::UDP; + flow_keys[2].pkt_type = PktType::USER; + flow_keys[3].pkt_type = PktType::FILE; + flow_keys[4].pkt_type = PktType::TCP; + flow_keys[5].pkt_type = PktType::TCP; + flow_keys[6].pkt_type = PktType::PDU; + flow_keys[7].pkt_type = PktType::ICMP; + flow_keys[8].pkt_type = PktType::TCP; + flow_keys[9].pkt_type = PktType::ICMP; + + //flow count 4 TCP, 1 UDP, 1 USER, 1 FILE, 1 PDU, 2 ICMP = 10 + // Add the flows to the hash_table + for (int i = 0; i < 10; ++i) + { + flow_keys[i].port_l = i; + Flow* flow = cache->allocate(&flow_keys[i]); + CHECK(flow != nullptr); + } + + CHECK_EQUAL(10, cache->get_count()); // Verify 10 flows in + CHECK_EQUAL(4, cache->count_flows_in_lru(to_utype(PktType::TCP))); // 4 TCP flows + CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::UDP))); // 1 UDP flow + CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::USER))); // 1 USER flow + CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::FILE))); // 1 FILE flow + CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::PDU))); // 1 PDU flow + CHECK_EQUAL(2, cache->count_flows_in_lru(to_utype(PktType::ICMP))); // 2 ICMP flow + + Flow* flow1 = cache->find(&flow_keys[0]); + Flow* flow2 = cache->find(&flow_keys[1]); + Flow* flow3 = cache->find(&flow_keys[6]); + CHECK(cache->move_to_allowlist(flow1)); + CHECK(cache->move_to_allowlist(flow2)); + CHECK(cache->move_to_allowlist(flow3)); + + CHECK_EQUAL(10, cache->get_count()); + CHECK_EQUAL(3, cache->count_flows_in_lru(to_utype(PktType::TCP))); // 3 TCP flows + CHECK_EQUAL(0, cache->count_flows_in_lru(to_utype(PktType::UDP))); // 0 UDP flows + CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::USER))); // 1 USER flow + CHECK_EQUAL(1, cache->count_flows_in_lru(to_utype(PktType::FILE))); // 1 FILE flow + CHECK_EQUAL(0, cache->count_flows_in_lru(to_utype(PktType::PDU))); // 0 PDU flows + CHECK_EQUAL(2, cache->count_flows_in_lru(to_utype(PktType::ICMP))); // 2 ICMP flow + // Check the allow listed flows + CHECK_EQUAL(3, cache->count_flows_in_lru(allowlist_lru_index)); // 3 allowlist flows + +} + int main(int argc, char** argv) { return CommandLineTestRunner::RunAllTests(argc, argv); diff --git a/src/flow/test/flow_control_test.cc b/src/flow/test/flow_control_test.cc index c238da128..5785eb84a 100644 --- a/src/flow/test/flow_control_test.cc +++ b/src/flow/test/flow_control_test.cc @@ -85,6 +85,8 @@ ExpectCache::ExpectCache(uint32_t) { } ExpectCache::~ExpectCache() = default; bool ExpectCache::check(Packet*, Flow*) { return true; } Flow* HighAvailabilityManager::import(Packet&, FlowKey&) { return nullptr; } +bool FlowCache::move_to_allowlist(snort::Flow*) { return true; } +uint64_t FlowCache::get_lru_flow_count(uint8_t) const { return 0; } namespace snort { diff --git a/src/flow/test/flow_test.cc b/src/flow/test/flow_test.cc index 560eeb24d..d5d5f0240 100644 --- a/src/flow/test/flow_test.cc +++ b/src/flow/test/flow_test.cc @@ -26,6 +26,8 @@ #include "detection/context_switcher.h" #include "detection/detection_engine.h" #include "flow/flow.h" +#include "flow/flow_config.h" +#include "flow/flow_control.h" #include "flow/flow_stash.h" #include "flow/ha.h" #include "framework/inspector.h" @@ -44,6 +46,16 @@ #include "flow_stubs.h" using namespace snort; +THREAD_LOCAL class FlowControl* flow_con; + +const FlowCacheConfig& FlowControl::get_flow_cache_config() const +{ + static FlowCacheConfig fcc; + fcc.allowlist_cache = true; + return fcc; +} + +bool FlowControl:: move_to_allowlist(snort::Flow*) { return true; } void Inspector::rem_ref() {} diff --git a/src/framework/decode_data.h b/src/framework/decode_data.h index 64a3522aa..3e87718bc 100644 --- a/src/framework/decode_data.h +++ b/src/framework/decode_data.h @@ -47,6 +47,20 @@ enum class PktType : std::uint8_t NONE, IP, TCP, UDP, ICMP, USER, FILE, PDU, MAX }; +enum class LRUType : std::uint8_t +{ + NONE = static_cast(PktType::NONE), + IP = static_cast(PktType::IP), + TCP = static_cast(PktType::TCP), + UDP = static_cast(PktType::UDP), + ICMP = static_cast(PktType::ICMP), + USER = static_cast(PktType::USER), + FILE = static_cast(PktType::FILE), + PDU = static_cast(PktType::PDU), + ALLOW_LIST, + MAX +}; + // the first several of these bits must map to PktType // eg PROTO_BIT__IP == BIT(PktType::IP), etc. #define PROTO_BIT__NONE 0x000000 diff --git a/src/hash/hash_lru_cache.cc b/src/hash/hash_lru_cache.cc index 43764bbc9..c6fe32149 100644 --- a/src/hash/hash_lru_cache.cc +++ b/src/hash/hash_lru_cache.cc @@ -44,6 +44,7 @@ void HashLruCache::insert(HashNode* hnode) else tail = hnode; head = hnode; + node_count++; } void HashLruCache::touch(HashNode* hnode) @@ -83,4 +84,5 @@ void HashLruCache::remove_node(HashNode* hnode) if ( tail == hnode ) tail = hnode->gprev; + node_count--; } diff --git a/src/hash/hash_lru_cache.h b/src/hash/hash_lru_cache.h index c5ed0e88b..9b1d2f8ef 100644 --- a/src/hash/hash_lru_cache.h +++ b/src/hash/hash_lru_cache.h @@ -78,12 +78,16 @@ public: return rnode; } + inline uint64_t get_node_count() + { return node_count; } + private: snort::HashNode* head = nullptr; snort::HashNode* tail = nullptr; snort::HashNode* cursor = nullptr; //walk_cursor is used to traverse from tail to head while dumping the flows. snort::HashNode* walk_cursor = nullptr; + uint64_t node_count = 0; }; #endif diff --git a/src/hash/xhash.cc b/src/hash/xhash.cc index b4b206f4b..187fc76a5 100644 --- a/src/hash/xhash.cc +++ b/src/hash/xhash.cc @@ -305,13 +305,13 @@ void XHash::update_cursor() } } -void* XHash::get_user_data(const void* key, uint8_t type) +void* XHash::get_user_data(const void* key, uint8_t type, bool touch) { assert(key); assert(type < num_lru_caches); int rindex = 0; - HashNode* hnode = find_node_row(key, rindex, type); + HashNode* hnode = find_node_row(key, rindex, type, touch); return ( hnode ) ? hnode->data : nullptr; } @@ -407,7 +407,14 @@ void XHash::move_to_front(HashNode* node,uint8_t type) lru_caches[type]->touch(node); } -HashNode* XHash::find_node_row(const void* key, int& rindex, uint8_t type) +void XHash::touch_last_found(uint8_t type) +{ + assert(type < num_lru_caches); + if ( lfind ) + move_to_front(lfind, type); +} + +HashNode* XHash::find_node_row(const void* key, int& rindex, uint8_t type, bool touch) { assert(type < num_lru_caches); unsigned hashkey = hashkey_ops->do_hash((const unsigned char*)key, keysize); @@ -418,7 +425,9 @@ HashNode* XHash::find_node_row(const void* key, int& rindex, uint8_t type) { if ( hashkey_ops->key_compare(hnode->key, key, keysize) ) { - move_to_front(hnode,type); + lfind = hnode; + if ( touch ) + move_to_front(hnode, type); return hnode; } } @@ -544,6 +553,28 @@ HashNode* XHash::release_lru_node(uint8_t type) return hnode; } +void XHash::switch_node_lru_cache(HashNode* hnode, uint8_t old_type, uint8_t new_type) +{ + lru_caches[old_type]->remove_node(hnode); + lru_caches[new_type]->insert(hnode); +} + +bool XHash::switch_lru_cache(const void* key, uint8_t old_type, uint8_t new_type) +{ + assert(old_type < num_lru_caches); + assert(new_type < num_lru_caches); + + int rindex = 0; + HashNode* hnode = (HashNode*)find_node_row(key, rindex, old_type); + if ( hnode ) + { + switch_node_lru_cache(hnode, old_type, new_type); + return true; + } + + return false; +} + bool XHash::delete_lru_node(uint8_t type) { assert(type < num_lru_caches); diff --git a/src/hash/xhash.h b/src/hash/xhash.h index 6ecf6206c..ffa716eac 100644 --- a/src/hash/xhash.h +++ b/src/hash/xhash.h @@ -58,7 +58,7 @@ public: HashNode* find_first_node(); HashNode* find_next_node(); void* get_user_data(); - void* get_user_data(const void* key, uint8_t type = 0); + void* get_user_data(const void* key, uint8_t type = 0, bool touch = true); void release(uint8_t type = 0); int release_node(const void* key, uint8_t type = 0); int release_node(HashNode* node, uint8_t type = 0); @@ -69,6 +69,8 @@ public: bool delete_lru_node(uint8_t type = 0); void clear_hash(); bool full() const { return !fhead; } + bool switch_lru_cache(const void* key, uint8_t old_type, uint8_t new_type); + void touch_last_found(uint8_t type = 0); // set max hash nodes, 0 == no limit void set_max_nodes(int max) @@ -100,7 +102,7 @@ protected: void initialize_node(HashNode*, const void* key, void* data, int index, uint8_t type = 0); HashNode* allocate_node(const void* key, void* data, int index); - HashNode* find_node_row(const void* key, int& rindex, uint8_t type = 0); + HashNode* find_node_row(const void* key, int& rindex, uint8_t type = 0, bool touch = true); void link_node(HashNode*); void unlink_node(HashNode*); bool delete_a_node(); @@ -130,6 +132,7 @@ private: HashKeyOperations* hashkey_ops = nullptr; HashNode* cursor = nullptr; HashNode* fhead = nullptr; + HashNode* lfind = nullptr; unsigned datasize = 0; unsigned long mem_cap = 0; unsigned max_nodes = 0; @@ -142,6 +145,7 @@ private: HashNode* release_lru_node(uint8_t type = 0); void update_cursor(); void purge_free_list(); + void switch_node_lru_cache(HashNode* hnode, uint8_t old_type, uint8_t new_type); }; } // namespace snort diff --git a/src/hash/zhash.cc b/src/hash/zhash.cc index e3c937d57..79713db46 100644 --- a/src/hash/zhash.cc +++ b/src/hash/zhash.cc @@ -134,3 +134,9 @@ void ZHash::lru_touch(uint8_t type) assert(node); lru_caches[type]->touch(node); } + +uint64_t ZHash::get_node_count(uint8_t type) +{ + assert(type < num_lru_caches); + return lru_caches[type]->get_node_count(); +} diff --git a/src/hash/zhash.h b/src/hash/zhash.h index e43076cc2..52d7c43e9 100644 --- a/src/hash/zhash.h +++ b/src/hash/zhash.h @@ -36,6 +36,7 @@ public: void* pop(); void* get(const void* key, uint8_t type = 0); + uint64_t get_node_count(uint8_t type); void* remove(uint8_t type = 0); void* lru_first(uint8_t type = 0); diff --git a/src/main/analyzer.cc b/src/main/analyzer.cc index 05f173eac..15f823628 100644 --- a/src/main/analyzer.cc +++ b/src/main/analyzer.cc @@ -297,6 +297,8 @@ static DAQ_Verdict distill_verdict(Packet* p) verdict = DAQ_VERDICT_PASS; daq_stats.internal_whitelist++; } + else if ( p->flow ) + p->flow->handle_allowlist(); } if ( p->flow ) diff --git a/src/main/test/distill_verdict_stubs.h b/src/main/test/distill_verdict_stubs.h index 8b2ef50a3..34a1ea037 100644 --- a/src/main/test/distill_verdict_stubs.h +++ b/src/main/test/distill_verdict_stubs.h @@ -27,6 +27,7 @@ #include "filters/rate_filter.h" #include "filters/sfrf.h" #include "filters/sfthreshold.h" +#include "flow/flow_control.h" #include "flow/ha.h" #include "framework/data_bus.h" #include "latency/packet_latency.h" @@ -68,6 +69,7 @@ THREAD_LOCAL DAQStats daq_stats; THREAD_LOCAL bool RuleContext::enabled = false; THREAD_LOCAL bool snort::TimeProfilerStats::enabled; THREAD_LOCAL snort::PacketTracer* snort::PacketTracer::s_pkt_trace; +THREAD_LOCAL class FlowControl* flow_con; void Profiler::start() { } void Profiler::stop(uint64_t) { } @@ -230,11 +232,13 @@ IpsContext::IpsContext(unsigned) { } NetworkPolicy* get_network_policy() { return nullptr; } InspectionPolicy* get_inspection_policy() { return nullptr; } Flow::~Flow() = default; +bool Flow::handle_allowlist() { return true; } void ThreadConfig::implement_thread_affinity(SThreadType, unsigned) { } void ThreadConfig::apply_thread_policy(SThreadType , unsigned ) { } void ThreadConfig::set_instance_tid(int) { } } +bool FlowControl::move_to_allowlist(snort::Flow*) { return true; } void memory::MemoryCap::thread_init() { } void memory::MemoryCap::thread_term() { } void memory::MemoryCap::free_space() { } diff --git a/src/main/test/distill_verdict_test.cc b/src/main/test/distill_verdict_test.cc index 0159654d8..0a69a0cae 100644 --- a/src/main/test/distill_verdict_test.cc +++ b/src/main/test/distill_verdict_test.cc @@ -51,12 +51,20 @@ void Flow::trust() { } SFDAQInstance* SFDAQ::get_local_instance() { return nullptr; } + unsigned int get_random_seed() { return 3193; } unsigned DataBus::get_id(const PubKey&) { return 0; } } +const FlowCacheConfig& FlowControl::get_flow_cache_config() const +{ + static FlowCacheConfig cfg; + cfg.allowlist_cache = true; + return cfg; +} + using namespace snort; //-------------------------------------------------------------------------- diff --git a/src/stream/base/stream_base.cc b/src/stream/base/stream_base.cc index 5f0df34c1..084adb43c 100644 --- a/src/stream/base/stream_base.cc +++ b/src/stream/base/stream_base.cc @@ -96,15 +96,17 @@ const PegInfo base_pegs[] = { CountType::SUM, "user_memcap_prunes", "number of USER flows pruned due to memcap" }, { CountType::SUM, "file_memcap_prunes", "number of FILE flows pruned due to memcap" }, { CountType::SUM, "pdu_memcap_prunes", "number of PDU flows pruned due to memcap" }, + { CountType::SUM, "allowlist_memcap_prunes", "number of allowlist flows pruned due to memcap" }, // Keep the NOW stats at the bottom as it requires special sum_stats logic + { CountType::NOW, "allowlist_flows", "number of flows moved to the allow list" }, { CountType::NOW, "current_flows", "current number of flows in cache" }, { CountType::NOW, "uni_flows", "number of uni flows in cache" }, { CountType::NOW, "uni_ip_flows", "number of uni ip flows in cache" }, { CountType::END, nullptr, nullptr } }; -#define NOW_PEGS_NUM 3 +#define NOW_PEGS_NUM 4 // FIXIT-L dependency on stats define in another file void base_prep() @@ -139,7 +141,9 @@ void base_prep() stream_base_stats.user_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, PktType::USER); stream_base_stats.file_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, PktType::FILE); stream_base_stats.pdu_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, PktType::PDU); + stream_base_stats.allowlist_memcap_prunes = flow_con->get_proto_prune_count(PruneReason::MEMCAP, static_cast(allowlist_lru_index)); + stream_base_stats.allowlist_flows = flow_con->get_allowlist_flow_count(); stream_base_stats.current_flows = flow_con->get_num_flows(); stream_base_stats.uni_flows = flow_con->get_uni_flows(); stream_base_stats.uni_ip_flows = flow_con->get_uni_ip_flows(); diff --git a/src/stream/base/stream_module.cc b/src/stream/base/stream_module.cc index 895510569..af8bf2506 100644 --- a/src/stream/base/stream_module.cc +++ b/src/stream/base/stream_module.cc @@ -67,6 +67,14 @@ static const Parameter name[] = \ { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } \ } +static const Parameter allowlist_cache_params[] = +{ + { "enable", Parameter::PT_BOOL, nullptr, "false", + "enable allowlist cache" }, + + { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } +}; + FLOW_TYPE_PARAMS(ip_params, "180"); FLOW_TYPE_PARAMS(icmp_params, "180"); FLOW_TYPE_PARAMS(tcp_params, "3600"); @@ -103,6 +111,8 @@ static const Parameter s_params[] = { "require_3whs", Parameter::PT_INT, "-1:max31", "-1", "don't track midstream TCP sessions after given seconds from start up; -1 tracks all" }, + { "allowlist_cache", Parameter::PT_TABLE, allowlist_cache_params, nullptr, "configure allowlist cache" }, + FLOW_TYPE_TABLE("ip_cache", "ip", ip_params), FLOW_TYPE_TABLE("icmp_cache", "icmp", icmp_params), FLOW_TYPE_TABLE("tcp_cache", "tcp", tcp_params), @@ -317,6 +327,9 @@ bool StreamModule::set(const char* fqn, Value& v, SnortConfig* c) else if ( v.is("require_3whs") ) config.hs_timeout = v.get_int32(); + else if ( !strcmp(fqn, "stream.allowlist_cache.enable") ) + config.flow_cache_cfg.allowlist_cache = v.get_bool(); + else if ( !strcmp(fqn, "stream.file_cache.idle_timeout") ) config.flow_cache_cfg.proto[to_utype(PktType::FILE)].nominal_timeout = v.get_uint32(); @@ -493,6 +506,12 @@ void StreamModuleConfig::show() const ConfigLogger::log_value(flow_type_names[i], tmp.c_str()); } + { + std::string tmp; + tmp += "{ enable = " + (flow_cache_cfg.allowlist_cache ? std::string("true") : std::string("false")); + tmp += " }"; + ConfigLogger::log_value("allowlist_cache", tmp.c_str()); + } } bool HPQReloadTuner::tinit() diff --git a/src/stream/base/stream_module.h b/src/stream/base/stream_module.h index d4ccdb619..ab7c4d6ee 100644 --- a/src/stream/base/stream_module.h +++ b/src/stream/base/stream_module.h @@ -103,8 +103,10 @@ struct BaseStats PegCount user_memcap_prunes; PegCount file_memcap_prunes; PegCount pdu_memcap_prunes; + PegCount allowlist_memcap_prunes; // Keep the NOW stats at the bottom as it requires special sum_stats logic + PegCount allowlist_flows; PegCount current_flows; PegCount uni_flows; PegCount uni_ip_flows;