]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2747 in SNORT/snort3 from ~SBAIGAL/snort3:perf_ha to master
authorSteve Chew (stechew) <stechew@cisco.com>
Fri, 26 Feb 2021 18:06:47 +0000 (18:06 +0000)
committerSteve Chew (stechew) <stechew@cisco.com>
Fri, 26 Feb 2021 18:06:47 +0000 (18:06 +0000)
Squashed commit of the following:

commit 8a93f67c57c000a089e52459f3f6ddd425387a28
Author: Steven Baigal (sbaigal) <sbaigal@cisco.com>
Date:   Thu Feb 18 16:31:11 2021 -0500

    stream: do not update service from appid to host attributes if nothing is changed

commit 58111934f03848ddb29be00ba9268ca93d801262
Author: Steven Baigal (sbaigal) <sbaigal@cisco.com>
Date:   Thu Feb 18 13:40:20 2021 -0500

    host_attributes: updated api to reduce use of shared_pointer

commit 678f77983e959ac97e659ceb000dd3bcb4d05baa
Author: Steven Baigal (sbaigal) <sbaigal@cisco.com>
Date:   Thu Feb 18 12:43:56 2021 -0500

    binder: use service inspector caching to improve get_gadget() performance

src/flow/flow.h
src/managers/inspector_manager.cc
src/managers/inspector_manager.h
src/network_inspectors/binder/binder.cc
src/stream/stream.cc
src/stream/stream.h
src/target_based/host_attributes.cc
src/target_based/host_attributes.h

index e12f1809924fe0acc90f181f2ac143fc2a3c437a..410b41c4daf2e2af3a8acbbb16fb79d935f39028 100644 (file)
@@ -474,8 +474,8 @@ public:  // FIXIT-M privatize if possible
         bool trigger_detained_packet_event : 1;
         bool trigger_finalize_event : 1;
         bool use_direct_inject : 1;
-        bool data_decrypted : 1;    // indicate data in current flow is decrypted TLS application
-                                    //data
+        bool data_decrypted : 1;    // indicate data in current flow is decrypted TLS application data
+        bool snort_proto_id_set_by_ha : 1;
     } flags;
 
     FlowState flow_state;
index 4235848e8b98266580bf56dad291675dc9060dd0..3765634b713935b744d59799a9efd367e8fa98a5 100644 (file)
@@ -231,11 +231,50 @@ struct FrameworkPolicy
     Inspector* binder;
     Inspector* wizard;
 
+    std::unordered_map<SnortProtocolId, Inspector*> inspector_cache_by_id;
+    std::unordered_map<std::string, Inspector*> inspector_cache_by_service;
+
     bool default_binder;
 
     void vectorize(SnortConfig*);
+    void add_inspector_to_cache(PHInstance*, SnortConfig*);
+    void remove_inspector_from_cache(Inspector*);
 };
 
+void FrameworkPolicy::add_inspector_to_cache(PHInstance* p, SnortConfig* sc)
+{
+    if (p->pp_class.api.type == IT_SERVICE and p->pp_class.api.service and p->handler)
+    {
+        SnortProtocolId id = sc->proto_ref->find(p->pp_class.api.service);
+        if (id != UNKNOWN_PROTOCOL_ID)
+            inspector_cache_by_id[id] = p->handler;
+        inspector_cache_by_service[p->pp_class.api.service] = p->handler;
+    }
+}
+
+void FrameworkPolicy::remove_inspector_from_cache(Inspector* ins)
+{
+    if (!ins)
+        return;
+
+    for(auto i = inspector_cache_by_id.begin(); i != inspector_cache_by_id.end(); i++)
+    {
+        if (ins == i->second)
+        {
+            inspector_cache_by_id.erase(i);
+            break;
+        }
+    }
+    for(auto i = inspector_cache_by_service.begin(); i != inspector_cache_by_service.end(); i++)
+    {
+        if (ins == i->second)
+        {
+            inspector_cache_by_service.erase(i);
+            break;
+        }
+    }
+}
+
 void FrameworkPolicy::vectorize(SnortConfig* sc)
 {
     passive.alloc(ilist.size());
@@ -421,20 +460,6 @@ static bool get_instance(
     return false;
 }
 
-static PHInstance* get_instance_by_service(FrameworkPolicy* fp, const char* keyword,
-    InspectorType type)
-{
-    std::vector<PHInstance*>::iterator it;
-
-    for ( it = fp->ilist.begin(); it != fp->ilist.end(); ++it )
-    {
-        if ( (*it)->pp_class.api.service && !strcmp((*it)->pp_class.api.service, keyword) &&
-            (*it)->pp_class.api.type == type )
-            return *it;
-    }
-    return nullptr;
-}
-
 static PHInstance* get_instance(FrameworkPolicy* fp, const char* keyword)
 {
     std::vector<PHInstance*>::iterator it;
@@ -580,19 +605,26 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons
     return p->handler;
 }
 
-Inspector* InspectorManager::get_inspector_by_service(const char* key, InspectorType type)
+Inspector* InspectorManager::get_service_inspector_by_service(const char* key)
 {
     InspectionPolicy* pi = get_inspection_policy();
 
     if ( !pi || !pi->framework_policy )
         return nullptr;
 
-    PHInstance* p = get_instance_by_service(pi->framework_policy, key, type);
+    auto g = pi->framework_policy->inspector_cache_by_service.find(key);
+    return (g != pi->framework_policy->inspector_cache_by_service.end()) ? g->second : nullptr;
+}
 
-    if ( !p )
-        return nullptr;
+Inspector* InspectorManager::get_service_inspector_by_id(const SnortProtocolId protocol_id)
+{
+    InspectionPolicy* pi = get_inspection_policy();
 
-    return p->handler;
+    if ( !pi || !pi->framework_policy )
+        return nullptr;
+    auto g = pi->framework_policy->inspector_cache_by_id.find(protocol_id);
+    return (g != pi->framework_policy->inspector_cache_by_id.end()) ? g->second : nullptr;
 }
 
 bool InspectorManager::delete_inspector(SnortConfig* sc, const char* iname)
@@ -604,6 +636,7 @@ bool InspectorManager::delete_inspector(SnortConfig* sc, const char* iname)
     if ( get_instance(fp, iname, old_it) )
     {
         (*old_it)->set_reloaded(RELOAD_TYPE_DELETED);
+        fp->remove_inspector_from_cache((*old_it)->handler);
         fp->ilist.erase(old_it);
         ok = true;
         std::vector<PHInstance*>::iterator bind_it;
@@ -951,6 +984,12 @@ static bool configure(SnortConfig* sc, FrameworkPolicy* fp, bool cloned)
     sort(fp->ilist.begin(), fp->ilist.end(), PHInstance::comp);
     fp->vectorize(sc);
 
+    // create cache
+    fp->inspector_cache_by_id.clear();
+    fp->inspector_cache_by_service.clear();
+    for ( auto* p : fp->ilist )
+        fp->add_inspector_to_cache(p, sc);
+
     if ( !fp->binder and (fp->session.num or fp->wizard) )
         instantiate_default_binder(sc, fp);
 
index dcc84724c3628f4edf1d9211aa5e5affc920555b..2b6f7a9c0bb81b9da5eb6d21e2110fe407f17ef4 100644 (file)
@@ -64,7 +64,8 @@ public:
     SO_PUBLIC static Inspector* get_inspector(
         const char* key, bool dflt_only = false, const SnortConfig* = nullptr);
 
-    SO_PUBLIC static Inspector* get_inspector_by_service(const char*, InspectorType type = IT_SERVICE);
+    static Inspector* get_service_inspector_by_service(const char*);
+    static Inspector* get_service_inspector_by_id(const SnortProtocolId);
 
     SO_PUBLIC static Binder* get_binder();
 
index fd5a67f908de482762a73dad6de2a0e5d81160d4..018540697383fb23e872f4850ee32ede4e44b91e 100644 (file)
@@ -42,24 +42,12 @@ THREAD_LOCAL ProfileStats bindPerfStats;
 // helpers
 //-------------------------------------------------------------------------
 
-static Inspector* get_gadget(const Flow& flow)
+static Inspector* get_gadget(const SnortProtocolId protocol_id)
 {
-    if (flow.ssn_state.snort_protocol_id == UNKNOWN_PROTOCOL_ID)
+    if (protocol_id == UNKNOWN_PROTOCOL_ID)
         return nullptr;
 
-    const SnortConfig* sc = SnortConfig::get_conf();
-    const char* s = sc->proto_ref->get_name(flow.ssn_state.snort_protocol_id);
-
-    return InspectorManager::get_inspector_by_service(s, IT_SERVICE);
-}
-
-static Inspector* get_gadget_by_service(const char* service)
-{
-    const SnortConfig* sc = SnortConfig::get_conf();
-    const SnortProtocolId id = sc->proto_ref->find(service);
-    const char* s = sc->proto_ref->get_name(id);
-
-    return InspectorManager::get_inspector_by_service(s, IT_SERVICE);
+    return InspectorManager::get_service_inspector_by_id(protocol_id);
 }
 
 static std::string to_string(const sfip_var_t* list)
@@ -416,7 +404,7 @@ void Stuff::apply_service(Flow& flow)
         flow.set_data(data);
 
     if (!gadget)
-        gadget = get_gadget(flow);
+        gadget = get_gadget(flow.ssn_state.snort_protocol_id);
 
     if (gadget)
     {
@@ -437,7 +425,7 @@ void Stuff::apply_service(Flow& flow)
 void Stuff::apply_assistant(Flow& flow, const char* service)
 {
     if (!gadget)
-        gadget = get_gadget_by_service(service);
+        gadget = InspectorManager::get_service_inspector_by_service(service);
 
     if (gadget)
         flow.set_assistant_gadget(gadget);
@@ -632,23 +620,27 @@ void Binder::handle_flow_setup(Flow& flow, bool standby)
 
     // FIXIT-M logic for applying information from the host attribute table likely doesn't belong
     // in binder, but it *does* need to occur before the binding lookup (for service information)
-    const HostAttributesEntry host = HostAttributesManager::find_host(flow.server_ip);
-    if (host)
+    HostAttriInfo host;
+    HostAttriInfo* p_host = nullptr;
+    if ( HostAttributesManager::get_host_attributes(flow.server_ip, flow.server_port, &host) )
+        p_host = &host;
+
+    if (p_host)
     {
         // Set the fragmentation (IP) or stream (TCP) policy from the host entry
         switch (flow.pkt_type)
         {
             case PktType::IP:
-                flow.ssn_policy = host->get_frag_policy();
+                flow.ssn_policy = p_host->frag_policy;
                 break;
             case PktType::TCP:
-                flow.ssn_policy = host->get_stream_policy();
+                flow.ssn_policy = p_host->stream_policy;
                 break;
             default:
                 break;
         }
 
-        Stream::set_snort_protocol_id(&flow, host, FROM_SERVER);
+        Stream::set_snort_protocol_id_from_ha(&flow, p_host->snort_protocol_id);
         if (flow.ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID)
         {
             const SnortConfig* sc = SnortConfig::get_conf();
index 2d106997bdad6b234ed0239b97d1fd9eb37ccfb7..131b21ef5a94fe922d2516d8c34bbb297c82a4a0 100644 (file)
@@ -426,37 +426,22 @@ int Stream::set_snort_protocol_id_expected(
         swap_app_direction);
 }
 
-void Stream::set_snort_protocol_id(
-    Flow* flow, const HostAttributesEntry& host, int /*direction*/)
+void Stream::set_snort_protocol_id_from_ha(
+    Flow* flow, const SnortProtocolId snort_protocol_id)
 {
-    SnortProtocolId snort_protocol_id;
-
     if (!flow )
         return;
 
-    /* Cool, its already set! */
     if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID)
         return;
 
     if (flow->ssn_state.ipprotocol == 0)
-    {
         set_ip_protocol(flow);
-    }
-
-    snort_protocol_id = host->get_snort_protocol_id
-        (flow->ssn_state.ipprotocol, flow->server_port);
-
-#if 0
-    // FIXIT-M from client doesn't imply need to swap
-    if (direction == FROM_CLIENT)
-    {
-        if ( snort_protocol_id &&
-            (flow->ssn_state.session_flags & SSNFLAG_MIDSTREAM) )
-            flow->ssn_state.session_flags |= SSNFLAG_CLIENT_SWAP;
-    }
-#endif
 
     flow->ssn_state.snort_protocol_id = snort_protocol_id;
+    if ( snort_protocol_id != UNKNOWN_PROTOCOL_ID &&
+         snort_protocol_id != INVALID_PROTOCOL_ID )
+        flow->flags.snort_proto_id_set_by_ha = true;
 }
 
 SnortProtocolId Stream::get_snort_protocol_id(Flow* flow)
@@ -477,17 +462,18 @@ SnortProtocolId Stream::get_snort_protocol_id(Flow* flow)
     if (flow->ssn_state.ipprotocol == 0)
         set_ip_protocol(flow);
 
-    if ( HostAttributesEntry host = HostAttributesManager::find_host(flow->server_ip) )
+    HostAttriInfo host;
+    if (HostAttributesManager::get_host_attributes(flow->server_ip, flow->server_port, &host))
     {
-        set_snort_protocol_id(flow, host, FROM_SERVER);
+        set_snort_protocol_id_from_ha(flow, host.snort_protocol_id);
 
         if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID)
             return flow->ssn_state.snort_protocol_id;
     }
 
-    if ( HostAttributesEntry host = HostAttributesManager::find_host(flow->client_ip) )
+    if (HostAttributesManager::get_host_attributes(flow->client_ip, flow->client_port, &host))
     {
-        set_snort_protocol_id(flow, host, FROM_CLIENT);
+        set_snort_protocol_id_from_ha(flow, host.snort_protocol_id);
 
         if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID)
             return flow->ssn_state.snort_protocol_id;
@@ -502,12 +488,15 @@ SnortProtocolId Stream::set_snort_protocol_id(Flow* flow, SnortProtocolId id, bo
     if (!flow)
         return UNKNOWN_PROTOCOL_ID;
 
+    if (flow->ssn_state.snort_protocol_id != id)
+        flow->flags.snort_proto_id_set_by_ha = false;
+
     flow->ssn_state.snort_protocol_id = id;
 
     if (!flow->ssn_state.ipprotocol)
         set_ip_protocol(flow);
 
-    if ( !flow->is_proxied() )
+    if ( !flow->is_proxied() and !flow->flags.snort_proto_id_set_by_ha )
     {
         HostAttributesManager::update_service
             (flow->server_ip, flow->server_port, flow->ssn_state.ipprotocol, id, is_appid_service);
index 749fda2c5e44f2be9252f2eb7b1aaa3590aef225..021cfa11b8a148f4a1f5d568816bf74435f36099 100644 (file)
@@ -221,8 +221,7 @@ public:
     //  Populate a session key from the Packet
     static void populate_flow_key(Packet*, FlowKey*);
 
-    static void set_snort_protocol_id(
-        Flow*, const HostAttributesEntry&, int direction);
+    static void set_snort_protocol_id_from_ha(Flow*, const SnortProtocolId);
 
     static bool is_midstream(Flow* flow)
     { return ((flow->ssn_state.session_flags & SSNFLAG_MIDSTREAM) != 0); }
index 5d3164aeeb27aa7d7349d5a4bf931e8abe147a85..7e9748d424bca1b12376e9d0e605126684d7b77b 100644 (file)
@@ -89,8 +89,10 @@ bool HostAttributesDescriptor::update_service
         if ( s.ipproto == protocol && (uint16_t)s.port == port )
         {
             if ( s.snort_protocol_id != snort_protocol_id )
+            {
+                s.snort_protocol_id = snort_protocol_id;
                 s.appid_service = is_appid_service;
-            s.snort_protocol_id = snort_protocol_id;
+            }
             updated = true;
             return true;
         }
@@ -119,19 +121,21 @@ void HostAttributesDescriptor::clear_appid_services()
     }
 }
 
-SnortProtocolId HostAttributesDescriptor::get_snort_protocol_id(int ipprotocol, uint16_t port) const
+void HostAttributesDescriptor::get_host_attributes(uint16_t port,HostAttriInfo* host_info) const
 {
-    std::lock_guard<std::mutex> lck(host_attributes_lock);
-
+    std::lock_guard<std::mutex> slk(host_attributes_lock);
+    host_info->frag_policy = policies.fragPolicy;
+    host_info->stream_policy = policies.streamPolicy;
+    host_info->snort_protocol_id = UNKNOWN_PROTOCOL_ID;
     for ( auto& s : services )
     {
-        if ( (s.ipproto == ipprotocol) && (s.port == port) )
-            return s.snort_protocol_id;
+        if ( s.port == port )
+        {
+            host_info->snort_protocol_id = s.snort_protocol_id;
+            return;
+        }
     }
-
-    return UNKNOWN_PROTOCOL_ID;
 }
-
 bool HostAttributesManager::load_hosts_file(snort::SnortConfig* sc, const char* fname)
 {
     delete next_cache;
@@ -188,12 +192,18 @@ void HostAttributesManager::swap_cleanup()
 void HostAttributesManager::term()
 { delete active_cache; }
 
-HostAttributesEntry HostAttributesManager::find_host(const snort::SfIp& host_ip)
+bool HostAttributesManager::get_host_attributes(const snort::SfIp& host_ip, uint16_t port, HostAttriInfo* host_info)
 {
-    if ( active_cache )
-        return active_cache->find(host_ip);
+    if ( !active_cache )
+        return false;
 
-    return nullptr;
+    HostAttributesEntry h = active_cache->find(host_ip);
+    if (h)
+    {
+        h->get_host_attributes(port, host_info);
+        return true;
+    }
+    return false;
 }
 
 void HostAttributesManager::update_service(const snort::SfIp& host_ip, uint16_t port,
@@ -207,7 +217,6 @@ void HostAttributesManager::update_service(const snort::SfIp& host_ip, uint16_t
         {
             if ( created )
             {
-                host->set_ip_addr(host_ip);
                 host_attribute_stats.dynamic_host_adds++;
             }
 
index e776fce4f4c7b021e13d1f74354ad8dc4eb24035..cb741f631c1401ebd5f3aa75590f1cd0787f040b 100644 (file)
@@ -77,6 +77,13 @@ struct HostPolicyDescriptor
     uint8_t fragPolicy = 0;
 };
 
+struct HostAttriInfo
+{
+    SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID;
+    uint8_t stream_policy = 0;
+    uint8_t frag_policy = 0;
+};
+
 class HostAttributesDescriptor
 {
 public:
@@ -86,32 +93,24 @@ public:
     bool update_service(uint16_t port, uint16_t protocol, SnortProtocolId, bool& updated,
         bool is_appid_service = false);
     void clear_appid_services();
-    SnortProtocolId get_snort_protocol_id(int ipprotocol, uint16_t port) const;
+    void get_host_attributes(uint16_t, HostAttriInfo*) const;
 
+    // Note: the following get/set are only called from main thread on a temp LRU table
     const snort::SfIp& get_ip_addr() const
     { return ip_address; }
 
     void set_ip_addr(const snort::SfIp& host_ip_addr)
     {
-        std::lock_guard<std::mutex> lck(host_attributes_lock);
         ip_address = host_ip_addr;
     }
 
-    uint8_t get_frag_policy() const
-    { return policies.fragPolicy; }
-
     void set_frag_policy(const uint8_t frag_policy)
     {
-        std::lock_guard<std::mutex> lck(host_attributes_lock);
         policies.fragPolicy = frag_policy;
     }
 
-    uint8_t get_stream_policy() const
-    { return policies.streamPolicy; }
-
     void set_stream_policy(uint8_t stream_policy)
     {
-        std::lock_guard<std::mutex> lck(host_attributes_lock);
         policies.streamPolicy = stream_policy;
     }
 
@@ -150,7 +149,7 @@ public:
     static void term();
 
     static bool add_host(HostAttributesEntry, snort::SnortConfig*);
-    static HostAttributesEntry find_host(const snort::SfIp&);
+    static bool get_host_attributes(const snort::SfIp&, uint16_t, HostAttriInfo*);
     static void update_service(const snort::SfIp&, uint16_t port, uint16_t protocol,
         SnortProtocolId, bool is_appid_service = false);
     static void clear_appid_services();