]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #1522 in SNORT/snort3 from ~SMINUT/snort3:appid_service_cache...
authorMike Stepanek (mstepane) <mstepane@cisco.com>
Tue, 26 Feb 2019 18:12:52 +0000 (13:12 -0500)
committerMike Stepanek (mstepane) <mstepane@cisco.com>
Tue, 26 Feb 2019 18:12:52 +0000 (13:12 -0500)
Squashed commit of the following:

commit 85be96aa1e48c63b2782c61f6d28bb15b11542c6
Author: Silviu Minut <sminut@cisco.com>
Date:   Fri Feb 15 17:15:41 2019 -0500

    appid: implement service discovery state queue to honor memcap.

    appid: the service queue should be of type AppIdServiceStateKey.

    appid: change the service queue to store map iterators rather than the actual keys, as (a) map iterators are stable and (b) sizeof(map::iterator)=8 while sizeof(key)=28.

    appid: compute the size of the memory used for a service cache entry only once, as it is constant, and make it global.

    appid: implement service cache touch(). Must figure out where to call it from.

    appid: fix double free in service_state_queue and address reviewers comments.

    appid: introduce min memcap of 1024 with a default of 1Mb and refactor AppIdServiceState::remove() to accept a ServiceCache_t::iterator rather than ip, proto, port and decrypted.

    appid: put the service_state_cache and the service_state_queue into a class in its own right and refactor the code.

    appid: unit test for service cache and call the touch function.

    appid: untabify service_state.h and test/service_state_test.cc.

    appid: remove forgotten WhereMacro.

    appid: introduce the do_touch flag to the add/get functions and call those functions with the appropriate flag.

    appid: update unit test file.

src/network_inspectors/appid/appid_inspector.cc
src/network_inspectors/appid/appid_module.cc
src/network_inspectors/appid/service_plugins/service_discovery.cc
src/network_inspectors/appid/service_state.cc
src/network_inspectors/appid/service_state.h
src/network_inspectors/appid/test/appid_mock_definitions.h
src/network_inspectors/appid/test/service_state_test.cc

index 1f9a2f203eb24dc04691d617304f7ad1f2402e2d..808fd76c70ebc29933a1f07062d6b398730f642e 100644 (file)
@@ -156,7 +156,7 @@ void AppIdInspector::tinit()
     AppIdStatistics::initialize_manager(*config);
     appid_forecast_tinit();
     LuaDetectorManager::initialize(*active_config);
-    AppIdServiceState::initialize();
+    AppIdServiceState::initialize(config->memcap);
     appidDebug = new AppIdDebug();
     if (active_config->mod_config and active_config->mod_config->log_all_sessions)
         appidDebug->set_enabled(true);
index 2ecad659f382d819a40df48e83fcbdcade727771..7debbed92ef18f35d3a5b0069b6ad18f26cc4004 100644 (file)
@@ -60,8 +60,8 @@ static const Parameter s_params[] =
     { "first_decrypted_packet_debug", Parameter::PT_INT, "0:max32", "0",
       "the first packet of an already decrypted SSL flow (debug single session only)" },
 #endif
-    { "memcap", Parameter::PT_INT, "0:maxSZ", "0",
-      "disregard - not implemented" },  // FIXIT-M implement or delete appid.memcap
+    { "memcap", Parameter::PT_INT, "1024:maxSZ", "1048576",
+      "max size of the service cache before we start pruning the cache" },
     { "log_stats", Parameter::PT_BOOL, nullptr, "false",
       "enable logging of appid statistics" },
     { "app_stats_period", Parameter::PT_INT, "1:max32", "300",
index d19bac3384a406e5c7b945422e331bae39b08563..a5f57839ebff1e0a4b6c5d78bc3b29da8524af1f 100644 (file)
@@ -454,7 +454,7 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p,
     if ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::START )
     {
         asd.service_search_state = SESSION_SERVICE_SEARCH_STATE::PORT;
-        sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+        sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true);
         sds->set_reset_time(0);
         SERVICE_ID_STATE sds_state = sds->get_state();
 
@@ -553,7 +553,7 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p,
          !asd.service_detector and ( dir == APP_ID_FROM_RESPONDER ) ) )
     {
         if (!sds)
-            sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+            sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true);
         // Don't log this if fail service is not due to empty list
         if (appidDebug->is_active() and !(got_fail_service and asd.service_detector))
             LogMessage("AppIdDbg %s No service %s\n", appidDebug->get_debug_session(),
@@ -574,7 +574,7 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p,
             tmp_ip = p->ptrs.ip_api.get_src();
 
         if (!sds)
-            sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+            sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true);
 
         if (got_incompatible_service)
             sds->update_service_incompatiable(tmp_ip);
@@ -774,7 +774,7 @@ int ServiceDiscovery::incompatible_data(AppIdSession& asd, const Packet* pkt, Ap
     const SfIp* ip = pkt->ptrs.ip_api.get_src();
     uint16_t port = asd.service_port ? asd.service_port : pkt->ptrs.sp;
     ServiceDiscoveryState* sds = AppIdServiceState::add(ip, asd.protocol, port,
-        asd.is_decrypted());
+        asd.is_decrypted());     // do not touch here
     sds->set_service(service);
     sds->set_reset_time(0);
     if ( !asd.service_ip.is_set() )
@@ -835,5 +835,5 @@ int ServiceDiscovery::fail_service(AppIdSession& asd, const Packet* pkt, AppidSe
 void ServiceDiscovery::release_thread_resources()
 {
     for (auto detectors : service_detector_list)
-           detectors->release_thread_resources();
+        detectors->release_thread_resources();
 }
index d08469468308ff34a7d3ebe477793dea320eba69..a85cfe0d07512f7bf7935b72456ee1f9dce67856 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "service_state.h"
 
+#include <list>
 #include <map>
 
 #include "log/messages.h"
 
 using namespace snort;
 
+static THREAD_LOCAL MapList* service_state_cache = nullptr;
+
+
+const size_t MapList::sz = sizeof(Val_t) +
+    sizeof(Map_t::value_type) + sizeof(Queue_t::value_type);
+
 ServiceDiscoveryState::ServiceDiscoveryState()
 {
     state = SERVICE_ID_STATE::SEARCHING_PORT_PATTERN;
@@ -186,129 +193,35 @@ void ServiceDiscoveryState::update_service_incompatiable(const SfIp* ip)
     }
 }
 
-class AppIdServiceStateKey
-{
-public:
-    AppIdServiceStateKey()
-    {
-        ip.clear();
-        port = 0;
-        level = 0;
-        proto = IpProtocol::PROTO_NOT_SET;
-        padding[0] = padding[1] = padding[2] = 0;
-    }
-
-    bool operator<(AppIdServiceStateKey right) const
-    {
-        if ( ip.less_than(right.ip) )
-            return true;
-        else if ( right.ip.less_than(ip) )
-            return false;
-        else
-        {
-            if ( port < right.port )
-                return true;
-            else if ( right.port < port )
-                return false;
-            else if ( proto < right.proto )
-                return true;
-            else if ( right.proto < proto )
-                return false;
-            else if ( level < right.level )
-                return true;
-            else
-                return false;
-        }
-    }
-
-    SfIp ip;
-    uint16_t port;
-    uint32_t level;
-    IpProtocol proto;
-    char padding[3];
-};
 
-static THREAD_LOCAL std::map<AppIdServiceStateKey, ServiceDiscoveryState*>* service_state_cache =
-    nullptr;
-
-void AppIdServiceState::initialize()
+void AppIdServiceState::initialize(size_t memcap)
 {
-    service_state_cache = new std::map<AppIdServiceStateKey, ServiceDiscoveryState*>;
+    service_state_cache = new MapList(memcap);
 }
 
 void AppIdServiceState::clean()
 {
-    if ( service_state_cache )
-    {
-        for ( auto& kv : *service_state_cache )
-            delete kv.second;
-
-        service_state_cache->clear();
-        delete service_state_cache;
-        service_state_cache = nullptr;
-    }
+    delete service_state_cache;
 }
 
 ServiceDiscoveryState* AppIdServiceState::add(const SfIp* ip, IpProtocol proto, uint16_t port,
-    bool decrypted)
+    bool decrypted, bool do_touch)
 {
-    AppIdServiceStateKey ssk;
-    ServiceDiscoveryState* ss = nullptr;
-
-    ssk.ip.set(*ip);
-    ssk.proto = proto;
-    ssk.port = port;
-    ssk.level = decrypted ? 1 : 0;
-
-    std::map<AppIdServiceStateKey, ServiceDiscoveryState*>::iterator it;
-    it = service_state_cache->find(ssk);
-    if ( it == service_state_cache->end() )
-    {
-        ss = new ServiceDiscoveryState;
-        (*service_state_cache)[ssk] = ss;
-    }
-    else
-        ss = it->second;
-
-    return ss;
+    return service_state_cache->add( AppIdServiceStateKey(ip, proto, port, decrypted), do_touch );
 }
 
 ServiceDiscoveryState* AppIdServiceState::get(const SfIp* ip, IpProtocol proto, uint16_t port,
-    bool decrypted)
+    bool decrypted, bool do_touch)
 {
-    AppIdServiceStateKey ssk;
-    ServiceDiscoveryState* ss = nullptr;
-
-    ssk.ip.set(*ip);
-    ssk.proto = proto;
-    ssk.port = port;
-    ssk.level = decrypted ? 1 : 0;
-
-    std::map<AppIdServiceStateKey, ServiceDiscoveryState*>::iterator it;
-    it = service_state_cache->find(ssk);
-    if ( it != service_state_cache->end() )
-        ss = it->second;
-
-    return ss;
+    return service_state_cache->get( AppIdServiceStateKey(ip, proto, port, decrypted), do_touch);
 }
 
 void AppIdServiceState::remove(const SfIp* ip, IpProtocol proto, uint16_t port, bool decrypted)
 {
-    AppIdServiceStateKey ssk;
-
-    ssk.ip.set(*ip);
-    ssk.proto = proto;
-    ssk.port = port;
-    ssk.level = decrypted ? 1 : 0;
+    AppIdServiceStateKey ssk(ip, proto, port, decrypted);
+    Map_t::iterator it = service_state_cache->find(ssk);
 
-    std::map<AppIdServiceStateKey, ServiceDiscoveryState*>::iterator it;
-    it = service_state_cache->find(ssk);
-    if ( it != service_state_cache->end() )
-    {
-        delete it->second;
-        service_state_cache->erase(it);
-    }
-    else
+    if ( !service_state_cache->remove(it) )
     {
         char ipstr[INET6_ADDRSTRLEN];
 
@@ -354,4 +267,3 @@ void AppIdServiceState::dump_stats()
     }
 #endif
 }
-
index dc6c48f9ca68e24a4b98da1b94e7c6f17ca21f6a..40ee2ccfff9f943eb74c8e81f4a527c91d12c546 100644 (file)
@@ -22,6 +22,9 @@
 #ifndef SERVICE_STATE_H
 #define SERVICE_STATE_H
 
+#include <list>
+#include <map>
+
 #include "protocols/protocol_ids.h"
 #include "sfip/sf_ip.h"
 
 
 class ServiceDetector;
 
+class AppIdServiceStateKey;
+class ServiceDiscoveryState;
+
+typedef AppIdServiceStateKey Key_t;
+typedef ServiceDiscoveryState Val_t;
+
+typedef std::map<Key_t, Val_t*> Map_t;
+typedef std::list<Map_t::iterator> Queue_t;
+
+
 enum SERVICE_ID_STATE
 {
     SEARCHING_PORT_PATTERN = 0,
@@ -110,6 +123,8 @@ public:
         reset_time = resetTime;
     }
 
+    Queue_t::iterator qptr; // Our place in service_state_queue
+
 private:
     SERVICE_ID_STATE state;
     ServiceDetector* service = nullptr;
@@ -133,15 +148,164 @@ private:
 class AppIdServiceState
 {
 public:
-    static void initialize();
+    static void initialize(size_t memcap = 0);
     static void clean();
-    static ServiceDiscoveryState* add(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted);
-    static ServiceDiscoveryState* get(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted);
+    static ServiceDiscoveryState* add(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted, bool do_touch = false);
+    static ServiceDiscoveryState* get(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted, bool do_touch = false);
     static void remove(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted);
     static void check_reset(AppIdSession& asd, const snort::SfIp* ip, uint16_t port);
 
     static void dump_stats();
 };
 
-#endif
 
+class AppIdServiceStateKey
+{
+public:
+    AppIdServiceStateKey()
+    {
+        ip.clear();
+        port = 0;
+        level = 0;
+        proto = IpProtocol::PROTO_NOT_SET;
+        padding[0] = padding[1] = padding[2] = 0;
+    }
+
+    AppIdServiceStateKey(const snort::SfIp* ip_in,
+        IpProtocol proto_in, uint16_t port_in, bool decrypted)
+    {
+        ip.set(*ip_in);
+        port = port_in;
+        level = decrypted != 0;
+        proto = proto_in;
+        padding[0] = padding[1] = padding[2] = 0;
+    }
+
+    bool operator<(AppIdServiceStateKey right) const
+    {
+        if ( ip.less_than(right.ip) )
+            return true;
+        else if ( right.ip.less_than(ip) )
+            return false;
+        else
+        {
+            if ( port < right.port )
+                return true;
+            else if ( right.port < port )
+                return false;
+            else if ( proto < right.proto )
+                return true;
+            else if ( right.proto < proto )
+                return false;
+            else if ( level < right.level )
+                return true;
+            else
+                return false;
+        }
+    }
+
+private:
+    snort::SfIp ip;
+    uint16_t port;
+    uint32_t level;
+    IpProtocol proto;
+    char padding[3];
+};
+
+
+class MapList
+{
+public:
+
+    MapList(size_t cap) : memcap(cap), mem_used(0) {}
+
+    ~MapList()
+    {
+        for ( auto& kv : m )
+            delete kv.second;
+    }
+
+    Val_t* add(const Key_t& k, bool do_touch = false)
+    {
+        Val_t* ss = nullptr;
+
+        Map_t::iterator it = m.find(k);
+        if ( it == m.end() )
+        {
+            // Prune the map to make room for the new sds if memcap is hit
+            if ( mem_used + sz > memcap )
+                remove( q.front() );
+
+            ss = new Val_t;
+
+            std::pair<Map_t::iterator, bool> sit = m.emplace(std::make_pair(k,ss));
+            q.emplace_back(sit.first);
+            mem_used += sz;
+            ss->qptr = --q.end(); // remember our place in the queue
+        }
+        else {
+            ss = it->second;
+            if ( do_touch )
+                touch(ss->qptr);
+        }
+        
+        return ss;
+    }
+
+    Val_t* get(const Key_t& k, bool do_touch = 0)
+    {
+        Map_t::const_iterator it = m.find(k);
+        if ( it != m.end() ) {
+            if ( do_touch )
+                touch(it->second->qptr);
+            return it->second;
+        }
+        return nullptr;
+    }
+
+    bool remove(Map_t::iterator it)
+    {
+        if ( it != m.end() )
+        {
+            mem_used -= sz;
+            q.erase(it->second->qptr);  // remove from queue
+            delete it->second;
+            m.erase(it);                // then from cache
+            return true;
+        }
+        return false;
+    }
+
+    Map_t::iterator find(const Key_t& k)
+    {
+        return m.find(k);
+    }
+
+    void touch(Queue_t::iterator& qptr)
+    {
+        // If we don't already have the highest priority...
+        if ( *qptr != q.back() )
+        {
+            q.emplace_back(*qptr);
+            q.erase(qptr);
+            qptr = --q.end();
+        }
+    }
+
+    size_t size() const { return m.size(); }
+
+    Queue_t::iterator newest() { return --q.end(); }
+    Queue_t::iterator oldest() { return q.begin(); }
+    Queue_t::iterator end() { return q.end(); }
+
+    // how much memory we add when we put an SDS in the cache:
+    static const size_t sz;
+
+private:
+    Map_t m;
+    Queue_t q;
+    size_t memcap;
+    size_t mem_used;
+};
+
+#endif
index a6bc711200f3092de9b9e1e4fa1449f770d412b6..6f24497016c7fbe512a762c4b7285ac67fdaaa81 100644 (file)
@@ -78,12 +78,12 @@ bool AppInfoManager::configured()
 { return false; }
 
 // Stubs for service_state.h
-ServiceDiscoveryState* AppIdServiceState::get(SfIp const*, IpProtocol, unsigned short, bool)
+ServiceDiscoveryState* AppIdServiceState::get(SfIp const*, IpProtocol, unsigned short, bool, bool)
 {
   return nullptr;
 }
 
-ServiceDiscoveryState* AppIdServiceState::add(SfIp const*, IpProtocol, unsigned short, bool)
+ServiceDiscoveryState* AppIdServiceState::add(SfIp const*, IpProtocol, unsigned short, bool, bool)
 {
   return nullptr;
 }
@@ -115,4 +115,3 @@ void mock_cleanup_appid_pegs()
 THREAD_LOCAL AppIdStats appid_stats;
 
 #endif
-
index 637e52e689b471eb6056ba3e0a8c91ec53a685c4..5790bb427d613e6ea2c5eab17a6f3f13952785ac 100644 (file)
@@ -26,6 +26,8 @@
 #include <CppUTest/CommandLineTestRunner.h>
 #include <CppUTest/TestHarness.h>
 
+#include <vector>
+
 namespace snort
 {
 // Stubs for logs
@@ -183,6 +185,44 @@ TEST(service_state_tests, set_service_id_failed_with_valid)
     CHECK_TRUE(sds.get_state() == SERVICE_ID_STATE::VALID);
 }
 
+TEST(service_state_tests, service_cache)
+{
+    size_t num_entries = 10, max_entries = 3;
+    size_t memcap = max_entries*MapList::sz;
+    MapList ServiceCache(memcap);
+
+    IpProtocol proto = IpProtocol::TCP;
+    uint16_t port = 3000;
+    SfIp ip;
+    ip.set("10.10.0.1");
+
+    Val_t* ss = nullptr;
+    std::vector<Val_t*> ssvec;
+        
+    // Insert past the memcap, and check the memcap is not exceeded:
+    for( size_t i = 1; i <= num_entries; i++, port++ )
+    {
+        ss = ServiceCache.add( Key_t(&ip, proto, port, 0) );
+        CHECK_TRUE(ServiceCache.size() == ( i <= max_entries ? i : max_entries));
+        ssvec.push_back(ss);
+    }
+
+    // The cache should now be port 8, 9, 10
+    Queue_t::iterator it = ServiceCache.newest();
+    std::vector<Val_t*>::iterator vit = --ssvec.end();
+    for( size_t i=0; i<max_entries; i++, --it, --vit )
+    {
+        Map_t::iterator mit = *it;
+        CHECK_TRUE( mit->second == *vit );
+    }
+        
+    // Now get an entry in the cache and check that it got touched:
+    port -= 1;
+    ss = ServiceCache.get( Key_t(&ip, proto, port, 0) );
+    CHECK_TRUE( ss != nullptr );
+    CHECK_TRUE( ss->qptr == ServiceCache.newest() );
+}
+
 int main(int argc, char** argv)
 {
     int rc = CommandLineTestRunner::RunAllTests(argc, argv);