From: Mike Stepanek (mstepane) Date: Tue, 26 Feb 2019 18:12:52 +0000 (-0500) Subject: Merge pull request #1522 in SNORT/snort3 from ~SMINUT/snort3:appid_service_cache... X-Git-Tag: 3.0.0-251~36 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=7ddf6ddee5428456596bc2b073ccca273a480230;p=thirdparty%2Fsnort3.git Merge pull request #1522 in SNORT/snort3 from ~SMINUT/snort3:appid_service_cache to master Squashed commit of the following: commit 85be96aa1e48c63b2782c61f6d28bb15b11542c6 Author: Silviu Minut Date: Fri Feb 15 17:15:41 2019 -0500 appid: implement service discovery state queue to honor memcap. appid: the service queue should be of type AppIdServiceStateKey. appid: change the service queue to store map iterators rather than the actual keys, as (a) map iterators are stable and (b) sizeof(map::iterator)=8 while sizeof(key)=28. appid: compute the size of the memory used for a service cache entry only once, as it is constant, and make it global. appid: implement service cache touch(). Must figure out where to call it from. appid: fix double free in service_state_queue and address reviewers comments. appid: introduce min memcap of 1024 with a default of 1Mb and refactor AppIdServiceState::remove() to accept a ServiceCache_t::iterator rather than ip, proto, port and decrypted. appid: put the service_state_cache and the service_state_queue into a class in its own right and refactor the code. appid: unit test for service cache and call the touch function. appid: untabify service_state.h and test/service_state_test.cc. appid: remove forgotten WhereMacro. appid: introduce the do_touch flag to the add/get functions and call those functions with the appropriate flag. appid: update unit test file. --- diff --git a/src/network_inspectors/appid/appid_inspector.cc b/src/network_inspectors/appid/appid_inspector.cc index 1f9a2f203..808fd76c7 100644 --- a/src/network_inspectors/appid/appid_inspector.cc +++ b/src/network_inspectors/appid/appid_inspector.cc @@ -156,7 +156,7 @@ void AppIdInspector::tinit() AppIdStatistics::initialize_manager(*config); appid_forecast_tinit(); LuaDetectorManager::initialize(*active_config); - AppIdServiceState::initialize(); + AppIdServiceState::initialize(config->memcap); appidDebug = new AppIdDebug(); if (active_config->mod_config and active_config->mod_config->log_all_sessions) appidDebug->set_enabled(true); diff --git a/src/network_inspectors/appid/appid_module.cc b/src/network_inspectors/appid/appid_module.cc index 2ecad659f..7debbed92 100644 --- a/src/network_inspectors/appid/appid_module.cc +++ b/src/network_inspectors/appid/appid_module.cc @@ -60,8 +60,8 @@ static const Parameter s_params[] = { "first_decrypted_packet_debug", Parameter::PT_INT, "0:max32", "0", "the first packet of an already decrypted SSL flow (debug single session only)" }, #endif - { "memcap", Parameter::PT_INT, "0:maxSZ", "0", - "disregard - not implemented" }, // FIXIT-M implement or delete appid.memcap + { "memcap", Parameter::PT_INT, "1024:maxSZ", "1048576", + "max size of the service cache before we start pruning the cache" }, { "log_stats", Parameter::PT_BOOL, nullptr, "false", "enable logging of appid statistics" }, { "app_stats_period", Parameter::PT_INT, "1:max32", "300", diff --git a/src/network_inspectors/appid/service_plugins/service_discovery.cc b/src/network_inspectors/appid/service_plugins/service_discovery.cc index d19bac338..a5f57839e 100644 --- a/src/network_inspectors/appid/service_plugins/service_discovery.cc +++ b/src/network_inspectors/appid/service_plugins/service_discovery.cc @@ -454,7 +454,7 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p, if ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::START ) { asd.service_search_state = SESSION_SERVICE_SEARCH_STATE::PORT; - sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted()); + sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true); sds->set_reset_time(0); SERVICE_ID_STATE sds_state = sds->get_state(); @@ -553,7 +553,7 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p, !asd.service_detector and ( dir == APP_ID_FROM_RESPONDER ) ) ) { if (!sds) - sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted()); + sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true); // Don't log this if fail service is not due to empty list if (appidDebug->is_active() and !(got_fail_service and asd.service_detector)) LogMessage("AppIdDbg %s No service %s\n", appidDebug->get_debug_session(), @@ -574,7 +574,7 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p, tmp_ip = p->ptrs.ip_api.get_src(); if (!sds) - sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted()); + sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true); if (got_incompatible_service) sds->update_service_incompatiable(tmp_ip); @@ -774,7 +774,7 @@ int ServiceDiscovery::incompatible_data(AppIdSession& asd, const Packet* pkt, Ap const SfIp* ip = pkt->ptrs.ip_api.get_src(); uint16_t port = asd.service_port ? asd.service_port : pkt->ptrs.sp; ServiceDiscoveryState* sds = AppIdServiceState::add(ip, asd.protocol, port, - asd.is_decrypted()); + asd.is_decrypted()); // do not touch here sds->set_service(service); sds->set_reset_time(0); if ( !asd.service_ip.is_set() ) @@ -835,5 +835,5 @@ int ServiceDiscovery::fail_service(AppIdSession& asd, const Packet* pkt, AppidSe void ServiceDiscovery::release_thread_resources() { for (auto detectors : service_detector_list) - detectors->release_thread_resources(); + detectors->release_thread_resources(); } diff --git a/src/network_inspectors/appid/service_state.cc b/src/network_inspectors/appid/service_state.cc index d08469468..a85cfe0d0 100644 --- a/src/network_inspectors/appid/service_state.cc +++ b/src/network_inspectors/appid/service_state.cc @@ -25,6 +25,7 @@ #include "service_state.h" +#include #include #include "log/messages.h" @@ -37,6 +38,12 @@ using namespace snort; +static THREAD_LOCAL MapList* service_state_cache = nullptr; + + +const size_t MapList::sz = sizeof(Val_t) + + sizeof(Map_t::value_type) + sizeof(Queue_t::value_type); + ServiceDiscoveryState::ServiceDiscoveryState() { state = SERVICE_ID_STATE::SEARCHING_PORT_PATTERN; @@ -186,129 +193,35 @@ void ServiceDiscoveryState::update_service_incompatiable(const SfIp* ip) } } -class AppIdServiceStateKey -{ -public: - AppIdServiceStateKey() - { - ip.clear(); - port = 0; - level = 0; - proto = IpProtocol::PROTO_NOT_SET; - padding[0] = padding[1] = padding[2] = 0; - } - - bool operator<(AppIdServiceStateKey right) const - { - if ( ip.less_than(right.ip) ) - return true; - else if ( right.ip.less_than(ip) ) - return false; - else - { - if ( port < right.port ) - return true; - else if ( right.port < port ) - return false; - else if ( proto < right.proto ) - return true; - else if ( right.proto < proto ) - return false; - else if ( level < right.level ) - return true; - else - return false; - } - } - - SfIp ip; - uint16_t port; - uint32_t level; - IpProtocol proto; - char padding[3]; -}; -static THREAD_LOCAL std::map* service_state_cache = - nullptr; - -void AppIdServiceState::initialize() +void AppIdServiceState::initialize(size_t memcap) { - service_state_cache = new std::map; + service_state_cache = new MapList(memcap); } void AppIdServiceState::clean() { - if ( service_state_cache ) - { - for ( auto& kv : *service_state_cache ) - delete kv.second; - - service_state_cache->clear(); - delete service_state_cache; - service_state_cache = nullptr; - } + delete service_state_cache; } ServiceDiscoveryState* AppIdServiceState::add(const SfIp* ip, IpProtocol proto, uint16_t port, - bool decrypted) + bool decrypted, bool do_touch) { - AppIdServiceStateKey ssk; - ServiceDiscoveryState* ss = nullptr; - - ssk.ip.set(*ip); - ssk.proto = proto; - ssk.port = port; - ssk.level = decrypted ? 1 : 0; - - std::map::iterator it; - it = service_state_cache->find(ssk); - if ( it == service_state_cache->end() ) - { - ss = new ServiceDiscoveryState; - (*service_state_cache)[ssk] = ss; - } - else - ss = it->second; - - return ss; + return service_state_cache->add( AppIdServiceStateKey(ip, proto, port, decrypted), do_touch ); } ServiceDiscoveryState* AppIdServiceState::get(const SfIp* ip, IpProtocol proto, uint16_t port, - bool decrypted) + bool decrypted, bool do_touch) { - AppIdServiceStateKey ssk; - ServiceDiscoveryState* ss = nullptr; - - ssk.ip.set(*ip); - ssk.proto = proto; - ssk.port = port; - ssk.level = decrypted ? 1 : 0; - - std::map::iterator it; - it = service_state_cache->find(ssk); - if ( it != service_state_cache->end() ) - ss = it->second; - - return ss; + return service_state_cache->get( AppIdServiceStateKey(ip, proto, port, decrypted), do_touch); } void AppIdServiceState::remove(const SfIp* ip, IpProtocol proto, uint16_t port, bool decrypted) { - AppIdServiceStateKey ssk; - - ssk.ip.set(*ip); - ssk.proto = proto; - ssk.port = port; - ssk.level = decrypted ? 1 : 0; + AppIdServiceStateKey ssk(ip, proto, port, decrypted); + Map_t::iterator it = service_state_cache->find(ssk); - std::map::iterator it; - it = service_state_cache->find(ssk); - if ( it != service_state_cache->end() ) - { - delete it->second; - service_state_cache->erase(it); - } - else + if ( !service_state_cache->remove(it) ) { char ipstr[INET6_ADDRSTRLEN]; @@ -354,4 +267,3 @@ void AppIdServiceState::dump_stats() } #endif } - diff --git a/src/network_inspectors/appid/service_state.h b/src/network_inspectors/appid/service_state.h index dc6c48f9c..40ee2ccff 100644 --- a/src/network_inspectors/appid/service_state.h +++ b/src/network_inspectors/appid/service_state.h @@ -22,6 +22,9 @@ #ifndef SERVICE_STATE_H #define SERVICE_STATE_H +#include +#include + #include "protocols/protocol_ids.h" #include "sfip/sf_ip.h" @@ -30,6 +33,16 @@ class ServiceDetector; +class AppIdServiceStateKey; +class ServiceDiscoveryState; + +typedef AppIdServiceStateKey Key_t; +typedef ServiceDiscoveryState Val_t; + +typedef std::map Map_t; +typedef std::list Queue_t; + + enum SERVICE_ID_STATE { SEARCHING_PORT_PATTERN = 0, @@ -110,6 +123,8 @@ public: reset_time = resetTime; } + Queue_t::iterator qptr; // Our place in service_state_queue + private: SERVICE_ID_STATE state; ServiceDetector* service = nullptr; @@ -133,15 +148,164 @@ private: class AppIdServiceState { public: - static void initialize(); + static void initialize(size_t memcap = 0); static void clean(); - static ServiceDiscoveryState* add(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted); - static ServiceDiscoveryState* get(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted); + static ServiceDiscoveryState* add(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted, bool do_touch = false); + static ServiceDiscoveryState* get(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted, bool do_touch = false); static void remove(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted); static void check_reset(AppIdSession& asd, const snort::SfIp* ip, uint16_t port); static void dump_stats(); }; -#endif +class AppIdServiceStateKey +{ +public: + AppIdServiceStateKey() + { + ip.clear(); + port = 0; + level = 0; + proto = IpProtocol::PROTO_NOT_SET; + padding[0] = padding[1] = padding[2] = 0; + } + + AppIdServiceStateKey(const snort::SfIp* ip_in, + IpProtocol proto_in, uint16_t port_in, bool decrypted) + { + ip.set(*ip_in); + port = port_in; + level = decrypted != 0; + proto = proto_in; + padding[0] = padding[1] = padding[2] = 0; + } + + bool operator<(AppIdServiceStateKey right) const + { + if ( ip.less_than(right.ip) ) + return true; + else if ( right.ip.less_than(ip) ) + return false; + else + { + if ( port < right.port ) + return true; + else if ( right.port < port ) + return false; + else if ( proto < right.proto ) + return true; + else if ( right.proto < proto ) + return false; + else if ( level < right.level ) + return true; + else + return false; + } + } + +private: + snort::SfIp ip; + uint16_t port; + uint32_t level; + IpProtocol proto; + char padding[3]; +}; + + +class MapList +{ +public: + + MapList(size_t cap) : memcap(cap), mem_used(0) {} + + ~MapList() + { + for ( auto& kv : m ) + delete kv.second; + } + + Val_t* add(const Key_t& k, bool do_touch = false) + { + Val_t* ss = nullptr; + + Map_t::iterator it = m.find(k); + if ( it == m.end() ) + { + // Prune the map to make room for the new sds if memcap is hit + if ( mem_used + sz > memcap ) + remove( q.front() ); + + ss = new Val_t; + + std::pair sit = m.emplace(std::make_pair(k,ss)); + q.emplace_back(sit.first); + mem_used += sz; + ss->qptr = --q.end(); // remember our place in the queue + } + else { + ss = it->second; + if ( do_touch ) + touch(ss->qptr); + } + + return ss; + } + + Val_t* get(const Key_t& k, bool do_touch = 0) + { + Map_t::const_iterator it = m.find(k); + if ( it != m.end() ) { + if ( do_touch ) + touch(it->second->qptr); + return it->second; + } + return nullptr; + } + + bool remove(Map_t::iterator it) + { + if ( it != m.end() ) + { + mem_used -= sz; + q.erase(it->second->qptr); // remove from queue + delete it->second; + m.erase(it); // then from cache + return true; + } + return false; + } + + Map_t::iterator find(const Key_t& k) + { + return m.find(k); + } + + void touch(Queue_t::iterator& qptr) + { + // If we don't already have the highest priority... + if ( *qptr != q.back() ) + { + q.emplace_back(*qptr); + q.erase(qptr); + qptr = --q.end(); + } + } + + size_t size() const { return m.size(); } + + Queue_t::iterator newest() { return --q.end(); } + Queue_t::iterator oldest() { return q.begin(); } + Queue_t::iterator end() { return q.end(); } + + // how much memory we add when we put an SDS in the cache: + static const size_t sz; + +private: + Map_t m; + Queue_t q; + size_t memcap; + size_t mem_used; +}; + +#endif diff --git a/src/network_inspectors/appid/test/appid_mock_definitions.h b/src/network_inspectors/appid/test/appid_mock_definitions.h index a6bc71120..6f2449701 100644 --- a/src/network_inspectors/appid/test/appid_mock_definitions.h +++ b/src/network_inspectors/appid/test/appid_mock_definitions.h @@ -78,12 +78,12 @@ bool AppInfoManager::configured() { return false; } // Stubs for service_state.h -ServiceDiscoveryState* AppIdServiceState::get(SfIp const*, IpProtocol, unsigned short, bool) +ServiceDiscoveryState* AppIdServiceState::get(SfIp const*, IpProtocol, unsigned short, bool, bool) { return nullptr; } -ServiceDiscoveryState* AppIdServiceState::add(SfIp const*, IpProtocol, unsigned short, bool) +ServiceDiscoveryState* AppIdServiceState::add(SfIp const*, IpProtocol, unsigned short, bool, bool) { return nullptr; } @@ -115,4 +115,3 @@ void mock_cleanup_appid_pegs() THREAD_LOCAL AppIdStats appid_stats; #endif - diff --git a/src/network_inspectors/appid/test/service_state_test.cc b/src/network_inspectors/appid/test/service_state_test.cc index 637e52e68..5790bb427 100644 --- a/src/network_inspectors/appid/test/service_state_test.cc +++ b/src/network_inspectors/appid/test/service_state_test.cc @@ -26,6 +26,8 @@ #include #include +#include + namespace snort { // Stubs for logs @@ -183,6 +185,44 @@ TEST(service_state_tests, set_service_id_failed_with_valid) CHECK_TRUE(sds.get_state() == SERVICE_ID_STATE::VALID); } +TEST(service_state_tests, service_cache) +{ + size_t num_entries = 10, max_entries = 3; + size_t memcap = max_entries*MapList::sz; + MapList ServiceCache(memcap); + + IpProtocol proto = IpProtocol::TCP; + uint16_t port = 3000; + SfIp ip; + ip.set("10.10.0.1"); + + Val_t* ss = nullptr; + std::vector ssvec; + + // Insert past the memcap, and check the memcap is not exceeded: + for( size_t i = 1; i <= num_entries; i++, port++ ) + { + ss = ServiceCache.add( Key_t(&ip, proto, port, 0) ); + CHECK_TRUE(ServiceCache.size() == ( i <= max_entries ? i : max_entries)); + ssvec.push_back(ss); + } + + // The cache should now be port 8, 9, 10 + Queue_t::iterator it = ServiceCache.newest(); + std::vector::iterator vit = --ssvec.end(); + for( size_t i=0; isecond == *vit ); + } + + // Now get an entry in the cache and check that it got touched: + port -= 1; + ss = ServiceCache.get( Key_t(&ip, proto, port, 0) ); + CHECK_TRUE( ss != nullptr ); + CHECK_TRUE( ss->qptr == ServiceCache.newest() ); +} + int main(int argc, char** argv) { int rc = CommandLineTestRunner::RunAllTests(argc, argv);