From: Mike Stepanek (mstepane) Date: Mon, 29 Jul 2019 16:35:43 +0000 (-0400) Subject: Merge pull request #1691 in SNORT/snort3 from ~MASHASAN/snort3:refactor_host_cache... X-Git-Tag: 3.0.0-259~27 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=03a828841eab46148672bbe3a10a1a1576d0d99c;p=thirdparty%2Fsnort3.git Merge pull request #1691 in SNORT/snort3 from ~MASHASAN/snort3:refactor_host_cache to master Squashed commit of the following: commit 8226ab4c2662a508d291efb2527777364cbaac6b Author: Masud Hasan Date: Fri Jul 26 01:44:02 2019 -0400 host_cache: Refactoring code to fix multithreading issues and to remove redundancy --- diff --git a/src/hash/lru_cache_shared.cc b/src/hash/lru_cache_shared.cc index f458dee72..d56b1e832 100644 --- a/src/hash/lru_cache_shared.cc +++ b/src/hash/lru_cache_shared.cc @@ -27,12 +27,9 @@ const PegInfo lru_cache_shared_peg_names[] = { { CountType::SUM, "lru_cache_adds", "lru cache added new entry" }, - { CountType::SUM, "lru_cache_replaces", "lru cache replaced existing entry" }, { CountType::SUM, "lru_cache_prunes", "lru cache pruned entry to make space for new entry" }, { CountType::SUM, "lru_cache_find_hits", "lru cache found entry in cache" }, { CountType::SUM, "lru_cache_find_misses", "lru cache did not find entry in cache" }, - { CountType::SUM, "lru_cache_removes", "lru cache found entry and removed it" }, - { CountType::SUM, "lru_cache_clears", "lru cache clear API calls" }, { CountType::END, nullptr, nullptr }, }; diff --git a/src/hash/lru_cache_shared.h b/src/hash/lru_cache_shared.h index 097855616..6f2503e4a 100644 --- a/src/hash/lru_cache_shared.h +++ b/src/hash/lru_cache_shared.h @@ -25,9 +25,11 @@ // least-recently-used (LRU) entries are removed once a fixed size is hit. #include -#include -#include +#include #include +#include +#include +#include #include "framework/counts.h" @@ -36,16 +38,13 @@ extern const PegInfo lru_cache_shared_peg_names[]; struct LruCacheSharedStats { PegCount adds = 0; // An insert that added new entry. - PegCount replaces = 0; // An insert that replaced existing entry PegCount prunes = 0; // When an old entry is removed to make // room for a new entry. PegCount find_hits = 0; // Found entry in cache. PegCount find_misses = 0; // Did not find entry in cache. - PegCount removes = 0; // Found entry and removed it. - PegCount clears = 0; // Calls to clear API. }; -template +template class LruCacheShared { public: @@ -57,58 +56,29 @@ public: LruCacheShared& operator=(const LruCacheShared& arg) = delete; LruCacheShared(const size_t initial_size) : - max_size(initial_size), - current_size(0) - { - } + max_size(initial_size), current_size(0) { } - // Get current number of elements in the LruCache. - size_t size() - { - std::lock_guard cache_lock(cache_mutex); - return current_size; - } + using Data = std::shared_ptr; - size_t get_max_size() - { - std::lock_guard cache_lock(cache_mutex); - return max_size; - } + // Return data entry associated with key. If doesn't exist, return nullptr. + Data find(const Key& key); + + // Return data entry associated with key. If doesn't exist, create a new entry. + Data operator[](const Key& key); + + // Return all data from the LruCache in order (most recently used to least) + std::vector > get_all_data(); // Modify the maximum number of entries allowed in the cache. // If the size is reduced, the oldest entries are removed. bool set_max_size(size_t newsize); - // Add data to cache or replace data if it already exists. - void insert(const Key& key, const Data& data); - - // Find Data associated with Key. If update is true, mark entry as - // recently used. - // Returns true and copies data if the key is found. - bool find(const Key& key, Data& data, bool update=true); - - // Remove entry associated with Key. - // Returns true if entry existed, false otherwise. - bool remove(const Key& key); - - // Remove entry associated with key and return removed data. - // Returns true and copy of data if entry existed. Returns false if - // entry did not exist. - bool remove(const Key& key, Data& data); - - // Remove all elements from the LruCache - void clear(); - - // Return all data from the LruCache in order (most recently used to - // least). - std::vector > get_all_data(); - const PegInfo* get_pegs() const { return lru_cache_shared_peg_names; } - PegCount* get_counts() const + PegCount* get_counts() { return (PegCount*)&stats; } @@ -144,8 +114,8 @@ private: struct LruCacheSharedStats stats; }; -template -bool LruCacheShared::set_max_size(size_t newsize) +template +bool LruCacheShared::set_max_size(size_t newsize) { LruListIter list_iter; @@ -162,32 +132,50 @@ bool LruCacheShared::set_max_size(size_t newsize) current_size--; map.erase(list_iter->first); list.erase(list_iter); + stats.prunes++; } max_size = newsize; return true; } -template -void LruCacheShared::insert(const Key& key, const Data& data) +template +std::shared_ptr LruCacheShared::find(const Key& key) { LruMapIter map_iter; std::lock_guard cache_lock(cache_mutex); - // If key already exists, remove it. map_iter = map.find(key); - if (map_iter != map.end()) + if (map_iter == map.end()) { - current_size--; - list.erase(map_iter->second); - map.erase(map_iter); - stats.replaces++; + stats.find_misses++; + return nullptr; } - else + + // Move entry to front of LruList + list.splice(list.begin(), list, map_iter->second); + stats.find_hits++; + return map_iter->second->second; +} + +template +std::shared_ptr LruCacheShared::operator[](const Key& key) +{ + LruMapIter map_iter; + std::lock_guard cache_lock(cache_mutex); + + map_iter = map.find(key); + if (map_iter != map.end()) { - stats.adds++; + stats.find_hits++; + list.splice(list.begin(), list, map_iter->second); // update LRU + return map_iter->second->second; } + stats.find_misses++; + stats.adds++; + Data data = Data(new Value); + // Add key/data pair to front of list. list.emplace_front(std::make_pair(key, data)); @@ -208,87 +196,12 @@ void LruCacheShared::insert(const Key& key, const Data& data) { current_size++; } + return data; } -template -bool LruCacheShared::find(const Key& key, Data& data, bool update) -{ - LruMapIter map_iter; - std::lock_guard cache_lock(cache_mutex); - - map_iter = map.find(key); - if (map_iter == map.end()) - { - stats.find_misses++; - return false; // Key is not in LruCache. - } - - data = map_iter->second->second; - - // If needed, move entry to front of LruList - if (update) - list.splice(list.begin(), list, map_iter->second); - - stats.find_hits++; - return true; -} - -template -bool LruCacheShared::remove(const Key& key) -{ - LruMapIter map_iter; - std::lock_guard cache_lock(cache_mutex); - - map_iter = map.find(key); - if (map_iter == map.end()) - return false; // Key is not in LruCache. - - current_size--; - list.erase(map_iter->second); - map.erase(map_iter); - stats.removes++; - return(true); -} - -template -bool LruCacheShared::remove(const Key& key, Data& data) -{ - LruMapIter map_iter; - std::lock_guard cache_lock(cache_mutex); - - map_iter = map.find(key); - if (map_iter == map.end()) - return false; // Key is not in LruCache. - - data = map_iter->second->second; - - current_size--; - list.erase(map_iter->second); - map.erase(map_iter); - stats.removes++; - return(true); -} - -template -void LruCacheShared::clear() -{ - LruMapIter map_iter; - std::lock_guard cache_lock(cache_mutex); - - for (map_iter = map.begin(); map_iter != map.end(); /* No incr */) - { - list.erase(map_iter->second); - - // erase returns next iterator after erased element. - map_iter = map.erase(map_iter); - } - - current_size = 0; - stats.clears++; -} - -template -std::vector > LruCacheShared::get_all_data() +template +std::vector< std::pair> > +LruCacheShared::get_all_data() { std::vector > vec; std::lock_guard cache_lock(cache_mutex); diff --git a/src/hash/test/lru_cache_shared_test.cc b/src/hash/test/lru_cache_shared_test.cc index 3ca3185f5..89832a150 100644 --- a/src/hash/test/lru_cache_shared_test.cc +++ b/src/hash/test/lru_cache_shared_test.cc @@ -34,175 +34,73 @@ TEST_GROUP(lru_cache_shared) { }; -// Test LruCacheShared constructor and member access. -TEST(lru_cache_shared, constructor_test) -{ - LruCacheShared > lru_cache(5); - - CHECK(lru_cache.get_max_size() == 5); - CHECK(lru_cache.size() == 0); -} - -// Test LruCacheShared insert, find and get_all_data functions. +// Test LruCacheShared find, operator[], and get_all_data functions. TEST(lru_cache_shared, insert_test) { - std::string data; - LruCacheShared > lru_cache(5); + LruCacheShared > lru_cache(3); - lru_cache.insert(0, "zero"); - CHECK(true == lru_cache.find(0, data)); - CHECK("zero" == data); - - lru_cache.insert(1, "one"); - CHECK(true == lru_cache.find(1, data)); - CHECK("one" == data); - - lru_cache.insert(2, "two"); - CHECK(true == lru_cache.find(2, data)); - CHECK("two" == data); - - // Verify find fails for non-existent item. - CHECK(false == lru_cache.find(3, data)); - - // Verify that insert will replace data if key exists already. - lru_cache.insert(1, "new one"); - CHECK(true == lru_cache.find(1, data)); - CHECK("new one" == data); - - // Verify current number of entries in cache. - CHECK(3 == lru_cache.size()); - - // Verify that the data is in LRU order. - auto vec = lru_cache.get_all_data(); - CHECK(3 == vec.size()); - CHECK((vec[0] == std::make_pair(1, std::string("new one")))); - CHECK((vec[1] == std::make_pair(2, std::string("two")))); - CHECK((vec[2] == std::make_pair(0, std::string("zero")))); -} + auto data = lru_cache[0]; + CHECK(data == lru_cache.find(0)); + data->assign("zero"); -// Test that the least recently used items are removed when we exceed -// the capacity of the LruCache. -TEST(lru_cache_shared, lru_removal_test) -{ - LruCacheShared > lru_cache(5); + data = lru_cache[1]; + CHECK(data == lru_cache.find(1)); + data->assign("one"); - for (int i = 0; i < 10; i++) - { - lru_cache.insert(i, std::to_string(i)); - } - - CHECK(5 == lru_cache.size()); - - // Verify that the data is in LRU order and is correct. - auto vec = lru_cache.get_all_data(); - CHECK(5 == vec.size()); - CHECK((vec[0] == std::make_pair(9, std::string("9")))); - CHECK((vec[1] == std::make_pair(8, std::string("8")))); - CHECK((vec[2] == std::make_pair(7, std::string("7")))); - CHECK((vec[3] == std::make_pair(6, std::string("6")))); - CHECK((vec[4] == std::make_pair(5, std::string("5")))); -} + data = lru_cache[2]; + CHECK(data == lru_cache.find(2)); + data->assign("two"); -// Test the remove and clear functions. -TEST(lru_cache_shared, remove_test) -{ - std::string data; - LruCacheShared > lru_cache(5); + // Replace existing + data = lru_cache[0]; + data->assign("new_zero"); - for (int i = 0; i < 5; i++) - { - lru_cache.insert(i, std::to_string(i)); - CHECK(true == lru_cache.find(i, data)); - CHECK(data == std::to_string(i)); - - CHECK(true == lru_cache.remove(i)); - CHECK(false == lru_cache.find(i, data)); - } - - CHECK(0 == lru_cache.size()); - - // Test remove API that returns the removed data. - lru_cache.insert(1, "one"); - CHECK(1 == lru_cache.size()); - CHECK(true == lru_cache.remove(1, data)); - CHECK(data == "one"); - CHECK(0 == lru_cache.size()); - - lru_cache.insert(1, "one"); - lru_cache.insert(2, "two"); - CHECK(2 == lru_cache.size()); - - // Verify that removing an item that does not exist does not affect - // cache. - CHECK(false == lru_cache.remove(3)); - CHECK(false == lru_cache.remove(4, data)); - CHECK(2 == lru_cache.size()); - - auto vec = lru_cache.get_all_data(); - CHECK(2 == vec.size()); - CHECK((vec[0] == std::make_pair(2, std::string("two")))); - CHECK((vec[1] == std::make_pair(1, std::string("one")))); - - // Verify that clear() removes all entries. - lru_cache.clear(); - CHECK(0 == lru_cache.size()); - - vec = lru_cache.get_all_data(); - CHECK(vec.empty()); + // Verify find fails for non-existent item. + CHECK(nullptr == lru_cache.find(3)); + + // Replace least recently used when capacity exceeds + data = lru_cache[3]; + CHECK(data == lru_cache.find(3)); + data->assign("three"); + + // Verify current number of entries in cache and LRU order. + const auto&& vec = lru_cache.get_all_data(); + CHECK(vec.size() == 3); + CHECK(vec[0].second->compare("three") == 0); + CHECK(vec[1].second->compare("new_zero") == 0); + CHECK(vec[2].second->compare("two") == 0); } -// Test statistics counters. +// Test statistics counters and set_max_size. TEST(lru_cache_shared, stats_test) { - std::string data; LruCacheShared > lru_cache(5); for (int i = 0; i < 10; i++) - { - lru_cache.insert(i, std::to_string(i)); - } - - lru_cache.insert(8, "new-eight"); // Replace entries. - lru_cache.insert(9, "new-nine"); - - CHECK(5 == lru_cache.size()); - - lru_cache.find(7, data); // Hits - lru_cache.find(8, data); - lru_cache.find(9, data); - - lru_cache.remove(7); - lru_cache.remove(8); - lru_cache.remove(9, data); - CHECK("new-nine" == data); + lru_cache[i]; - lru_cache.find(8, data); // Misses now that they're removed. - lru_cache.find(9, data); + lru_cache.find(7); // Hits + lru_cache.find(8); + lru_cache.find(9); - lru_cache.remove(100); // Removing a non-existent entry does not - // increase remove count. + lru_cache.find(10); // Misses; in addition to previous 10 + lru_cache.find(11); - lru_cache.clear(); + CHECK(lru_cache.set_max_size(3) == true); // change size prunes; in addition to previous 5 PegCount* stats = lru_cache.get_counts(); CHECK(stats[0] == 10); // adds - CHECK(stats[1] == 2); // replaces - CHECK(stats[2] == 5); // prunes - CHECK(stats[3] == 3); // find hits - CHECK(stats[4] == 2); // find misses - CHECK(stats[5] == 3); // removes - CHECK(stats[6] == 1); // clears + CHECK(stats[1] == 7); // prunes + CHECK(stats[2] == 3); // find hits + CHECK(stats[3] == 12); // find misses // Check statistics names. const PegInfo* pegs = lru_cache.get_pegs(); CHECK(!strcmp(pegs[0].name, "lru_cache_adds")); - CHECK(!strcmp(pegs[1].name, "lru_cache_replaces")); - CHECK(!strcmp(pegs[2].name, "lru_cache_prunes")); - CHECK(!strcmp(pegs[3].name, "lru_cache_find_hits")); - CHECK(!strcmp(pegs[4].name, "lru_cache_find_misses")); - CHECK(!strcmp(pegs[5].name, "lru_cache_removes")); - CHECK(!strcmp(pegs[6].name, "lru_cache_clears")); + CHECK(!strcmp(pegs[1].name, "lru_cache_prunes")); + CHECK(!strcmp(pegs[2].name, "lru_cache_find_hits")); + CHECK(!strcmp(pegs[3].name, "lru_cache_find_misses")); } int main(int argc, char** argv) diff --git a/src/host_tracker/host_cache.cc b/src/host_tracker/host_cache.cc index 450d46eb5..9a55ab808 100644 --- a/src/host_tracker/host_cache.cc +++ b/src/host_tracker/host_cache.cc @@ -24,88 +24,8 @@ #include "host_cache.h" -#include "main/snort_config.h" -#include "target_based/snort_protocols.h" - using namespace snort; -using namespace std; #define LRU_CACHE_INITIAL_SIZE 65535 -LruCacheShared, HashHostIpKey> - host_cache(LRU_CACHE_INITIAL_SIZE); - -namespace snort -{ - -void host_cache_add_host_tracker(HostTracker* ht) -{ - std::shared_ptr sptr(ht); - host_cache.insert((const uint8_t*) ht->get_ip_addr().get_ip6_ptr(), sptr); -} - -bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, SnortProtocolId id) -{ - HostIpKey ipkey((const uint8_t*) ipaddr.get_ip6_ptr()); - std::shared_ptr ht; - - if (!host_cache.find(ipkey, ht)) - { - // This host hasn't been seen. Add it. - ht = std::make_shared(ipaddr); - - if (ht == nullptr) - { - // FIXIT-L add error count - return false; - } - host_cache.insert(ipkey, ht); - } - - HostApplicationEntry app_entry(ipproto, port, id); - return ht->add_service(app_entry); -} - -bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, const char* service) -{ - return host_cache_add_service(ipaddr, ipproto, port, - SnortConfig::get_conf()->proto_ref->find(service)); -} - -bool host_cache_add_app_mapping(const SfIp& ipaddr, Port port, Protocol proto, AppId appId) -{ - HostIpKey ipkey((const uint8_t*) ipaddr.get_ip6_ptr()); - std::shared_ptr ht; - - if (!host_cache.find(ipkey, ht)) - { - ht = std::make_shared (ipaddr); - - if (ht == nullptr) - { - return false; - } - ht->add_app_mapping(port, proto, appId); - host_cache.insert(ipkey, ht); - } - else - { - return ht->find_else_add_app_mapping(port, proto, appId); - } - - return true; -} - -AppId host_cache_find_app_mapping(const SfIp* ipaddr, Port port, Protocol proto) -{ - HostIpKey ipkey((const uint8_t*) ipaddr->get_ip6_ptr()); - std::shared_ptr ht; - - if (host_cache.find(ipkey, ht)) - { - return ht->find_app_mapping(port, proto); - } - - return APP_ID_NONE; -} -} +LruCacheShared host_cache(LRU_CACHE_INITIAL_SIZE); diff --git a/src/host_tracker/host_cache.h b/src/host_tracker/host_cache.h index d73290645..725146acd 100644 --- a/src/host_tracker/host_cache.h +++ b/src/host_tracker/host_cache.h @@ -24,56 +24,22 @@ // The host cache is used to cache information about hosts so that it can // be shared among threads. -#include - #include "hash/lru_cache_shared.h" #include "host_tracker/host_tracker.h" - -#define HOST_IP_KEY_SIZE 16 - -struct HostIpKey -{ - union host_ip_addr - { - uint8_t ip8[HOST_IP_KEY_SIZE]; - uint64_t ip64[HOST_IP_KEY_SIZE/8]; - } ip_addr = {{0}}; // Holds either IPv4 or IPv6 addr - - HostIpKey() = default; - - HostIpKey(const uint8_t ip[HOST_IP_KEY_SIZE]) - { - memcpy(&ip_addr, ip, HOST_IP_KEY_SIZE); - } - - inline bool operator==(const HostIpKey& rhs) const - { - return !memcmp(&ip_addr, &rhs.ip_addr, HOST_IP_KEY_SIZE); - } -}; +#include "sfip/sf_ip.h" // Used to create hash of key for indexing into cache. -struct HashHostIpKey +struct HashIp { - size_t operator()(const HostIpKey& ip) const + size_t operator()(const snort::SfIp& ip) const { - return std::hash() (ip.ip_addr.ip64[0]) ^ - std::hash() (ip.ip_addr.ip64[1]); + const uint64_t* ip64 = (const uint64_t*) ip.get_ip6_ptr(); + return std::hash() (ip64[0]) ^ + std::hash() (ip64[1]); } }; -extern LruCacheShared, HashHostIpKey> host_cache; - -namespace snort -{ -void host_cache_add_host_tracker(HostTracker*); - -// Insert a new service into host cache if it doesn't already exist. -SO_PUBLIC bool host_cache_add_service(const SfIp&, Protocol, Port, SnortProtocolId); -SO_PUBLIC bool host_cache_add_service(const SfIp&, Protocol, Port, const char*); +extern SO_PUBLIC LruCacheShared host_cache; -bool host_cache_add_app_mapping(const SfIp&, Port, Protocol, AppId); -AppId host_cache_find_app_mapping(const SfIp* , Port, Protocol ); -} #endif diff --git a/src/host_tracker/host_cache_module.cc b/src/host_tracker/host_cache_module.cc index a22c4b673..f253644f5 100644 --- a/src/host_tracker/host_cache_module.cc +++ b/src/host_tracker/host_cache_module.cc @@ -123,7 +123,6 @@ HostCacheModule::~HostCacheModule() log_host_cache(dump_file); snort_free((void*)dump_file); } - host_cache.clear(); } void HostCacheModule::log_host_cache(const char* file_name, bool verbose) @@ -153,9 +152,12 @@ void HostCacheModule::log_host_cache(const char* file_name, bool verbose) } string str; + SfIpString ip_str; const auto&& lru_data = host_cache.get_all_data(); for ( const auto& elem : lru_data ) { + str = "IP: "; + str += elem.first.ntop(ip_str); elem.second->stringify(str); out_stream << str << endl << endl; } diff --git a/src/host_tracker/host_tracker.cc b/src/host_tracker/host_tracker.cc index c117563fa..d0911ab4b 100644 --- a/src/host_tracker/host_tracker.cc +++ b/src/host_tracker/host_tracker.cc @@ -29,165 +29,58 @@ using namespace std; THREAD_LOCAL struct HostTrackerStats host_tracker_stats; -snort::SfIp HostTracker::get_ip_addr() -{ - std::lock_guard lck(host_tracker_lock); - return ip_addr; -} - -void HostTracker::set_ip_addr(const snort::SfIp& new_ip_addr) -{ - std::lock_guard lck(host_tracker_lock); - std::memcpy(&ip_addr, &new_ip_addr, sizeof(ip_addr)); -} - -Policy HostTracker::get_stream_policy() -{ - std::lock_guard lck(host_tracker_lock); - return stream_policy; -} - -void HostTracker::set_stream_policy(const Policy& policy) -{ - std::lock_guard lck(host_tracker_lock); - stream_policy = policy; -} - -Policy HostTracker::get_frag_policy() -{ - std::lock_guard lck(host_tracker_lock); - return frag_policy; -} - -void HostTracker::set_frag_policy(const Policy& policy) -{ - std::lock_guard lck(host_tracker_lock); - frag_policy = policy; -} - -void HostTracker::add_app_mapping(Port port, Protocol proto, AppId appid) -{ - std::lock_guard lck(host_tracker_lock); - AppMapping app_map = {port, proto, appid}; - - app_mappings.emplace_back(app_map); -} - -AppId HostTracker::find_app_mapping(Port port, Protocol proto) +bool HostTracker::add_service(Port port, IpProtocol proto, AppId appid, bool inferred_appid) { + host_tracker_stats.service_adds++; std::lock_guard lck(host_tracker_lock); - for (std::vector::iterator it=app_mappings.begin(); it!=app_mappings.end(); ++it) - { - if (it->port == port and it->proto ==proto) - { - return it->appid; - } - } - return APP_ID_NONE; -} -bool HostTracker::find_else_add_app_mapping(Port port, Protocol proto, AppId appid) -{ - std::lock_guard lck(host_tracker_lock); - for (std::vector::iterator it=app_mappings.begin(); it!=app_mappings.end(); ++it) + for ( auto& s : services ) { - if (it->port == port and it->proto ==proto) + if ( s.port == port and s.proto == proto ) { - return false; + if ( s.appid != appid and appid != APP_ID_NONE ) + { + s.appid = appid; + s.inferred_appid = inferred_appid; + } + return true; } } - AppMapping app_map = {port, proto, appid}; - - app_mappings.emplace_back(app_map); - return true; -} - -bool HostTracker::add_service(const HostApplicationEntry& app_entry) -{ - host_tracker_stats.service_adds++; - - std::lock_guard lck(host_tracker_lock); - - auto iter = std::find(services.begin(), services.end(), app_entry); - if (iter != services.end()) - return false; // Already exists. - services.emplace_front(app_entry); + services.emplace_back( HostApplication{port, proto, appid, inferred_appid} ); return true; } -void HostTracker::add_or_replace_service(const HostApplicationEntry& app_entry) +AppId HostTracker::get_appid(Port port, IpProtocol proto, bool inferred_only) { - host_tracker_stats.service_adds++; - - std::lock_guard lck(host_tracker_lock); - - auto iter = std::find(services.begin(), services.end(), app_entry); - if (iter != services.end()) - services.erase(iter); - - services.emplace_front(app_entry); -} - -bool HostTracker::find_service(Protocol ipproto, Port port, HostApplicationEntry& app_entry) -{ - HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID); host_tracker_stats.service_finds++; - std::lock_guard lck(host_tracker_lock); - auto iter = std::find(services.begin(), services.end(), tmp_entry); - if (iter != services.end()) + for ( const auto& s : services ) { - app_entry = *iter; - return true; + if ( s.port == port and s.proto == proto and + (!inferred_only or s.inferred_appid == inferred_only) ) + return s.appid; } - return false; -} - -bool HostTracker::remove_service(Protocol ipproto, Port port) -{ - HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID); - host_tracker_stats.service_removes++; - - std::lock_guard lck(host_tracker_lock); - - auto iter = std::find(services.begin(), services.end(), tmp_entry); - if (iter != services.end()) - { - services.erase(iter); - return true; // Assumes only one matching entry. - } - - return false; + return APP_ID_NONE; } void HostTracker::stringify(string& str) { - str = "IP: "; - SfIpString ip_str; - str += ip_addr.ntop(ip_str); - - if ( !app_mappings.empty() ) - { - str += "\napp_mappings size: " + to_string(app_mappings.size()); - for ( const auto& elem : app_mappings ) - str += "\n port: " + to_string(elem.port) - + ", proto: " + to_string(elem.proto) - + ", appid: " + to_string(elem.appid); - } - - if ( stream_policy or frag_policy ) - str += "\nstream policy: " + to_string(stream_policy) - + ", frag policy: " + to_string(frag_policy); - if ( !services.empty() ) { str += "\nservices size: " + to_string(services.size()); - for ( const auto& elem : services ) - str += "\n port: " + to_string(elem.port) - + ", proto: " + to_string(elem.ipproto) - + ", snort proto: " + to_string(elem.snort_protocol_id); - } + for ( const auto& s : services ) + { + str += "\n port: " + to_string(s.port) + + ", proto: " + to_string((uint8_t) s.proto); + if ( s.appid != APP_ID_NONE ) + { + str += ", appid: " + to_string(s.appid); + if ( s.inferred_appid ) + str += ", inferred"; + } + } + } } diff --git a/src/host_tracker/host_tracker.h b/src/host_tracker/host_tracker.h index eafc9fbdc..0c22509be 100644 --- a/src/host_tracker/host_tracker.h +++ b/src/host_tracker/host_tracker.h @@ -25,95 +25,40 @@ // configuration or dynamic discovery). It provides a thread-safe API to // set/get the host data. -#include -#include -#include #include +#include + #include "framework/counts.h" #include "main/snort_types.h" #include "main/thread.h" #include "network_inspectors/appid/application_ids.h" #include "protocols/protocol_ids.h" -#include "sfip/sf_ip.h" -#include "target_based/snort_protocols.h" - -// FIXIT-M For now this emulates the Snort++ attribute table. -// Need to add in host_tracker.h data eventually. - -typedef uint16_t Protocol; -typedef uint8_t Policy; struct HostTrackerStats { PegCount service_adds; PegCount service_finds; - PegCount service_removes; }; extern THREAD_LOCAL struct HostTrackerStats host_tracker_stats; -struct HostApplicationEntry -{ - Port port = 0; - Protocol ipproto = 0; - SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID; - - HostApplicationEntry() = default; - - HostApplicationEntry(Protocol ipproto_param, Port port_param, SnortProtocolId protocol_param) : - port(port_param), - ipproto(ipproto_param), - snort_protocol_id(protocol_param) - { - } - - inline bool operator==(const HostApplicationEntry& rhs) const - { - return ipproto == rhs.ipproto and port == rhs.port; - } -}; - -struct AppMapping +struct HostApplication { Port port; - Protocol proto; + IpProtocol proto; AppId appid; + bool inferred_appid; }; -class HostTracker +class SO_PUBLIC HostTracker { public: - HostTracker() - { memset(&ip_addr, 0, sizeof(ip_addr)); } - - HostTracker(const snort::SfIp& new_ip_addr) - { std::memcpy(&ip_addr, &new_ip_addr, sizeof(ip_addr)); } - - snort::SfIp get_ip_addr(); - void set_ip_addr(const snort::SfIp& new_ip_addr); - Policy get_stream_policy(); - void set_stream_policy(const Policy& policy); - Policy get_frag_policy(); - void set_frag_policy(const Policy& policy); - void add_app_mapping(Port port, Protocol proto, AppId appid); - AppId find_app_mapping(Port port, Protocol proto); - bool find_else_add_app_mapping(Port port, Protocol proto, AppId appid); - - // Add host service data only if it doesn't already exist. Returns - // false if entry exists already, and true if entry was added. - bool add_service(const HostApplicationEntry& app_entry); + // Appid may not be identified always. Inferred means dynamic/runtime + // appid detected from one flow to another flow such as BitTorrent. + bool add_service(Port port, IpProtocol proto, + AppId appid = APP_ID_NONE, bool inferred_appid = false); - // Add host service data if it doesn't already exist. If it does exist - // replace the previous entry with the new entry. - void add_or_replace_service(const HostApplicationEntry& app_entry); - - // Returns true and fills in copy of HostApplicationEntry when found. - // Returns false when not found. - bool find_service(Protocol ipproto, Port port, HostApplicationEntry& app_entry); - - // Removes HostApplicationEntry object associated with ipproto and port. - // Returns true if entry existed. False otherwise. - bool remove_service(Protocol ipproto, Port port); + AppId get_appid(Port port, IpProtocol proto, bool inferred_only = false); // This should be updated whenever HostTracker data members are changed void stringify(std::string& str); @@ -122,15 +67,7 @@ private: // Ensure that updates to a shared object are safe std::mutex host_tracker_lock; - // FIXIT-M do we need to use a host_id instead of SfIp as in sfrna? - snort::SfIp ip_addr; - std::vector< AppMapping > app_mappings; - - // Policies to apply to this host. - Policy stream_policy = 0; - Policy frag_policy = 0; - - std::list services; + std::vector services; }; #endif diff --git a/src/host_tracker/host_tracker_module.cc b/src/host_tracker/host_tracker_module.cc index 07d74426a..f56c03ca4 100644 --- a/src/host_tracker/host_tracker_module.cc +++ b/src/host_tracker/host_tracker_module.cc @@ -24,11 +24,8 @@ #include "host_tracker_module.h" +#include "log/messages.h" #include "main/snort_config.h" -#include "stream/stream.h" -#include "target_based/snort_protocols.h" - -#include "host_cache.h" using namespace snort; @@ -36,34 +33,21 @@ const PegInfo host_tracker_pegs[] = { { CountType::SUM, "service_adds", "host service adds" }, { CountType::SUM, "service_finds", "host service finds" }, - { CountType::SUM, "service_removes", "host service removes" }, { CountType::END, nullptr, nullptr }, }; const Parameter HostTrackerModule::service_params[] = { - { "name", Parameter::PT_STRING, nullptr, nullptr, - "service identifier" }, - - { "proto", Parameter::PT_ENUM, "tcp | udp", "tcp", - "IP protocol" }, + { "port", Parameter::PT_PORT, nullptr, nullptr, "port number" }, - { "port", Parameter::PT_PORT, nullptr, nullptr, - "port number" }, + { "proto", Parameter::PT_ENUM, "ip | tcp | udp", nullptr, "IP protocol" }, { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } }; const Parameter HostTrackerModule::host_tracker_params[] = { - { "ip", Parameter::PT_ADDR, nullptr, "0.0.0.0/32", - "hosts address / cidr" }, - - { "frag_policy", Parameter::PT_ENUM, IP_POLICIES, nullptr, - "defragmentation policy" }, - - { "tcp_policy", Parameter::PT_ENUM, TCP_POLICIES, nullptr, - "TCP reassembly policy" }, + { "ip", Parameter::PT_ADDR, nullptr, nullptr, "hosts address / cidr" }, { "services", Parameter::PT_LIST, HostTrackerModule::service_params, nullptr, "list of service parameters" }, @@ -71,29 +55,21 @@ const Parameter HostTrackerModule::host_tracker_params[] = { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr } }; -bool HostTrackerModule::set(const char*, Value& v, SnortConfig* sc) +bool HostTrackerModule::set(const char*, Value& v, SnortConfig*) { - if ( host and v.is("ip") ) - { - SfIp addr; + if ( v.is("ip") ) v.get_addr(addr); - host->set_ip_addr(addr); - } - else if ( host and v.is("frag_policy") ) - host->set_frag_policy(v.get_uint8() + 1); - - else if ( host and v.is("tcp_policy") ) - host->set_stream_policy(v.get_uint8() + 1); - - else if ( v.is("name") ) - app.snort_protocol_id = sc->proto_ref->add(v.get_string()); - - else if ( v.is("proto") ) - app.ipproto = sc->proto_ref->add(v.get_string()); else if ( v.is("port") ) app.port = v.get_uint16(); + else if ( v.is("proto") ) + { + const IpProtocol mask[] = + { IpProtocol::IP, IpProtocol::TCP, IpProtocol::UDP }; + app.proto = mask[v.get_uint8()]; + } + else return false; @@ -103,8 +79,10 @@ bool HostTrackerModule::set(const char*, Value& v, SnortConfig* sc) bool HostTrackerModule::begin(const char* fqn, int idx, SnortConfig*) { if ( idx && !strcmp(fqn, "host_tracker") ) - host = new HostTracker; - + { + addr.clear(); + app = {}; + } return true; } @@ -112,13 +90,14 @@ bool HostTrackerModule::end(const char* fqn, int idx, SnortConfig*) { if ( idx && !strcmp(fqn, "host_tracker.services") ) { - host->add_service(app); + if ( addr.is_set() ) + host_cache[addr]->add_service(app.port, app.proto); app = {}; } - else if ( idx && !strcmp(fqn, "host_tracker") ) + else if ( idx && !strcmp(fqn, "host_tracker") && addr.is_set() ) { - host_cache_add_host_tracker(host); - host = nullptr; // Host cache is now responsible for freeing host + host_cache[addr]; + addr.clear(); } return true; diff --git a/src/host_tracker/host_tracker_module.h b/src/host_tracker/host_tracker_module.h index 6637d3d5d..00f400030 100644 --- a/src/host_tracker/host_tracker_module.h +++ b/src/host_tracker/host_tracker_module.h @@ -30,7 +30,7 @@ #include #include "framework/module.h" -#include "host_tracker/host_tracker.h" +#include "host_tracker/host_cache.h" #define host_tracker_help \ "configure hosts" @@ -38,11 +38,8 @@ class HostTrackerModule : public snort::Module { public: - HostTrackerModule() : snort::Module("host_tracker", host_tracker_help, host_tracker_params, true) - { host = nullptr; } - - ~HostTrackerModule() override - { assert(!host); } + HostTrackerModule() : + snort::Module("host_tracker", host_tracker_help, host_tracker_params, true) { } const PegInfo* get_pegs() const override; PegCount* get_counts() const override; @@ -58,8 +55,8 @@ private: static const snort::Parameter host_tracker_params[]; static const snort::Parameter service_params[]; - HostApplicationEntry app; - HostTracker* host; + HostApplication app; + snort::SfIp addr; }; #endif diff --git a/src/host_tracker/test/host_cache_module_test.cc b/src/host_tracker/test/host_cache_module_test.cc index f8f3b28e9..5266dbb6d 100644 --- a/src/host_tracker/test/host_cache_module_test.cc +++ b/src/host_tracker/test/host_cache_module_test.cc @@ -46,9 +46,6 @@ static char logged_message[LOG_MAX+1]; namespace snort { -// Fakes to avoid bringing in a ton of dependencies. -SnortProtocolId ProtocolReference::add(char const*) { return 0; } -SnortProtocolId ProtocolReference::find(char const*) { return 0; } SnortConfig* SnortConfig::get_conf() { return nullptr; } char* snort_strdup(const char* s) { return strdup(s); } Module* ModuleManager::get_module(const char*) { return nullptr; } @@ -73,9 +70,6 @@ void show_stats(PegCount*, const PegInfo*, unsigned, const char*) void show_stats(PegCount*, const PegInfo*, IndexVec&, const char*, FILE*) { } -#define FRAG_POLICY 33 -#define STREAM_POLICY 100 - TEST_GROUP(host_cache_module) { void setup() override @@ -98,21 +92,15 @@ TEST(host_cache_module, host_cache_module_test_values) const PegCount* ht_stats = module.get_counts(); CHECK(!strcmp(ht_pegs[0].name, "lru_cache_adds")); - CHECK(!strcmp(ht_pegs[1].name, "lru_cache_replaces")); - CHECK(!strcmp(ht_pegs[2].name, "lru_cache_prunes")); - CHECK(!strcmp(ht_pegs[3].name, "lru_cache_find_hits")); - CHECK(!strcmp(ht_pegs[4].name, "lru_cache_find_misses")); - CHECK(!strcmp(ht_pegs[5].name, "lru_cache_removes")); - CHECK(!strcmp(ht_pegs[6].name, "lru_cache_clears")); - CHECK(!ht_pegs[7].name); + CHECK(!strcmp(ht_pegs[1].name, "lru_cache_prunes")); + CHECK(!strcmp(ht_pegs[2].name, "lru_cache_find_hits")); + CHECK(!strcmp(ht_pegs[3].name, "lru_cache_find_misses")); + CHECK(!ht_pegs[4].name); CHECK(ht_stats[0] == 0); CHECK(ht_stats[1] == 0); CHECK(ht_stats[2] == 0); CHECK(ht_stats[3] == 0); - CHECK(ht_stats[4] == 0); - CHECK(ht_stats[5] == 0); - CHECK(ht_stats[6] == 0); size_val.set(&size_param); @@ -123,8 +111,6 @@ TEST(host_cache_module, host_cache_module_test_values) ht_stats = module.get_counts(); CHECK(ht_stats[0] == 0); - - CHECK(2112 == host_cache.get_max_size()); } TEST(host_cache_module, log_host_cache_messages) diff --git a/src/host_tracker/test/host_cache_test.cc b/src/host_tracker/test/host_cache_test.cc index db835d894..f177349ac 100644 --- a/src/host_tracker/test/host_cache_test.cc +++ b/src/host_tracker/test/host_cache_test.cc @@ -25,6 +25,8 @@ #include "host_tracker/host_cache.h" +#include + #include "main/snort_config.h" #include @@ -34,23 +36,6 @@ using namespace snort; namespace snort { -SnortConfig s_conf; -THREAD_LOCAL SnortConfig* snort_conf = &s_conf; -SnortConfig::SnortConfig(const SnortConfig* const) { } -SnortConfig::~SnortConfig() = default; -SnortConfig* SnortConfig::get_conf() -{ return snort_conf; } - -SnortProtocolId ProtocolReference::find(char const*) { return 0; } -SnortProtocolId ProtocolReference::add(const char* protocol) -{ - if (!strcmp("servicename", protocol)) - return 3; - if (!strcmp("tcp", protocol)) - return 2; - return 1; -} - char* snort_strdup(const char* str) { return strdup(str); @@ -61,23 +46,7 @@ TEST_GROUP(host_cache) { }; -// Test HostIpKey -TEST(host_cache, host_ip_key_test) -{ - HostIpKey zeroed_hk; - uint8_t expected_hk[16] = - { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef }; - - memset(&zeroed_hk.ip_addr, 0, sizeof(zeroed_hk.ip_addr)); - - HostIpKey hkey1; - CHECK(hkey1 == zeroed_hk); - - HostIpKey hkey2(expected_hk); - CHECK(hkey2 == expected_hk); -} - -// Test HashHostIpKey +// Test HashIp TEST(host_cache, hash_test) { size_t expected_hash_val = 4521729; @@ -85,174 +54,16 @@ TEST(host_cache, hash_test) uint8_t hk[16] = { 0x0a,0xff,0x12,0x00,0x00,0x00,0x00,0x00,0x0b,0x00,0x56,0x00,0x00,0x00,0x00,0x00 }; - HashHostIpKey hash_hk; + HashIp hash_hk; + SfIp ip; - actual_hash_val = hash_hk(hk); + ip.set(hk); + actual_hash_val = hash_hk(ip); CHECK(actual_hash_val == expected_hash_val); } -// Test host_cache_add_host_tracker -TEST(host_cache, host_cache_add_host_tracker_test) -{ - HostTracker* expected_ht = new HostTracker; - std::shared_ptr actual_ht; - uint8_t hk[16] = - { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef }; - SfIp ip_addr; - SfIp actual_ip_addr; - HostIpKey hkey(hk); - Port port = 443; - Protocol proto = 6; - HostApplicationEntry app_entry(proto, port, 2); - HostApplicationEntry actual_app_entry; - bool ret; - - ip_addr.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab"); - - expected_ht->set_ip_addr(ip_addr); - expected_ht->add_service(app_entry); - - host_cache_add_host_tracker(expected_ht); - - ret = host_cache.find((const uint8_t*) ip_addr.get_ip6_ptr(), actual_ht); - CHECK(true == ret); - - actual_ip_addr = actual_ht->get_ip_addr(); - CHECK(!memcmp(&ip_addr, &actual_ip_addr, sizeof(ip_addr))); - - ret = actual_ht->find_service(proto, port, actual_app_entry); - CHECK(true == ret); - CHECK(actual_app_entry == app_entry); - - host_cache.clear(); // Free HostTracker objects -} - -// Test host_cache_add_service -TEST(host_cache, host_cache_add_service_test) -{ - HostTracker* expected_ht = new HostTracker; - std::shared_ptr actual_ht; - uint8_t hk[16] = - { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef }; - SfIp ip_addr1; - SfIp ip_addr2; - HostIpKey hkey(hk); - Port port1 = 443; - Port port2 = 22; - Protocol proto1 = 17; - Protocol proto2 = 6; - HostApplicationEntry app_entry1(proto1, port1, 1); - HostApplicationEntry app_entry2(proto2, port2, 2); - HostApplicationEntry actual_app_entry; - bool ret; - - ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab"); - ip_addr2.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab"); - - // Initialize cache with a HostTracker. - host_cache_add_host_tracker(expected_ht); - - // Add a service to a HostTracker that already exists. - ret = host_cache_add_service(ip_addr1, proto1, port1, "udp"); - CHECK(true == ret); - - ret = host_cache.find((const uint8_t*) ip_addr1.get_ip6_ptr(), actual_ht); - CHECK(true == ret); - - ret = actual_ht->find_service(proto1, port1, actual_app_entry); - CHECK(true == ret); - CHECK(actual_app_entry == app_entry1); - - // Try adding service again (should fail since service exists). - ret = host_cache_add_service(ip_addr1, proto1, port1, "udp"); - CHECK(false == ret); - - // Add a service with a new IP. - ret = host_cache_add_service(ip_addr2, proto2, port2, "tcp"); - CHECK(true == ret); - - ret = host_cache.find((const uint8_t*) ip_addr1.get_ip6_ptr(), actual_ht); - CHECK(true == ret); - - ret = actual_ht->find_service(proto2, port2, actual_app_entry); - CHECK(true == ret); - CHECK(actual_app_entry == app_entry2); - - host_cache.clear(); // Free HostTracker objects -} - -TEST(host_cache, host_cache_app_mapping_test ) -{ - HostTracker* expected_ht = new HostTracker; - std::shared_ptr actual_ht; - uint8_t hk[16] = - { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef }; - SfIp ip_addr1; - SfIp ip_addr2; - HostIpKey hkey(hk); - Port port1 = 4123; - Port port2 = 1827; - Protocol proto1 = 6; - Protocol proto2 = 7; - AppId appid1 = 61; - AppId appid2 = 62; - AppId appid3 = 63; - AppId ret; - bool add_ret; - - ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab"); - ip_addr2.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:901b"); - - // Initialize cache with a HostTracker. - host_cache_add_host_tracker(expected_ht); - - add_ret = host_cache_add_app_mapping(ip_addr1, port1, proto1, appid1); - CHECK(true == add_ret); - - ret = host_cache_find_app_mapping(&ip_addr1, port1, proto1); - CHECK(61 == ret); - - ret = host_cache_find_app_mapping(&ip_addr1, port2, proto1); - CHECK(APP_ID_NONE == ret); - - ret = host_cache_find_app_mapping(&ip_addr2, port1, proto1); - CHECK(APP_ID_NONE == ret); - - add_ret = host_cache_add_app_mapping(ip_addr1, port2, proto1, appid2); - CHECK(true == add_ret); - ret = host_cache_find_app_mapping(&ip_addr1, port2, proto1); - CHECK(62 == ret); - - add_ret = host_cache_add_app_mapping(ip_addr1, port1, proto2, appid3); - CHECK(true == add_ret); - ret = host_cache_find_app_mapping(&ip_addr1, port1, proto2); - CHECK(63 == ret); - - host_cache.clear(); // Free HostTracker objects -} - int main(int argc, char** argv) { - SfIp ip_addr1; - Protocol proto1 = 17; - Port port1 = 443; - - ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab"); - - // This is necessary to prevent the cpputest memory leak - // detection from thinking there's a memory leak in the map - // object contained within the global host_cache. The map - // must have some data allocated when it is first created - // that doesn't go away until the global map object is - // deallocated. This pre-allocates the map so that initial - // allocation is done prior to starting the tests. The same - // is true for the list used when adding a service. - HostTracker* ht = new HostTracker; - host_cache_add_host_tracker(ht); - host_cache_add_service(ip_addr1, proto1, port1, "udp"); - - host_cache.clear(); - // Use this if you want to turn off memory checks entirely: // MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); diff --git a/src/host_tracker/test/host_tracker_module_test.cc b/src/host_tracker/test/host_tracker_module_test.cc index 2d5e36f7a..33e9da5a1 100644 --- a/src/host_tracker/test/host_tracker_module_test.cc +++ b/src/host_tracker/test/host_tracker_module_test.cc @@ -23,10 +23,12 @@ #include "config.h" #endif +#include + #include "host_tracker/host_cache.h" #include "host_tracker/host_tracker_module.h" -#include "target_based/snort_protocols.h" #include "main/snort_config.h" +#include "target_based/snort_protocols.h" #include #include @@ -35,17 +37,6 @@ using namespace snort; namespace snort { -SnortConfig* SnortConfig::get_conf() { return nullptr; } -SnortProtocolId ProtocolReference::find(char const*) { return 0; } -SnortProtocolId ProtocolReference::add(const char* protocol) -{ - if (!strcmp("servicename", protocol)) - return 3; - if (!strcmp("tcp", protocol)) - return 2; - return 1; -} - char* snort_strdup(const char* s) { return strdup(s); } } @@ -54,9 +45,6 @@ char* snort_strdup(const char* s) void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { } void show_stats(PegCount*, const PegInfo*, IndexVec&, const char*, FILE*) { } -#define FRAG_POLICY 33 -#define STREAM_POLICY 100 - SfIp expected_addr; TEST_GROUP(host_tracker_module) @@ -64,15 +52,9 @@ TEST_GROUP(host_tracker_module) void setup() override { Value ip_val("10.23.45.56"); - Value frag_val((double)FRAG_POLICY); - Value tcp_val((double)STREAM_POLICY); - Value name_val("servicename"); Value proto_val("udp"); Value port_val((double)2112); Parameter ip_param = { "ip", Parameter::PT_ADDR, nullptr, "0.0.0.0/32", "addr/cidr"}; - Parameter frag_param = { "frag_policy", Parameter::Parameter::PT_ENUM, "linux | bsd", nullptr, "frag policy"}; - Parameter tcp_param = { "tcp_policy", Parameter::PT_ENUM, "linux | bsd", nullptr, "tcp policy"}; - Parameter name_param = {"name", Parameter::PT_STRING, nullptr, nullptr, "name"}; Parameter proto_param = {"proto", Parameter::PT_ENUM, "tcp | udp", "tcp", "ip proto"}; Parameter port_param = {"port", Parameter::PT_PORT, nullptr, nullptr, "port num"}; HostTrackerModule module; @@ -81,17 +63,12 @@ TEST_GROUP(host_tracker_module) CHECK(!strcmp(ht_pegs[0].name, "service_adds")); CHECK(!strcmp(ht_pegs[1].name, "service_finds")); - CHECK(!strcmp(ht_pegs[2].name, "service_removes")); - CHECK(!ht_pegs[3].name); + CHECK(!ht_pegs[2].name); CHECK(ht_stats[0] == 0); CHECK(ht_stats[1] == 0); - CHECK(ht_stats[2] == 0); ip_val.set(&ip_param); - frag_val.set(&frag_param); - tcp_val.set(&tcp_param); - name_val.set(&name_param); proto_val.set(&proto_param); port_val.set(&port_param); @@ -101,12 +78,6 @@ TEST_GROUP(host_tracker_module) // Set up the module values and add a service. module.begin("host_tracker", 1, nullptr); module.set(nullptr, ip_val, nullptr); - module.set(nullptr, frag_val, nullptr); - module.set(nullptr, tcp_val, nullptr); - - // FIXIT-M see FIXIT-M below - //module.set(nullptr, name_val, nullptr); - //module.set(nullptr, proto_val, nullptr); module.set(nullptr, port_val, nullptr); module.end("host_tracker.services", 1, nullptr); @@ -118,7 +89,6 @@ TEST_GROUP(host_tracker_module) void teardown() override { memset(&host_tracker_stats, 0, sizeof(host_tracker_stats)); - host_cache.clear(); // Free HostTracker objects } }; @@ -127,70 +97,9 @@ TEST(host_tracker_module, host_tracker_module_test_basic) CHECK(true); } -#if 0 -// FIXIT-M the below are more functional in scope because they require host_cache -// services. need to stub this out better to focus on the module only. -// Test that HostTrackerModule variables are set correctly. -TEST(host_tracker_module, host_tracker_module_test_values) -{ - SfIp cached_addr; - - HostIpKey host_ip_key((const uint8_t*) expected_addr.get_ip6_ptr()); - std::shared_ptr ht; - - bool ret = host_cache.find(host_ip_key, ht); - CHECK(ret == true); - - cached_addr = ht->get_ip_addr(); - CHECK(cached_addr.fast_equals_raw(expected_addr) == true); - - Policy policy = ht->get_stream_policy(); - CHECK(policy == STREAM_POLICY + 1); - - policy = ht->get_frag_policy(); - CHECK(policy == FRAG_POLICY + 1); -} - - -// Test that HostTrackerModule statistics are correct. -TEST(host_tracker_module, host_tracker_module_test_stats) -{ - HostIpKey host_ip_key((const uint8_t*) expected_addr.get_ip6_ptr()); - std::shared_ptr ht; - - bool ret = host_cache.find(host_ip_key, ht); - CHECK(ret == true); - - HostApplicationEntry app; - ret = ht->find_service(1, 2112, app); - CHECK(ret == true); - CHECK(app.protocol == 3); - CHECK(app.ipproto == 1); - CHECK(app.port == 2112); - - ret = ht->remove_service(1, 2112); - CHECK(ret == true); - - // Verify counts are correct. The add was done during setup. - CHECK(host_tracker_stats.service_adds == 1); - CHECK(host_tracker_stats.service_finds == 1); - CHECK(host_tracker_stats.service_removes == 1); -} -#endif - int main(int argc, char** argv) { - // This is necessary to prevent the cpputest memory leak - // detection from thinking there's a memory leak in the map - // object contained within the global host_cache. The map - // must have some data allocated when it is first created - // that doesn't go away until the global map object is - // deallocated. This pre-allocates the map so that initial - // allocation is done prior to starting the tests. - HostTracker* ht = new HostTracker; - host_cache_add_host_tracker(ht); - host_cache.clear(); - + MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); return CommandLineTestRunner::RunAllTests(argc, argv); } diff --git a/src/host_tracker/test/host_tracker_test.cc b/src/host_tracker/test/host_tracker_test.cc index 5acd3b454..ae6eb4400 100644 --- a/src/host_tracker/test/host_tracker_test.cc +++ b/src/host_tracker/test/host_tracker_test.cc @@ -23,6 +23,8 @@ #include "config.h" #endif +#include + #include "host_tracker/host_tracker.h" #include @@ -42,159 +44,40 @@ TEST_GROUP(host_tracker) { }; -// Test HostTracker ipaddr get/set functions. -TEST(host_tracker, ipaddr_test) -{ - HostTracker ht; - SfIp zeroed_sfip; - SfIp expected_ip_addr; - SfIp actual_ip_addr; - - // Test IP prior to set. - memset(&zeroed_sfip, 0, sizeof(zeroed_sfip)); - actual_ip_addr = ht.get_ip_addr(); - CHECK(0 == memcmp(&zeroed_sfip, &actual_ip_addr, sizeof(zeroed_sfip))); - - expected_ip_addr.pton(AF_INET6, "beef:abcd:ef01:2300::"); - ht.set_ip_addr(expected_ip_addr); - actual_ip_addr = ht.get_ip_addr(); - CHECK(0 == memcmp(&expected_ip_addr, &actual_ip_addr, sizeof(expected_ip_addr))); -} - -// Test HostTracker policy get/set functions. -TEST(host_tracker, policy_test) -{ - HostTracker ht; - Policy expected_policy = 23; - Policy actual_policy; - - actual_policy = ht.get_stream_policy(); - CHECK(0 == actual_policy); - - actual_policy = ht.get_frag_policy(); - CHECK(0 == actual_policy); - - ht.set_stream_policy(expected_policy); - actual_policy = ht.get_stream_policy(); - CHECK(expected_policy == actual_policy); - - expected_policy = 77; - ht.set_frag_policy(expected_policy); - actual_policy = ht.get_frag_policy(); - CHECK(expected_policy == actual_policy); -} - -TEST(host_tracker, app_mapping_test) -{ - HostTracker ht; - const uint16_t expected_ports = 4123; - const uint16_t port1 = 4122; - AppId actual_appid; - Protocol expected_proto1 = 6; - Protocol expected_proto2 = 17; - AppId appid1 = 61; - AppId appid2 = 62; - bool ret; - - ht.add_app_mapping(expected_ports, expected_proto1, appid1); - - actual_appid = ht.find_app_mapping(expected_ports, expected_proto1); - CHECK(61 == actual_appid); - - actual_appid = ht.find_app_mapping(expected_ports, expected_proto2); - CHECK(APP_ID_NONE == actual_appid); - - actual_appid = ht.find_app_mapping(port1, expected_proto2); - CHECK(APP_ID_NONE == actual_appid); - - ret = ht.find_else_add_app_mapping(port1, expected_proto1, appid2); - CHECK(true == ret); - - ret = ht.find_else_add_app_mapping(port1, expected_proto1, appid2); - CHECK(false == ret); -} - -// Test HostTracker add and find service functions. +// Test HostTracker find appid and add service functions. TEST(host_tracker, add_find_service_test) { - bool ret; HostTracker ht; - HostApplicationEntry actual_entry; - HostApplicationEntry app_entry1(6, 2112, 3); - HostApplicationEntry app_entry2(17, 7777, 10); // Try a find on an empty list. - ret = ht.find_service(3,1000, actual_entry); - CHECK(false == ret); + CHECK(APP_ID_NONE == ht.get_appid(80, IpProtocol::TCP)); // Test add and find. - ret = ht.add_service(app_entry1); - CHECK(true == ret); - - ret = ht.find_service(6, 2112, actual_entry); - CHECK(true == ret); - CHECK(actual_entry.port == 2112); - CHECK(actual_entry.ipproto == 6); - CHECK(actual_entry.snort_protocol_id == 3); - - ht.add_service(app_entry2); - ret = ht.find_service(6, 2112, actual_entry); - CHECK(true == ret); - CHECK(actual_entry.port == 2112); - CHECK(actual_entry.ipproto == 6); - CHECK(actual_entry.snort_protocol_id == 3); - - ret = ht.find_service(17, 7777, actual_entry); - CHECK(true == ret); - CHECK(actual_entry.port == 7777); - CHECK(actual_entry.ipproto == 17); - CHECK(actual_entry.snort_protocol_id == 10); - - // Try adding an entry that exists already. - ret = ht.add_service(app_entry1); - CHECK(false == ret); - - // Try a find on a port that isn't in the list. - ret = ht.find_service(6, 100, actual_entry); - CHECK(false == ret); - - // Try a find on an ipproto that isn't in the list. - ret = ht.find_service(17, 2112, actual_entry); - CHECK(false == ret); - - // Try to remove an entry that's not in the list. - ret = ht.remove_service(6,100); - CHECK(false == ret); - - ret = ht.remove_service(17,2112); - CHECK(false == ret); - - // Actually remove an entry. - ret = ht.remove_service(6,2112); - CHECK(true == ret); + CHECK(true == ht.add_service(80, IpProtocol::TCP, 676, true)); + CHECK(true == ht.add_service(443, IpProtocol::TCP, 1122)); + CHECK(676 == ht.get_appid(80, IpProtocol::TCP)); + CHECK(1122 == ht.get_appid(443, IpProtocol::TCP)); + + // Try adding an entry that exists already and update appid + CHECK(true == ht.add_service(443, IpProtocol::TCP, 847)); + CHECK(847 == ht.get_appid(443, IpProtocol::TCP)); + + // Try a find appid on a port that isn't in the list. + CHECK(APP_ID_NONE == ht.get_appid(8080, IpProtocol::UDP)); } TEST(host_tracker, stringify) { - SfIp ip; - ip.pton(AF_INET6, "feed:dead:beef::"); - HostTracker ht(ip); - ht.add_app_mapping(80, 6, 676); - ht.add_app_mapping(443, 6, 1122); - ht.set_frag_policy(3); - HostApplicationEntry app_entry(6, 80, 10); - ht.add_service(app_entry); + HostTracker ht; + ht.add_service(80, IpProtocol::TCP, 676, true); + ht.add_service(443, IpProtocol::TCP, 1122); string host_tracker_string; ht.stringify(host_tracker_string); STRCMP_EQUAL(host_tracker_string.c_str(), - "IP: feed:dead:beef:0000:0000:0000:0000:0000\n" - "app_mappings size: 2\n" - " port: 80, proto: 6, appid: 676\n" - " port: 443, proto: 6, appid: 1122\n" - "stream policy: 0, frag policy: 3\n" - "services size: 1\n" - " port: 80, proto: 6, snort proto: 10"); + "\nservices size: 2" + "\n port: 80, proto: 6, appid: 676, inferred" + "\n port: 443, proto: 6, appid: 1122"); } int main(int argc, char** argv) diff --git a/src/network_inspectors/appid/appid_discovery.cc b/src/network_inspectors/appid/appid_discovery.cc index aae4e48f7..ced22d2f3 100644 --- a/src/network_inspectors/appid/appid_discovery.cc +++ b/src/network_inspectors/appid/appid_discovery.cc @@ -638,11 +638,15 @@ static void lookup_appid_by_host_port(AppIdSession& asd, Packet* p, IpProtocol p } else if (asd.config->mod_config->is_host_port_app_cache_runtime) { - AppId appid = snort::host_cache_find_app_mapping(ip, port, (Protocol)protocol); - if (appid > APP_ID_NONE) + auto ht = host_cache.find(*ip); + if ( ht ) { - asd.client.set_id(appid); - asd.client_disco_state = APPID_DISCO_STATE_FINISHED; + AppId appid = ht->get_appid(port, protocol, true); + if ( appid > APP_ID_NONE ) + { + asd.client.set_id(appid); + asd.client_disco_state = APPID_DISCO_STATE_FINISHED; + } } } } diff --git a/src/network_inspectors/appid/lua_detector_api.cc b/src/network_inspectors/appid/lua_detector_api.cc index ee1c869be..b45baa028 100644 --- a/src/network_inspectors/appid/lua_detector_api.cc +++ b/src/network_inspectors/appid/lua_detector_api.cc @@ -1181,15 +1181,15 @@ static int detector_add_host_port_dynamic(lua_State* L) } unsigned port = lua_tointeger(L, ++index); - unsigned proto = lua_tointeger(L, ++index); - if (proto > (unsigned)IpProtocol::RESERVED) + IpProtocol proto = (IpProtocol) lua_tointeger(L, ++index); + if (proto > IpProtocol::RESERVED) { - ErrorMessage("%s:Invalid protocol value %u\n",__func__, proto); + ErrorMessage("%s:Invalid protocol value %u\n",__func__, (unsigned) proto); return 0; } - if (!(snort::host_cache_add_app_mapping(ip_addr, port, proto, appid))) - ErrorMessage("%s:Failed to add app mapping\n",__func__); + if ( !host_cache[ip_addr]->add_service(port, proto, appid, true) ) + ErrorMessage("%s:Failed to add host tracker service\n",__func__); return 0; } diff --git a/src/network_inspectors/appid/test/appid_discovery_test.cc b/src/network_inspectors/appid/test/appid_discovery_test.cc index 8e843ab15..d893805d5 100644 --- a/src/network_inspectors/appid/test/appid_discovery_test.cc +++ b/src/network_inspectors/appid/test/appid_discovery_test.cc @@ -218,7 +218,13 @@ ServiceDiscovery& ServiceDiscovery::get_instance() s_discovery_manager = new ServiceDiscovery(); return *s_discovery_manager; } -AppId snort::host_cache_find_app_mapping(snort::SfIp const*, Port, Protocol){ return 0; } + +LruCacheShared host_cache(50); +AppId HostTracker::get_appid(Port, IpProtocol, bool) +{ + return APP_ID_NONE; +} + // Stubs for ClientDiscovery ClientDiscovery::ClientDiscovery(){} ClientDiscovery::~ClientDiscovery() {} diff --git a/src/network_inspectors/rna/rna_pnd.cc b/src/network_inspectors/rna/rna_pnd.cc index 3708456c8..10d004065 100644 --- a/src/network_inspectors/rna/rna_pnd.cc +++ b/src/network_inspectors/rna/rna_pnd.cc @@ -105,16 +105,16 @@ void RnaPnd::analyze_flow_udp(const Packet* p) void RnaPnd::discover_network_icmp(const Packet* p) { - if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(), - p->flow->client_port, SNORT_PROTO_ICMP) ) + if ( !(host_cache[p->flow->client_ip]-> + add_service(p->flow->client_port, p->get_ip_proto_next())) ) return; // process rna discovery for icmp } void RnaPnd::discover_network_ip(const Packet* p) { - if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(), - p->flow->client_port, SNORT_PROTO_IP) ) + if ( !(host_cache[p->flow->client_ip]-> + add_service(p->flow->client_port, p->get_ip_proto_next())) ) return; // process rna discovery for ip } @@ -128,8 +128,8 @@ void RnaPnd::discover_network_non_ip(const Packet* p) void RnaPnd::discover_network_tcp(const Packet* p) { // Track from initiator direction, if not already seen - if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(), - p->flow->client_port, SNORT_PROTO_TCP) ) + if ( !(host_cache[p->flow->client_ip]-> + add_service(p->flow->client_port, p->get_ip_proto_next())) ) return; // Add mac address to ht list, ttl, last_seen, etc. @@ -138,8 +138,8 @@ void RnaPnd::discover_network_tcp(const Packet* p) void RnaPnd::discover_network_udp(const Packet* p) { - if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(), - p->flow->client_port, SNORT_PROTO_UDP) ) + if ( !(host_cache[p->flow->client_ip]-> + add_service(p->flow->client_port, p->get_ip_proto_next())) ) return; // process rna discovery for udp } diff --git a/src/service_inspectors/wizard/wizard.cc b/src/service_inspectors/wizard/wizard.cc index c9502cedd..768adfb74 100644 --- a/src/service_inspectors/wizard/wizard.cc +++ b/src/service_inspectors/wizard/wizard.cc @@ -263,11 +263,9 @@ bool Wizard::spellbind( { f->service = m->book.find_spell(data, len, m); - if (f->service != nullptr) + if ( f->service != nullptr ) { - // FIXIT-H need to make sure Flow's ipproto and service - // correspond to HostApplicationEntry's ipproto and service - host_cache_add_service(f->server_ip, f->ip_proto, f->server_port, f->service); + host_cache[f->server_ip]->add_service(f->server_port, (IpProtocol)f->ip_proto); return true; } @@ -282,10 +280,11 @@ bool Wizard::cursebind(vector& curse_tracker, Flow* f, if (cst.curse->alg(data, len, cst.tracker)) { f->service = cst.curse->service.c_str(); - // FIXIT-H need to make sure Flow's ipproto and service - // correspond to HostApplicationEntry's ipproto and service - host_cache_add_service(f->server_ip, f->ip_proto, f->server_port, f->service); - return true; + if ( f->service != nullptr ) + { + host_cache[f->server_ip]->add_service(f->server_port, (IpProtocol)f->ip_proto); + return true; + } } } diff --git a/src/sfip/sf_ip.h b/src/sfip/sf_ip.h index b0fde364b..6eef4fa21 100644 --- a/src/sfip/sf_ip.h +++ b/src/sfip/sf_ip.h @@ -81,6 +81,7 @@ struct SO_PUBLIC SfIp bool fast_gt6(const SfIp& ip2) const; bool fast_eq6(const SfIp& ip2) const; bool fast_equals_raw(const SfIp& ip2) const; + bool operator==(const SfIp& ip2) const; /* * Miscellaneous @@ -463,6 +464,11 @@ inline bool SfIp::fast_equals_raw(const SfIp& ip2) const return false; } +inline bool SfIp::operator==(const SfIp& ip2) const +{ + return fast_equals_raw(ip2); +} + /* End of member function definitions */ SO_PUBLIC const char* sfip_ntop(const SfIp* ip, char* buf, int bufsize);