From: Steve Chew (stechew) Date: Fri, 26 Feb 2021 18:06:47 +0000 (+0000) Subject: Merge pull request #2747 in SNORT/snort3 from ~SBAIGAL/snort3:perf_ha to master X-Git-Tag: 3.1.2.0~23 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=a671b9aa74fe3ca37ea981e946ccf1776515833f;p=thirdparty%2Fsnort3.git Merge pull request #2747 in SNORT/snort3 from ~SBAIGAL/snort3:perf_ha to master Squashed commit of the following: commit 8a93f67c57c000a089e52459f3f6ddd425387a28 Author: Steven Baigal (sbaigal) Date: Thu Feb 18 16:31:11 2021 -0500 stream: do not update service from appid to host attributes if nothing is changed commit 58111934f03848ddb29be00ba9268ca93d801262 Author: Steven Baigal (sbaigal) Date: Thu Feb 18 13:40:20 2021 -0500 host_attributes: updated api to reduce use of shared_pointer commit 678f77983e959ac97e659ceb000dd3bcb4d05baa Author: Steven Baigal (sbaigal) Date: Thu Feb 18 12:43:56 2021 -0500 binder: use service inspector caching to improve get_gadget() performance --- diff --git a/src/flow/flow.h b/src/flow/flow.h index e12f18099..410b41c4d 100644 --- a/src/flow/flow.h +++ b/src/flow/flow.h @@ -474,8 +474,8 @@ public: // FIXIT-M privatize if possible bool trigger_detained_packet_event : 1; bool trigger_finalize_event : 1; bool use_direct_inject : 1; - bool data_decrypted : 1; // indicate data in current flow is decrypted TLS application - //data + bool data_decrypted : 1; // indicate data in current flow is decrypted TLS application data + bool snort_proto_id_set_by_ha : 1; } flags; FlowState flow_state; diff --git a/src/managers/inspector_manager.cc b/src/managers/inspector_manager.cc index 4235848e8..3765634b7 100644 --- a/src/managers/inspector_manager.cc +++ b/src/managers/inspector_manager.cc @@ -231,11 +231,50 @@ struct FrameworkPolicy Inspector* binder; Inspector* wizard; + std::unordered_map inspector_cache_by_id; + std::unordered_map inspector_cache_by_service; + bool default_binder; void vectorize(SnortConfig*); + void add_inspector_to_cache(PHInstance*, SnortConfig*); + void remove_inspector_from_cache(Inspector*); }; +void FrameworkPolicy::add_inspector_to_cache(PHInstance* p, SnortConfig* sc) +{ + if (p->pp_class.api.type == IT_SERVICE and p->pp_class.api.service and p->handler) + { + SnortProtocolId id = sc->proto_ref->find(p->pp_class.api.service); + if (id != UNKNOWN_PROTOCOL_ID) + inspector_cache_by_id[id] = p->handler; + inspector_cache_by_service[p->pp_class.api.service] = p->handler; + } +} + +void FrameworkPolicy::remove_inspector_from_cache(Inspector* ins) +{ + if (!ins) + return; + + for(auto i = inspector_cache_by_id.begin(); i != inspector_cache_by_id.end(); i++) + { + if (ins == i->second) + { + inspector_cache_by_id.erase(i); + break; + } + } + for(auto i = inspector_cache_by_service.begin(); i != inspector_cache_by_service.end(); i++) + { + if (ins == i->second) + { + inspector_cache_by_service.erase(i); + break; + } + } +} + void FrameworkPolicy::vectorize(SnortConfig* sc) { passive.alloc(ilist.size()); @@ -421,20 +460,6 @@ static bool get_instance( return false; } -static PHInstance* get_instance_by_service(FrameworkPolicy* fp, const char* keyword, - InspectorType type) -{ - std::vector::iterator it; - - for ( it = fp->ilist.begin(); it != fp->ilist.end(); ++it ) - { - if ( (*it)->pp_class.api.service && !strcmp((*it)->pp_class.api.service, keyword) && - (*it)->pp_class.api.type == type ) - return *it; - } - return nullptr; -} - static PHInstance* get_instance(FrameworkPolicy* fp, const char* keyword) { std::vector::iterator it; @@ -580,19 +605,26 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons return p->handler; } -Inspector* InspectorManager::get_inspector_by_service(const char* key, InspectorType type) +Inspector* InspectorManager::get_service_inspector_by_service(const char* key) { InspectionPolicy* pi = get_inspection_policy(); if ( !pi || !pi->framework_policy ) return nullptr; - PHInstance* p = get_instance_by_service(pi->framework_policy, key, type); + auto g = pi->framework_policy->inspector_cache_by_service.find(key); + return (g != pi->framework_policy->inspector_cache_by_service.end()) ? g->second : nullptr; +} - if ( !p ) - return nullptr; +Inspector* InspectorManager::get_service_inspector_by_id(const SnortProtocolId protocol_id) +{ + InspectionPolicy* pi = get_inspection_policy(); - return p->handler; + if ( !pi || !pi->framework_policy ) + return nullptr; + + auto g = pi->framework_policy->inspector_cache_by_id.find(protocol_id); + return (g != pi->framework_policy->inspector_cache_by_id.end()) ? g->second : nullptr; } bool InspectorManager::delete_inspector(SnortConfig* sc, const char* iname) @@ -604,6 +636,7 @@ bool InspectorManager::delete_inspector(SnortConfig* sc, const char* iname) if ( get_instance(fp, iname, old_it) ) { (*old_it)->set_reloaded(RELOAD_TYPE_DELETED); + fp->remove_inspector_from_cache((*old_it)->handler); fp->ilist.erase(old_it); ok = true; std::vector::iterator bind_it; @@ -951,6 +984,12 @@ static bool configure(SnortConfig* sc, FrameworkPolicy* fp, bool cloned) sort(fp->ilist.begin(), fp->ilist.end(), PHInstance::comp); fp->vectorize(sc); + // create cache + fp->inspector_cache_by_id.clear(); + fp->inspector_cache_by_service.clear(); + for ( auto* p : fp->ilist ) + fp->add_inspector_to_cache(p, sc); + if ( !fp->binder and (fp->session.num or fp->wizard) ) instantiate_default_binder(sc, fp); diff --git a/src/managers/inspector_manager.h b/src/managers/inspector_manager.h index dcc84724c..2b6f7a9c0 100644 --- a/src/managers/inspector_manager.h +++ b/src/managers/inspector_manager.h @@ -64,7 +64,8 @@ public: SO_PUBLIC static Inspector* get_inspector( const char* key, bool dflt_only = false, const SnortConfig* = nullptr); - SO_PUBLIC static Inspector* get_inspector_by_service(const char*, InspectorType type = IT_SERVICE); + static Inspector* get_service_inspector_by_service(const char*); + static Inspector* get_service_inspector_by_id(const SnortProtocolId); SO_PUBLIC static Binder* get_binder(); diff --git a/src/network_inspectors/binder/binder.cc b/src/network_inspectors/binder/binder.cc index fd5a67f90..018540697 100644 --- a/src/network_inspectors/binder/binder.cc +++ b/src/network_inspectors/binder/binder.cc @@ -42,24 +42,12 @@ THREAD_LOCAL ProfileStats bindPerfStats; // helpers //------------------------------------------------------------------------- -static Inspector* get_gadget(const Flow& flow) +static Inspector* get_gadget(const SnortProtocolId protocol_id) { - if (flow.ssn_state.snort_protocol_id == UNKNOWN_PROTOCOL_ID) + if (protocol_id == UNKNOWN_PROTOCOL_ID) return nullptr; - const SnortConfig* sc = SnortConfig::get_conf(); - const char* s = sc->proto_ref->get_name(flow.ssn_state.snort_protocol_id); - - return InspectorManager::get_inspector_by_service(s, IT_SERVICE); -} - -static Inspector* get_gadget_by_service(const char* service) -{ - const SnortConfig* sc = SnortConfig::get_conf(); - const SnortProtocolId id = sc->proto_ref->find(service); - const char* s = sc->proto_ref->get_name(id); - - return InspectorManager::get_inspector_by_service(s, IT_SERVICE); + return InspectorManager::get_service_inspector_by_id(protocol_id); } static std::string to_string(const sfip_var_t* list) @@ -416,7 +404,7 @@ void Stuff::apply_service(Flow& flow) flow.set_data(data); if (!gadget) - gadget = get_gadget(flow); + gadget = get_gadget(flow.ssn_state.snort_protocol_id); if (gadget) { @@ -437,7 +425,7 @@ void Stuff::apply_service(Flow& flow) void Stuff::apply_assistant(Flow& flow, const char* service) { if (!gadget) - gadget = get_gadget_by_service(service); + gadget = InspectorManager::get_service_inspector_by_service(service); if (gadget) flow.set_assistant_gadget(gadget); @@ -632,23 +620,27 @@ void Binder::handle_flow_setup(Flow& flow, bool standby) // FIXIT-M logic for applying information from the host attribute table likely doesn't belong // in binder, but it *does* need to occur before the binding lookup (for service information) - const HostAttributesEntry host = HostAttributesManager::find_host(flow.server_ip); - if (host) + HostAttriInfo host; + HostAttriInfo* p_host = nullptr; + if ( HostAttributesManager::get_host_attributes(flow.server_ip, flow.server_port, &host) ) + p_host = &host; + + if (p_host) { // Set the fragmentation (IP) or stream (TCP) policy from the host entry switch (flow.pkt_type) { case PktType::IP: - flow.ssn_policy = host->get_frag_policy(); + flow.ssn_policy = p_host->frag_policy; break; case PktType::TCP: - flow.ssn_policy = host->get_stream_policy(); + flow.ssn_policy = p_host->stream_policy; break; default: break; } - Stream::set_snort_protocol_id(&flow, host, FROM_SERVER); + Stream::set_snort_protocol_id_from_ha(&flow, p_host->snort_protocol_id); if (flow.ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) { const SnortConfig* sc = SnortConfig::get_conf(); diff --git a/src/stream/stream.cc b/src/stream/stream.cc index 2d106997b..131b21ef5 100644 --- a/src/stream/stream.cc +++ b/src/stream/stream.cc @@ -426,37 +426,22 @@ int Stream::set_snort_protocol_id_expected( swap_app_direction); } -void Stream::set_snort_protocol_id( - Flow* flow, const HostAttributesEntry& host, int /*direction*/) +void Stream::set_snort_protocol_id_from_ha( + Flow* flow, const SnortProtocolId snort_protocol_id) { - SnortProtocolId snort_protocol_id; - if (!flow ) return; - /* Cool, its already set! */ if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) return; if (flow->ssn_state.ipprotocol == 0) - { set_ip_protocol(flow); - } - - snort_protocol_id = host->get_snort_protocol_id - (flow->ssn_state.ipprotocol, flow->server_port); - -#if 0 - // FIXIT-M from client doesn't imply need to swap - if (direction == FROM_CLIENT) - { - if ( snort_protocol_id && - (flow->ssn_state.session_flags & SSNFLAG_MIDSTREAM) ) - flow->ssn_state.session_flags |= SSNFLAG_CLIENT_SWAP; - } -#endif flow->ssn_state.snort_protocol_id = snort_protocol_id; + if ( snort_protocol_id != UNKNOWN_PROTOCOL_ID && + snort_protocol_id != INVALID_PROTOCOL_ID ) + flow->flags.snort_proto_id_set_by_ha = true; } SnortProtocolId Stream::get_snort_protocol_id(Flow* flow) @@ -477,17 +462,18 @@ SnortProtocolId Stream::get_snort_protocol_id(Flow* flow) if (flow->ssn_state.ipprotocol == 0) set_ip_protocol(flow); - if ( HostAttributesEntry host = HostAttributesManager::find_host(flow->server_ip) ) + HostAttriInfo host; + if (HostAttributesManager::get_host_attributes(flow->server_ip, flow->server_port, &host)) { - set_snort_protocol_id(flow, host, FROM_SERVER); + set_snort_protocol_id_from_ha(flow, host.snort_protocol_id); if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) return flow->ssn_state.snort_protocol_id; } - if ( HostAttributesEntry host = HostAttributesManager::find_host(flow->client_ip) ) + if (HostAttributesManager::get_host_attributes(flow->client_ip, flow->client_port, &host)) { - set_snort_protocol_id(flow, host, FROM_CLIENT); + set_snort_protocol_id_from_ha(flow, host.snort_protocol_id); if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) return flow->ssn_state.snort_protocol_id; @@ -502,12 +488,15 @@ SnortProtocolId Stream::set_snort_protocol_id(Flow* flow, SnortProtocolId id, bo if (!flow) return UNKNOWN_PROTOCOL_ID; + if (flow->ssn_state.snort_protocol_id != id) + flow->flags.snort_proto_id_set_by_ha = false; + flow->ssn_state.snort_protocol_id = id; if (!flow->ssn_state.ipprotocol) set_ip_protocol(flow); - if ( !flow->is_proxied() ) + if ( !flow->is_proxied() and !flow->flags.snort_proto_id_set_by_ha ) { HostAttributesManager::update_service (flow->server_ip, flow->server_port, flow->ssn_state.ipprotocol, id, is_appid_service); diff --git a/src/stream/stream.h b/src/stream/stream.h index 749fda2c5..021cfa11b 100644 --- a/src/stream/stream.h +++ b/src/stream/stream.h @@ -221,8 +221,7 @@ public: // Populate a session key from the Packet static void populate_flow_key(Packet*, FlowKey*); - static void set_snort_protocol_id( - Flow*, const HostAttributesEntry&, int direction); + static void set_snort_protocol_id_from_ha(Flow*, const SnortProtocolId); static bool is_midstream(Flow* flow) { return ((flow->ssn_state.session_flags & SSNFLAG_MIDSTREAM) != 0); } diff --git a/src/target_based/host_attributes.cc b/src/target_based/host_attributes.cc index 5d3164aee..7e9748d42 100644 --- a/src/target_based/host_attributes.cc +++ b/src/target_based/host_attributes.cc @@ -89,8 +89,10 @@ bool HostAttributesDescriptor::update_service if ( s.ipproto == protocol && (uint16_t)s.port == port ) { if ( s.snort_protocol_id != snort_protocol_id ) + { + s.snort_protocol_id = snort_protocol_id; s.appid_service = is_appid_service; - s.snort_protocol_id = snort_protocol_id; + } updated = true; return true; } @@ -119,19 +121,21 @@ void HostAttributesDescriptor::clear_appid_services() } } -SnortProtocolId HostAttributesDescriptor::get_snort_protocol_id(int ipprotocol, uint16_t port) const +void HostAttributesDescriptor::get_host_attributes(uint16_t port,HostAttriInfo* host_info) const { - std::lock_guard lck(host_attributes_lock); - + std::lock_guard slk(host_attributes_lock); + host_info->frag_policy = policies.fragPolicy; + host_info->stream_policy = policies.streamPolicy; + host_info->snort_protocol_id = UNKNOWN_PROTOCOL_ID; for ( auto& s : services ) { - if ( (s.ipproto == ipprotocol) && (s.port == port) ) - return s.snort_protocol_id; + if ( s.port == port ) + { + host_info->snort_protocol_id = s.snort_protocol_id; + return; + } } - - return UNKNOWN_PROTOCOL_ID; } - bool HostAttributesManager::load_hosts_file(snort::SnortConfig* sc, const char* fname) { delete next_cache; @@ -188,12 +192,18 @@ void HostAttributesManager::swap_cleanup() void HostAttributesManager::term() { delete active_cache; } -HostAttributesEntry HostAttributesManager::find_host(const snort::SfIp& host_ip) +bool HostAttributesManager::get_host_attributes(const snort::SfIp& host_ip, uint16_t port, HostAttriInfo* host_info) { - if ( active_cache ) - return active_cache->find(host_ip); + if ( !active_cache ) + return false; - return nullptr; + HostAttributesEntry h = active_cache->find(host_ip); + if (h) + { + h->get_host_attributes(port, host_info); + return true; + } + return false; } void HostAttributesManager::update_service(const snort::SfIp& host_ip, uint16_t port, @@ -207,7 +217,6 @@ void HostAttributesManager::update_service(const snort::SfIp& host_ip, uint16_t { if ( created ) { - host->set_ip_addr(host_ip); host_attribute_stats.dynamic_host_adds++; } diff --git a/src/target_based/host_attributes.h b/src/target_based/host_attributes.h index e776fce4f..cb741f631 100644 --- a/src/target_based/host_attributes.h +++ b/src/target_based/host_attributes.h @@ -77,6 +77,13 @@ struct HostPolicyDescriptor uint8_t fragPolicy = 0; }; +struct HostAttriInfo +{ + SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID; + uint8_t stream_policy = 0; + uint8_t frag_policy = 0; +}; + class HostAttributesDescriptor { public: @@ -86,32 +93,24 @@ public: bool update_service(uint16_t port, uint16_t protocol, SnortProtocolId, bool& updated, bool is_appid_service = false); void clear_appid_services(); - SnortProtocolId get_snort_protocol_id(int ipprotocol, uint16_t port) const; + void get_host_attributes(uint16_t, HostAttriInfo*) const; + // Note: the following get/set are only called from main thread on a temp LRU table const snort::SfIp& get_ip_addr() const { return ip_address; } void set_ip_addr(const snort::SfIp& host_ip_addr) { - std::lock_guard lck(host_attributes_lock); ip_address = host_ip_addr; } - uint8_t get_frag_policy() const - { return policies.fragPolicy; } - void set_frag_policy(const uint8_t frag_policy) { - std::lock_guard lck(host_attributes_lock); policies.fragPolicy = frag_policy; } - uint8_t get_stream_policy() const - { return policies.streamPolicy; } - void set_stream_policy(uint8_t stream_policy) { - std::lock_guard lck(host_attributes_lock); policies.streamPolicy = stream_policy; } @@ -150,7 +149,7 @@ public: static void term(); static bool add_host(HostAttributesEntry, snort::SnortConfig*); - static HostAttributesEntry find_host(const snort::SfIp&); + static bool get_host_attributes(const snort::SfIp&, uint16_t, HostAttriInfo*); static void update_service(const snort::SfIp&, uint16_t port, uint16_t protocol, SnortProtocolId, bool is_appid_service = false); static void clear_appid_services();