AppIdStatistics::initialize_manager(*config);
appid_forecast_tinit();
LuaDetectorManager::initialize(*active_config);
- AppIdServiceState::initialize();
+ AppIdServiceState::initialize(config->memcap);
appidDebug = new AppIdDebug();
if (active_config->mod_config and active_config->mod_config->log_all_sessions)
appidDebug->set_enabled(true);
{ "first_decrypted_packet_debug", Parameter::PT_INT, "0:max32", "0",
"the first packet of an already decrypted SSL flow (debug single session only)" },
#endif
- { "memcap", Parameter::PT_INT, "0:maxSZ", "0",
- "disregard - not implemented" }, // FIXIT-M implement or delete appid.memcap
+ { "memcap", Parameter::PT_INT, "1024:maxSZ", "1048576",
+ "max size of the service cache before we start pruning the cache" },
{ "log_stats", Parameter::PT_BOOL, nullptr, "false",
"enable logging of appid statistics" },
{ "app_stats_period", Parameter::PT_INT, "1:max32", "300",
if ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::START )
{
asd.service_search_state = SESSION_SERVICE_SEARCH_STATE::PORT;
- sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+ sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true);
sds->set_reset_time(0);
SERVICE_ID_STATE sds_state = sds->get_state();
!asd.service_detector and ( dir == APP_ID_FROM_RESPONDER ) ) )
{
if (!sds)
- sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+ sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true);
// Don't log this if fail service is not due to empty list
if (appidDebug->is_active() and !(got_fail_service and asd.service_detector))
LogMessage("AppIdDbg %s No service %s\n", appidDebug->get_debug_session(),
tmp_ip = p->ptrs.ip_api.get_src();
if (!sds)
- sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+ sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted(), true);
if (got_incompatible_service)
sds->update_service_incompatiable(tmp_ip);
const SfIp* ip = pkt->ptrs.ip_api.get_src();
uint16_t port = asd.service_port ? asd.service_port : pkt->ptrs.sp;
ServiceDiscoveryState* sds = AppIdServiceState::add(ip, asd.protocol, port,
- asd.is_decrypted());
+ asd.is_decrypted()); // do not touch here
sds->set_service(service);
sds->set_reset_time(0);
if ( !asd.service_ip.is_set() )
void ServiceDiscovery::release_thread_resources()
{
for (auto detectors : service_detector_list)
- detectors->release_thread_resources();
+ detectors->release_thread_resources();
}
#include "service_state.h"
+#include <list>
#include <map>
#include "log/messages.h"
using namespace snort;
+static THREAD_LOCAL MapList* service_state_cache = nullptr;
+
+
+const size_t MapList::sz = sizeof(Val_t) +
+ sizeof(Map_t::value_type) + sizeof(Queue_t::value_type);
+
ServiceDiscoveryState::ServiceDiscoveryState()
{
state = SERVICE_ID_STATE::SEARCHING_PORT_PATTERN;
}
}
-class AppIdServiceStateKey
-{
-public:
- AppIdServiceStateKey()
- {
- ip.clear();
- port = 0;
- level = 0;
- proto = IpProtocol::PROTO_NOT_SET;
- padding[0] = padding[1] = padding[2] = 0;
- }
-
- bool operator<(AppIdServiceStateKey right) const
- {
- if ( ip.less_than(right.ip) )
- return true;
- else if ( right.ip.less_than(ip) )
- return false;
- else
- {
- if ( port < right.port )
- return true;
- else if ( right.port < port )
- return false;
- else if ( proto < right.proto )
- return true;
- else if ( right.proto < proto )
- return false;
- else if ( level < right.level )
- return true;
- else
- return false;
- }
- }
-
- SfIp ip;
- uint16_t port;
- uint32_t level;
- IpProtocol proto;
- char padding[3];
-};
-static THREAD_LOCAL std::map<AppIdServiceStateKey, ServiceDiscoveryState*>* service_state_cache =
- nullptr;
-
-void AppIdServiceState::initialize()
+void AppIdServiceState::initialize(size_t memcap)
{
- service_state_cache = new std::map<AppIdServiceStateKey, ServiceDiscoveryState*>;
+ service_state_cache = new MapList(memcap);
}
void AppIdServiceState::clean()
{
- if ( service_state_cache )
- {
- for ( auto& kv : *service_state_cache )
- delete kv.second;
-
- service_state_cache->clear();
- delete service_state_cache;
- service_state_cache = nullptr;
- }
+ delete service_state_cache;
}
ServiceDiscoveryState* AppIdServiceState::add(const SfIp* ip, IpProtocol proto, uint16_t port,
- bool decrypted)
+ bool decrypted, bool do_touch)
{
- AppIdServiceStateKey ssk;
- ServiceDiscoveryState* ss = nullptr;
-
- ssk.ip.set(*ip);
- ssk.proto = proto;
- ssk.port = port;
- ssk.level = decrypted ? 1 : 0;
-
- std::map<AppIdServiceStateKey, ServiceDiscoveryState*>::iterator it;
- it = service_state_cache->find(ssk);
- if ( it == service_state_cache->end() )
- {
- ss = new ServiceDiscoveryState;
- (*service_state_cache)[ssk] = ss;
- }
- else
- ss = it->second;
-
- return ss;
+ return service_state_cache->add( AppIdServiceStateKey(ip, proto, port, decrypted), do_touch );
}
ServiceDiscoveryState* AppIdServiceState::get(const SfIp* ip, IpProtocol proto, uint16_t port,
- bool decrypted)
+ bool decrypted, bool do_touch)
{
- AppIdServiceStateKey ssk;
- ServiceDiscoveryState* ss = nullptr;
-
- ssk.ip.set(*ip);
- ssk.proto = proto;
- ssk.port = port;
- ssk.level = decrypted ? 1 : 0;
-
- std::map<AppIdServiceStateKey, ServiceDiscoveryState*>::iterator it;
- it = service_state_cache->find(ssk);
- if ( it != service_state_cache->end() )
- ss = it->second;
-
- return ss;
+ return service_state_cache->get( AppIdServiceStateKey(ip, proto, port, decrypted), do_touch);
}
void AppIdServiceState::remove(const SfIp* ip, IpProtocol proto, uint16_t port, bool decrypted)
{
- AppIdServiceStateKey ssk;
-
- ssk.ip.set(*ip);
- ssk.proto = proto;
- ssk.port = port;
- ssk.level = decrypted ? 1 : 0;
+ AppIdServiceStateKey ssk(ip, proto, port, decrypted);
+ Map_t::iterator it = service_state_cache->find(ssk);
- std::map<AppIdServiceStateKey, ServiceDiscoveryState*>::iterator it;
- it = service_state_cache->find(ssk);
- if ( it != service_state_cache->end() )
- {
- delete it->second;
- service_state_cache->erase(it);
- }
- else
+ if ( !service_state_cache->remove(it) )
{
char ipstr[INET6_ADDRSTRLEN];
}
#endif
}
-
#ifndef SERVICE_STATE_H
#define SERVICE_STATE_H
+#include <list>
+#include <map>
+
#include "protocols/protocol_ids.h"
#include "sfip/sf_ip.h"
class ServiceDetector;
+class AppIdServiceStateKey;
+class ServiceDiscoveryState;
+
+typedef AppIdServiceStateKey Key_t;
+typedef ServiceDiscoveryState Val_t;
+
+typedef std::map<Key_t, Val_t*> Map_t;
+typedef std::list<Map_t::iterator> Queue_t;
+
+
enum SERVICE_ID_STATE
{
SEARCHING_PORT_PATTERN = 0,
reset_time = resetTime;
}
+ Queue_t::iterator qptr; // Our place in service_state_queue
+
private:
SERVICE_ID_STATE state;
ServiceDetector* service = nullptr;
class AppIdServiceState
{
public:
- static void initialize();
+ static void initialize(size_t memcap = 0);
static void clean();
- static ServiceDiscoveryState* add(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted);
- static ServiceDiscoveryState* get(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted);
+ static ServiceDiscoveryState* add(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted, bool do_touch = false);
+ static ServiceDiscoveryState* get(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted, bool do_touch = false);
static void remove(const snort::SfIp*, IpProtocol, uint16_t port, bool decrypted);
static void check_reset(AppIdSession& asd, const snort::SfIp* ip, uint16_t port);
static void dump_stats();
};
-#endif
+class AppIdServiceStateKey
+{
+public:
+ AppIdServiceStateKey()
+ {
+ ip.clear();
+ port = 0;
+ level = 0;
+ proto = IpProtocol::PROTO_NOT_SET;
+ padding[0] = padding[1] = padding[2] = 0;
+ }
+
+ AppIdServiceStateKey(const snort::SfIp* ip_in,
+ IpProtocol proto_in, uint16_t port_in, bool decrypted)
+ {
+ ip.set(*ip_in);
+ port = port_in;
+ level = decrypted != 0;
+ proto = proto_in;
+ padding[0] = padding[1] = padding[2] = 0;
+ }
+
+ bool operator<(AppIdServiceStateKey right) const
+ {
+ if ( ip.less_than(right.ip) )
+ return true;
+ else if ( right.ip.less_than(ip) )
+ return false;
+ else
+ {
+ if ( port < right.port )
+ return true;
+ else if ( right.port < port )
+ return false;
+ else if ( proto < right.proto )
+ return true;
+ else if ( right.proto < proto )
+ return false;
+ else if ( level < right.level )
+ return true;
+ else
+ return false;
+ }
+ }
+
+private:
+ snort::SfIp ip;
+ uint16_t port;
+ uint32_t level;
+ IpProtocol proto;
+ char padding[3];
+};
+
+
+class MapList
+{
+public:
+
+ MapList(size_t cap) : memcap(cap), mem_used(0) {}
+
+ ~MapList()
+ {
+ for ( auto& kv : m )
+ delete kv.second;
+ }
+
+ Val_t* add(const Key_t& k, bool do_touch = false)
+ {
+ Val_t* ss = nullptr;
+
+ Map_t::iterator it = m.find(k);
+ if ( it == m.end() )
+ {
+ // Prune the map to make room for the new sds if memcap is hit
+ if ( mem_used + sz > memcap )
+ remove( q.front() );
+
+ ss = new Val_t;
+
+ std::pair<Map_t::iterator, bool> sit = m.emplace(std::make_pair(k,ss));
+ q.emplace_back(sit.first);
+ mem_used += sz;
+ ss->qptr = --q.end(); // remember our place in the queue
+ }
+ else {
+ ss = it->second;
+ if ( do_touch )
+ touch(ss->qptr);
+ }
+
+ return ss;
+ }
+
+ Val_t* get(const Key_t& k, bool do_touch = 0)
+ {
+ Map_t::const_iterator it = m.find(k);
+ if ( it != m.end() ) {
+ if ( do_touch )
+ touch(it->second->qptr);
+ return it->second;
+ }
+ return nullptr;
+ }
+
+ bool remove(Map_t::iterator it)
+ {
+ if ( it != m.end() )
+ {
+ mem_used -= sz;
+ q.erase(it->second->qptr); // remove from queue
+ delete it->second;
+ m.erase(it); // then from cache
+ return true;
+ }
+ return false;
+ }
+
+ Map_t::iterator find(const Key_t& k)
+ {
+ return m.find(k);
+ }
+
+ void touch(Queue_t::iterator& qptr)
+ {
+ // If we don't already have the highest priority...
+ if ( *qptr != q.back() )
+ {
+ q.emplace_back(*qptr);
+ q.erase(qptr);
+ qptr = --q.end();
+ }
+ }
+
+ size_t size() const { return m.size(); }
+
+ Queue_t::iterator newest() { return --q.end(); }
+ Queue_t::iterator oldest() { return q.begin(); }
+ Queue_t::iterator end() { return q.end(); }
+
+ // how much memory we add when we put an SDS in the cache:
+ static const size_t sz;
+
+private:
+ Map_t m;
+ Queue_t q;
+ size_t memcap;
+ size_t mem_used;
+};
+
+#endif
{ return false; }
// Stubs for service_state.h
-ServiceDiscoveryState* AppIdServiceState::get(SfIp const*, IpProtocol, unsigned short, bool)
+ServiceDiscoveryState* AppIdServiceState::get(SfIp const*, IpProtocol, unsigned short, bool, bool)
{
return nullptr;
}
-ServiceDiscoveryState* AppIdServiceState::add(SfIp const*, IpProtocol, unsigned short, bool)
+ServiceDiscoveryState* AppIdServiceState::add(SfIp const*, IpProtocol, unsigned short, bool, bool)
{
return nullptr;
}
THREAD_LOCAL AppIdStats appid_stats;
#endif
-
#include <CppUTest/CommandLineTestRunner.h>
#include <CppUTest/TestHarness.h>
+#include <vector>
+
namespace snort
{
// Stubs for logs
CHECK_TRUE(sds.get_state() == SERVICE_ID_STATE::VALID);
}
+TEST(service_state_tests, service_cache)
+{
+ size_t num_entries = 10, max_entries = 3;
+ size_t memcap = max_entries*MapList::sz;
+ MapList ServiceCache(memcap);
+
+ IpProtocol proto = IpProtocol::TCP;
+ uint16_t port = 3000;
+ SfIp ip;
+ ip.set("10.10.0.1");
+
+ Val_t* ss = nullptr;
+ std::vector<Val_t*> ssvec;
+
+ // Insert past the memcap, and check the memcap is not exceeded:
+ for( size_t i = 1; i <= num_entries; i++, port++ )
+ {
+ ss = ServiceCache.add( Key_t(&ip, proto, port, 0) );
+ CHECK_TRUE(ServiceCache.size() == ( i <= max_entries ? i : max_entries));
+ ssvec.push_back(ss);
+ }
+
+ // The cache should now be port 8, 9, 10
+ Queue_t::iterator it = ServiceCache.newest();
+ std::vector<Val_t*>::iterator vit = --ssvec.end();
+ for( size_t i=0; i<max_entries; i++, --it, --vit )
+ {
+ Map_t::iterator mit = *it;
+ CHECK_TRUE( mit->second == *vit );
+ }
+
+ // Now get an entry in the cache and check that it got touched:
+ port -= 1;
+ ss = ServiceCache.get( Key_t(&ip, proto, port, 0) );
+ CHECK_TRUE( ss != nullptr );
+ CHECK_TRUE( ss->qptr == ServiceCache.newest() );
+}
+
int main(int argc, char** argv)
{
int rc = CommandLineTestRunner::RunAllTests(argc, argv);