const PegInfo lru_cache_shared_peg_names[] =
{
{ CountType::SUM, "lru_cache_adds", "lru cache added new entry" },
- { CountType::SUM, "lru_cache_replaces", "lru cache replaced existing entry" },
{ CountType::SUM, "lru_cache_prunes", "lru cache pruned entry to make space for new entry" },
{ CountType::SUM, "lru_cache_find_hits", "lru cache found entry in cache" },
{ CountType::SUM, "lru_cache_find_misses", "lru cache did not find entry in cache" },
- { CountType::SUM, "lru_cache_removes", "lru cache found entry and removed it" },
- { CountType::SUM, "lru_cache_clears", "lru cache clear API calls" },
{ CountType::END, nullptr, nullptr },
};
// least-recently-used (LRU) entries are removed once a fixed size is hit.
#include <list>
-#include <vector>
-#include <unordered_map>
+#include <memory>
#include <mutex>
+#include <typeinfo>
+#include <unordered_map>
+#include <vector>
#include "framework/counts.h"
struct LruCacheSharedStats
{
PegCount adds = 0; // An insert that added new entry.
- PegCount replaces = 0; // An insert that replaced existing entry
PegCount prunes = 0; // When an old entry is removed to make
// room for a new entry.
PegCount find_hits = 0; // Found entry in cache.
PegCount find_misses = 0; // Did not find entry in cache.
- PegCount removes = 0; // Found entry and removed it.
- PegCount clears = 0; // Calls to clear API.
};
-template<typename Key, typename Data, typename Hash>
+template<typename Key, typename Value, typename Hash>
class LruCacheShared
{
public:
LruCacheShared& operator=(const LruCacheShared& arg) = delete;
LruCacheShared(const size_t initial_size) :
- max_size(initial_size),
- current_size(0)
- {
- }
+ max_size(initial_size), current_size(0) { }
- // Get current number of elements in the LruCache.
- size_t size()
- {
- std::lock_guard<std::mutex> cache_lock(cache_mutex);
- return current_size;
- }
+ using Data = std::shared_ptr<Value>;
- size_t get_max_size()
- {
- std::lock_guard<std::mutex> cache_lock(cache_mutex);
- return max_size;
- }
+ // Return data entry associated with key. If doesn't exist, return nullptr.
+ Data find(const Key& key);
+
+ // Return data entry associated with key. If doesn't exist, create a new entry.
+ Data operator[](const Key& key);
+
+ // Return all data from the LruCache in order (most recently used to least)
+ std::vector<std::pair<Key, Data> > get_all_data();
// Modify the maximum number of entries allowed in the cache.
// If the size is reduced, the oldest entries are removed.
bool set_max_size(size_t newsize);
- // Add data to cache or replace data if it already exists.
- void insert(const Key& key, const Data& data);
-
- // Find Data associated with Key. If update is true, mark entry as
- // recently used.
- // Returns true and copies data if the key is found.
- bool find(const Key& key, Data& data, bool update=true);
-
- // Remove entry associated with Key.
- // Returns true if entry existed, false otherwise.
- bool remove(const Key& key);
-
- // Remove entry associated with key and return removed data.
- // Returns true and copy of data if entry existed. Returns false if
- // entry did not exist.
- bool remove(const Key& key, Data& data);
-
- // Remove all elements from the LruCache
- void clear();
-
- // Return all data from the LruCache in order (most recently used to
- // least).
- std::vector<std::pair<Key, Data> > get_all_data();
-
const PegInfo* get_pegs() const
{
return lru_cache_shared_peg_names;
}
- PegCount* get_counts() const
+ PegCount* get_counts()
{
return (PegCount*)&stats;
}
struct LruCacheSharedStats stats;
};
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::set_max_size(size_t newsize)
+template<typename Key, typename Value, typename Hash>
+bool LruCacheShared<Key, Value, Hash>::set_max_size(size_t newsize)
{
LruListIter list_iter;
current_size--;
map.erase(list_iter->first);
list.erase(list_iter);
+ stats.prunes++;
}
max_size = newsize;
return true;
}
-template<typename Key, typename Data, typename Hash>
-void LruCacheShared<Key, Data, Hash>::insert(const Key& key, const Data& data)
+template<typename Key, typename Value, typename Hash>
+std::shared_ptr<Value> LruCacheShared<Key, Value, Hash>::find(const Key& key)
{
LruMapIter map_iter;
std::lock_guard<std::mutex> cache_lock(cache_mutex);
- // If key already exists, remove it.
map_iter = map.find(key);
- if (map_iter != map.end())
+ if (map_iter == map.end())
{
- current_size--;
- list.erase(map_iter->second);
- map.erase(map_iter);
- stats.replaces++;
+ stats.find_misses++;
+ return nullptr;
}
- else
+
+ // Move entry to front of LruList
+ list.splice(list.begin(), list, map_iter->second);
+ stats.find_hits++;
+ return map_iter->second->second;
+}
+
+template<typename Key, typename Value, typename Hash>
+std::shared_ptr<Value> LruCacheShared<Key, Value, Hash>::operator[](const Key& key)
+{
+ LruMapIter map_iter;
+ std::lock_guard<std::mutex> cache_lock(cache_mutex);
+
+ map_iter = map.find(key);
+ if (map_iter != map.end())
{
- stats.adds++;
+ stats.find_hits++;
+ list.splice(list.begin(), list, map_iter->second); // update LRU
+ return map_iter->second->second;
}
+ stats.find_misses++;
+ stats.adds++;
+ Data data = Data(new Value);
+
// Add key/data pair to front of list.
list.emplace_front(std::make_pair(key, data));
{
current_size++;
}
+ return data;
}
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::find(const Key& key, Data& data, bool update)
-{
- LruMapIter map_iter;
- std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
- map_iter = map.find(key);
- if (map_iter == map.end())
- {
- stats.find_misses++;
- return false; // Key is not in LruCache.
- }
-
- data = map_iter->second->second;
-
- // If needed, move entry to front of LruList
- if (update)
- list.splice(list.begin(), list, map_iter->second);
-
- stats.find_hits++;
- return true;
-}
-
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::remove(const Key& key)
-{
- LruMapIter map_iter;
- std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
- map_iter = map.find(key);
- if (map_iter == map.end())
- return false; // Key is not in LruCache.
-
- current_size--;
- list.erase(map_iter->second);
- map.erase(map_iter);
- stats.removes++;
- return(true);
-}
-
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::remove(const Key& key, Data& data)
-{
- LruMapIter map_iter;
- std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
- map_iter = map.find(key);
- if (map_iter == map.end())
- return false; // Key is not in LruCache.
-
- data = map_iter->second->second;
-
- current_size--;
- list.erase(map_iter->second);
- map.erase(map_iter);
- stats.removes++;
- return(true);
-}
-
-template<typename Key, typename Data, typename Hash>
-void LruCacheShared<Key, Data, Hash>::clear()
-{
- LruMapIter map_iter;
- std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
- for (map_iter = map.begin(); map_iter != map.end(); /* No incr */)
- {
- list.erase(map_iter->second);
-
- // erase returns next iterator after erased element.
- map_iter = map.erase(map_iter);
- }
-
- current_size = 0;
- stats.clears++;
-}
-
-template<typename Key, typename Data, typename Hash>
-std::vector<std::pair<Key, Data> > LruCacheShared<Key, Data, Hash>::get_all_data()
+template<typename Key, typename Value, typename Hash>
+std::vector< std::pair<Key, std::shared_ptr<Value>> >
+LruCacheShared<Key, Value, Hash>::get_all_data()
{
std::vector<std::pair<Key, Data> > vec;
std::lock_guard<std::mutex> cache_lock(cache_mutex);
{
};
-// Test LruCacheShared constructor and member access.
-TEST(lru_cache_shared, constructor_test)
-{
- LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
-
- CHECK(lru_cache.get_max_size() == 5);
- CHECK(lru_cache.size() == 0);
-}
-
-// Test LruCacheShared insert, find and get_all_data functions.
+// Test LruCacheShared find, operator[], and get_all_data functions.
TEST(lru_cache_shared, insert_test)
{
- std::string data;
- LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
+ LruCacheShared<int, std::string, std::hash<int> > lru_cache(3);
- lru_cache.insert(0, "zero");
- CHECK(true == lru_cache.find(0, data));
- CHECK("zero" == data);
-
- lru_cache.insert(1, "one");
- CHECK(true == lru_cache.find(1, data));
- CHECK("one" == data);
-
- lru_cache.insert(2, "two");
- CHECK(true == lru_cache.find(2, data));
- CHECK("two" == data);
-
- // Verify find fails for non-existent item.
- CHECK(false == lru_cache.find(3, data));
-
- // Verify that insert will replace data if key exists already.
- lru_cache.insert(1, "new one");
- CHECK(true == lru_cache.find(1, data));
- CHECK("new one" == data);
-
- // Verify current number of entries in cache.
- CHECK(3 == lru_cache.size());
-
- // Verify that the data is in LRU order.
- auto vec = lru_cache.get_all_data();
- CHECK(3 == vec.size());
- CHECK((vec[0] == std::make_pair(1, std::string("new one"))));
- CHECK((vec[1] == std::make_pair(2, std::string("two"))));
- CHECK((vec[2] == std::make_pair(0, std::string("zero"))));
-}
+ auto data = lru_cache[0];
+ CHECK(data == lru_cache.find(0));
+ data->assign("zero");
-// Test that the least recently used items are removed when we exceed
-// the capacity of the LruCache.
-TEST(lru_cache_shared, lru_removal_test)
-{
- LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
+ data = lru_cache[1];
+ CHECK(data == lru_cache.find(1));
+ data->assign("one");
- for (int i = 0; i < 10; i++)
- {
- lru_cache.insert(i, std::to_string(i));
- }
-
- CHECK(5 == lru_cache.size());
-
- // Verify that the data is in LRU order and is correct.
- auto vec = lru_cache.get_all_data();
- CHECK(5 == vec.size());
- CHECK((vec[0] == std::make_pair(9, std::string("9"))));
- CHECK((vec[1] == std::make_pair(8, std::string("8"))));
- CHECK((vec[2] == std::make_pair(7, std::string("7"))));
- CHECK((vec[3] == std::make_pair(6, std::string("6"))));
- CHECK((vec[4] == std::make_pair(5, std::string("5"))));
-}
+ data = lru_cache[2];
+ CHECK(data == lru_cache.find(2));
+ data->assign("two");
-// Test the remove and clear functions.
-TEST(lru_cache_shared, remove_test)
-{
- std::string data;
- LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
+ // Replace existing
+ data = lru_cache[0];
+ data->assign("new_zero");
- for (int i = 0; i < 5; i++)
- {
- lru_cache.insert(i, std::to_string(i));
- CHECK(true == lru_cache.find(i, data));
- CHECK(data == std::to_string(i));
-
- CHECK(true == lru_cache.remove(i));
- CHECK(false == lru_cache.find(i, data));
- }
-
- CHECK(0 == lru_cache.size());
-
- // Test remove API that returns the removed data.
- lru_cache.insert(1, "one");
- CHECK(1 == lru_cache.size());
- CHECK(true == lru_cache.remove(1, data));
- CHECK(data == "one");
- CHECK(0 == lru_cache.size());
-
- lru_cache.insert(1, "one");
- lru_cache.insert(2, "two");
- CHECK(2 == lru_cache.size());
-
- // Verify that removing an item that does not exist does not affect
- // cache.
- CHECK(false == lru_cache.remove(3));
- CHECK(false == lru_cache.remove(4, data));
- CHECK(2 == lru_cache.size());
-
- auto vec = lru_cache.get_all_data();
- CHECK(2 == vec.size());
- CHECK((vec[0] == std::make_pair(2, std::string("two"))));
- CHECK((vec[1] == std::make_pair(1, std::string("one"))));
-
- // Verify that clear() removes all entries.
- lru_cache.clear();
- CHECK(0 == lru_cache.size());
-
- vec = lru_cache.get_all_data();
- CHECK(vec.empty());
+ // Verify find fails for non-existent item.
+ CHECK(nullptr == lru_cache.find(3));
+
+ // Replace least recently used when capacity exceeds
+ data = lru_cache[3];
+ CHECK(data == lru_cache.find(3));
+ data->assign("three");
+
+ // Verify current number of entries in cache and LRU order.
+ const auto&& vec = lru_cache.get_all_data();
+ CHECK(vec.size() == 3);
+ CHECK(vec[0].second->compare("three") == 0);
+ CHECK(vec[1].second->compare("new_zero") == 0);
+ CHECK(vec[2].second->compare("two") == 0);
}
-// Test statistics counters.
+// Test statistics counters and set_max_size.
TEST(lru_cache_shared, stats_test)
{
- std::string data;
LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
for (int i = 0; i < 10; i++)
- {
- lru_cache.insert(i, std::to_string(i));
- }
-
- lru_cache.insert(8, "new-eight"); // Replace entries.
- lru_cache.insert(9, "new-nine");
-
- CHECK(5 == lru_cache.size());
-
- lru_cache.find(7, data); // Hits
- lru_cache.find(8, data);
- lru_cache.find(9, data);
-
- lru_cache.remove(7);
- lru_cache.remove(8);
- lru_cache.remove(9, data);
- CHECK("new-nine" == data);
+ lru_cache[i];
- lru_cache.find(8, data); // Misses now that they're removed.
- lru_cache.find(9, data);
+ lru_cache.find(7); // Hits
+ lru_cache.find(8);
+ lru_cache.find(9);
- lru_cache.remove(100); // Removing a non-existent entry does not
- // increase remove count.
+ lru_cache.find(10); // Misses; in addition to previous 10
+ lru_cache.find(11);
- lru_cache.clear();
+ CHECK(lru_cache.set_max_size(3) == true); // change size prunes; in addition to previous 5
PegCount* stats = lru_cache.get_counts();
CHECK(stats[0] == 10); // adds
- CHECK(stats[1] == 2); // replaces
- CHECK(stats[2] == 5); // prunes
- CHECK(stats[3] == 3); // find hits
- CHECK(stats[4] == 2); // find misses
- CHECK(stats[5] == 3); // removes
- CHECK(stats[6] == 1); // clears
+ CHECK(stats[1] == 7); // prunes
+ CHECK(stats[2] == 3); // find hits
+ CHECK(stats[3] == 12); // find misses
// Check statistics names.
const PegInfo* pegs = lru_cache.get_pegs();
CHECK(!strcmp(pegs[0].name, "lru_cache_adds"));
- CHECK(!strcmp(pegs[1].name, "lru_cache_replaces"));
- CHECK(!strcmp(pegs[2].name, "lru_cache_prunes"));
- CHECK(!strcmp(pegs[3].name, "lru_cache_find_hits"));
- CHECK(!strcmp(pegs[4].name, "lru_cache_find_misses"));
- CHECK(!strcmp(pegs[5].name, "lru_cache_removes"));
- CHECK(!strcmp(pegs[6].name, "lru_cache_clears"));
+ CHECK(!strcmp(pegs[1].name, "lru_cache_prunes"));
+ CHECK(!strcmp(pegs[2].name, "lru_cache_find_hits"));
+ CHECK(!strcmp(pegs[3].name, "lru_cache_find_misses"));
}
int main(int argc, char** argv)
#include "host_cache.h"
-#include "main/snort_config.h"
-#include "target_based/snort_protocols.h"
-
using namespace snort;
-using namespace std;
#define LRU_CACHE_INITIAL_SIZE 65535
-LruCacheShared<HostIpKey, std::shared_ptr<HostTracker>, HashHostIpKey>
- host_cache(LRU_CACHE_INITIAL_SIZE);
-
-namespace snort
-{
-
-void host_cache_add_host_tracker(HostTracker* ht)
-{
- std::shared_ptr<HostTracker> sptr(ht);
- host_cache.insert((const uint8_t*) ht->get_ip_addr().get_ip6_ptr(), sptr);
-}
-
-bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, SnortProtocolId id)
-{
- HostIpKey ipkey((const uint8_t*) ipaddr.get_ip6_ptr());
- std::shared_ptr<HostTracker> ht;
-
- if (!host_cache.find(ipkey, ht))
- {
- // This host hasn't been seen. Add it.
- ht = std::make_shared<HostTracker>(ipaddr);
-
- if (ht == nullptr)
- {
- // FIXIT-L add error count
- return false;
- }
- host_cache.insert(ipkey, ht);
- }
-
- HostApplicationEntry app_entry(ipproto, port, id);
- return ht->add_service(app_entry);
-}
-
-bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, const char* service)
-{
- return host_cache_add_service(ipaddr, ipproto, port,
- SnortConfig::get_conf()->proto_ref->find(service));
-}
-
-bool host_cache_add_app_mapping(const SfIp& ipaddr, Port port, Protocol proto, AppId appId)
-{
- HostIpKey ipkey((const uint8_t*) ipaddr.get_ip6_ptr());
- std::shared_ptr<HostTracker> ht;
-
- if (!host_cache.find(ipkey, ht))
- {
- ht = std::make_shared<HostTracker> (ipaddr);
-
- if (ht == nullptr)
- {
- return false;
- }
- ht->add_app_mapping(port, proto, appId);
- host_cache.insert(ipkey, ht);
- }
- else
- {
- return ht->find_else_add_app_mapping(port, proto, appId);
- }
-
- return true;
-}
-
-AppId host_cache_find_app_mapping(const SfIp* ipaddr, Port port, Protocol proto)
-{
- HostIpKey ipkey((const uint8_t*) ipaddr->get_ip6_ptr());
- std::shared_ptr<HostTracker> ht;
-
- if (host_cache.find(ipkey, ht))
- {
- return ht->find_app_mapping(port, proto);
- }
-
- return APP_ID_NONE;
-}
-}
+LruCacheShared<SfIp, HostTracker, HashIp> host_cache(LRU_CACHE_INITIAL_SIZE);
// The host cache is used to cache information about hosts so that it can
// be shared among threads.
-#include <memory>
-
#include "hash/lru_cache_shared.h"
#include "host_tracker/host_tracker.h"
-
-#define HOST_IP_KEY_SIZE 16
-
-struct HostIpKey
-{
- union host_ip_addr
- {
- uint8_t ip8[HOST_IP_KEY_SIZE];
- uint64_t ip64[HOST_IP_KEY_SIZE/8];
- } ip_addr = {{0}}; // Holds either IPv4 or IPv6 addr
-
- HostIpKey() = default;
-
- HostIpKey(const uint8_t ip[HOST_IP_KEY_SIZE])
- {
- memcpy(&ip_addr, ip, HOST_IP_KEY_SIZE);
- }
-
- inline bool operator==(const HostIpKey& rhs) const
- {
- return !memcmp(&ip_addr, &rhs.ip_addr, HOST_IP_KEY_SIZE);
- }
-};
+#include "sfip/sf_ip.h"
// Used to create hash of key for indexing into cache.
-struct HashHostIpKey
+struct HashIp
{
- size_t operator()(const HostIpKey& ip) const
+ size_t operator()(const snort::SfIp& ip) const
{
- return std::hash<uint64_t>() (ip.ip_addr.ip64[0]) ^
- std::hash<uint64_t>() (ip.ip_addr.ip64[1]);
+ const uint64_t* ip64 = (const uint64_t*) ip.get_ip6_ptr();
+ return std::hash<uint64_t>() (ip64[0]) ^
+ std::hash<uint64_t>() (ip64[1]);
}
};
-extern LruCacheShared<HostIpKey, std::shared_ptr<HostTracker>, HashHostIpKey> host_cache;
-
-namespace snort
-{
-void host_cache_add_host_tracker(HostTracker*);
-
-// Insert a new service into host cache if it doesn't already exist.
-SO_PUBLIC bool host_cache_add_service(const SfIp&, Protocol, Port, SnortProtocolId);
-SO_PUBLIC bool host_cache_add_service(const SfIp&, Protocol, Port, const char*);
+extern SO_PUBLIC LruCacheShared<snort::SfIp, HostTracker, HashIp> host_cache;
-bool host_cache_add_app_mapping(const SfIp&, Port, Protocol, AppId);
-AppId host_cache_find_app_mapping(const SfIp* , Port, Protocol );
-}
#endif
log_host_cache(dump_file);
snort_free((void*)dump_file);
}
- host_cache.clear();
}
void HostCacheModule::log_host_cache(const char* file_name, bool verbose)
}
string str;
+ SfIpString ip_str;
const auto&& lru_data = host_cache.get_all_data();
for ( const auto& elem : lru_data )
{
+ str = "IP: ";
+ str += elem.first.ntop(ip_str);
elem.second->stringify(str);
out_stream << str << endl << endl;
}
THREAD_LOCAL struct HostTrackerStats host_tracker_stats;
-snort::SfIp HostTracker::get_ip_addr()
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- return ip_addr;
-}
-
-void HostTracker::set_ip_addr(const snort::SfIp& new_ip_addr)
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- std::memcpy(&ip_addr, &new_ip_addr, sizeof(ip_addr));
-}
-
-Policy HostTracker::get_stream_policy()
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- return stream_policy;
-}
-
-void HostTracker::set_stream_policy(const Policy& policy)
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- stream_policy = policy;
-}
-
-Policy HostTracker::get_frag_policy()
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- return frag_policy;
-}
-
-void HostTracker::set_frag_policy(const Policy& policy)
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- frag_policy = policy;
-}
-
-void HostTracker::add_app_mapping(Port port, Protocol proto, AppId appid)
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- AppMapping app_map = {port, proto, appid};
-
- app_mappings.emplace_back(app_map);
-}
-
-AppId HostTracker::find_app_mapping(Port port, Protocol proto)
+bool HostTracker::add_service(Port port, IpProtocol proto, AppId appid, bool inferred_appid)
{
+ host_tracker_stats.service_adds++;
std::lock_guard<std::mutex> lck(host_tracker_lock);
- for (std::vector<AppMapping>::iterator it=app_mappings.begin(); it!=app_mappings.end(); ++it)
- {
- if (it->port == port and it->proto ==proto)
- {
- return it->appid;
- }
- }
- return APP_ID_NONE;
-}
-bool HostTracker::find_else_add_app_mapping(Port port, Protocol proto, AppId appid)
-{
- std::lock_guard<std::mutex> lck(host_tracker_lock);
- for (std::vector<AppMapping>::iterator it=app_mappings.begin(); it!=app_mappings.end(); ++it)
+ for ( auto& s : services )
{
- if (it->port == port and it->proto ==proto)
+ if ( s.port == port and s.proto == proto )
{
- return false;
+ if ( s.appid != appid and appid != APP_ID_NONE )
+ {
+ s.appid = appid;
+ s.inferred_appid = inferred_appid;
+ }
+ return true;
}
}
- AppMapping app_map = {port, proto, appid};
-
- app_mappings.emplace_back(app_map);
- return true;
-}
-
-bool HostTracker::add_service(const HostApplicationEntry& app_entry)
-{
- host_tracker_stats.service_adds++;
-
- std::lock_guard<std::mutex> lck(host_tracker_lock);
-
- auto iter = std::find(services.begin(), services.end(), app_entry);
- if (iter != services.end())
- return false; // Already exists.
- services.emplace_front(app_entry);
+ services.emplace_back( HostApplication{port, proto, appid, inferred_appid} );
return true;
}
-void HostTracker::add_or_replace_service(const HostApplicationEntry& app_entry)
+AppId HostTracker::get_appid(Port port, IpProtocol proto, bool inferred_only)
{
- host_tracker_stats.service_adds++;
-
- std::lock_guard<std::mutex> lck(host_tracker_lock);
-
- auto iter = std::find(services.begin(), services.end(), app_entry);
- if (iter != services.end())
- services.erase(iter);
-
- services.emplace_front(app_entry);
-}
-
-bool HostTracker::find_service(Protocol ipproto, Port port, HostApplicationEntry& app_entry)
-{
- HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID);
host_tracker_stats.service_finds++;
-
std::lock_guard<std::mutex> lck(host_tracker_lock);
- auto iter = std::find(services.begin(), services.end(), tmp_entry);
- if (iter != services.end())
+ for ( const auto& s : services )
{
- app_entry = *iter;
- return true;
+ if ( s.port == port and s.proto == proto and
+ (!inferred_only or s.inferred_appid == inferred_only) )
+ return s.appid;
}
- return false;
-}
-
-bool HostTracker::remove_service(Protocol ipproto, Port port)
-{
- HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID);
- host_tracker_stats.service_removes++;
-
- std::lock_guard<std::mutex> lck(host_tracker_lock);
-
- auto iter = std::find(services.begin(), services.end(), tmp_entry);
- if (iter != services.end())
- {
- services.erase(iter);
- return true; // Assumes only one matching entry.
- }
-
- return false;
+ return APP_ID_NONE;
}
void HostTracker::stringify(string& str)
{
- str = "IP: ";
- SfIpString ip_str;
- str += ip_addr.ntop(ip_str);
-
- if ( !app_mappings.empty() )
- {
- str += "\napp_mappings size: " + to_string(app_mappings.size());
- for ( const auto& elem : app_mappings )
- str += "\n port: " + to_string(elem.port)
- + ", proto: " + to_string(elem.proto)
- + ", appid: " + to_string(elem.appid);
- }
-
- if ( stream_policy or frag_policy )
- str += "\nstream policy: " + to_string(stream_policy)
- + ", frag policy: " + to_string(frag_policy);
-
if ( !services.empty() )
{
str += "\nservices size: " + to_string(services.size());
- for ( const auto& elem : services )
- str += "\n port: " + to_string(elem.port)
- + ", proto: " + to_string(elem.ipproto)
- + ", snort proto: " + to_string(elem.snort_protocol_id);
- }
+ for ( const auto& s : services )
+ {
+ str += "\n port: " + to_string(s.port)
+ + ", proto: " + to_string((uint8_t) s.proto);
+ if ( s.appid != APP_ID_NONE )
+ {
+ str += ", appid: " + to_string(s.appid);
+ if ( s.inferred_appid )
+ str += ", inferred";
+ }
+ }
+ }
}
// configuration or dynamic discovery). It provides a thread-safe API to
// set/get the host data.
-#include <algorithm>
-#include <cstring>
-#include <list>
#include <mutex>
+#include <vector>
+
#include "framework/counts.h"
#include "main/snort_types.h"
#include "main/thread.h"
#include "network_inspectors/appid/application_ids.h"
#include "protocols/protocol_ids.h"
-#include "sfip/sf_ip.h"
-#include "target_based/snort_protocols.h"
-
-// FIXIT-M For now this emulates the Snort++ attribute table.
-// Need to add in host_tracker.h data eventually.
-
-typedef uint16_t Protocol;
-typedef uint8_t Policy;
struct HostTrackerStats
{
PegCount service_adds;
PegCount service_finds;
- PegCount service_removes;
};
extern THREAD_LOCAL struct HostTrackerStats host_tracker_stats;
-struct HostApplicationEntry
-{
- Port port = 0;
- Protocol ipproto = 0;
- SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID;
-
- HostApplicationEntry() = default;
-
- HostApplicationEntry(Protocol ipproto_param, Port port_param, SnortProtocolId protocol_param) :
- port(port_param),
- ipproto(ipproto_param),
- snort_protocol_id(protocol_param)
- {
- }
-
- inline bool operator==(const HostApplicationEntry& rhs) const
- {
- return ipproto == rhs.ipproto and port == rhs.port;
- }
-};
-
-struct AppMapping
+struct HostApplication
{
Port port;
- Protocol proto;
+ IpProtocol proto;
AppId appid;
+ bool inferred_appid;
};
-class HostTracker
+class SO_PUBLIC HostTracker
{
public:
- HostTracker()
- { memset(&ip_addr, 0, sizeof(ip_addr)); }
-
- HostTracker(const snort::SfIp& new_ip_addr)
- { std::memcpy(&ip_addr, &new_ip_addr, sizeof(ip_addr)); }
-
- snort::SfIp get_ip_addr();
- void set_ip_addr(const snort::SfIp& new_ip_addr);
- Policy get_stream_policy();
- void set_stream_policy(const Policy& policy);
- Policy get_frag_policy();
- void set_frag_policy(const Policy& policy);
- void add_app_mapping(Port port, Protocol proto, AppId appid);
- AppId find_app_mapping(Port port, Protocol proto);
- bool find_else_add_app_mapping(Port port, Protocol proto, AppId appid);
-
- // Add host service data only if it doesn't already exist. Returns
- // false if entry exists already, and true if entry was added.
- bool add_service(const HostApplicationEntry& app_entry);
+ // Appid may not be identified always. Inferred means dynamic/runtime
+ // appid detected from one flow to another flow such as BitTorrent.
+ bool add_service(Port port, IpProtocol proto,
+ AppId appid = APP_ID_NONE, bool inferred_appid = false);
- // Add host service data if it doesn't already exist. If it does exist
- // replace the previous entry with the new entry.
- void add_or_replace_service(const HostApplicationEntry& app_entry);
-
- // Returns true and fills in copy of HostApplicationEntry when found.
- // Returns false when not found.
- bool find_service(Protocol ipproto, Port port, HostApplicationEntry& app_entry);
-
- // Removes HostApplicationEntry object associated with ipproto and port.
- // Returns true if entry existed. False otherwise.
- bool remove_service(Protocol ipproto, Port port);
+ AppId get_appid(Port port, IpProtocol proto, bool inferred_only = false);
// This should be updated whenever HostTracker data members are changed
void stringify(std::string& str);
// Ensure that updates to a shared object are safe
std::mutex host_tracker_lock;
- // FIXIT-M do we need to use a host_id instead of SfIp as in sfrna?
- snort::SfIp ip_addr;
- std::vector< AppMapping > app_mappings;
-
- // Policies to apply to this host.
- Policy stream_policy = 0;
- Policy frag_policy = 0;
-
- std::list<HostApplicationEntry> services;
+ std::vector<HostApplication> services;
};
#endif
#include "host_tracker_module.h"
+#include "log/messages.h"
#include "main/snort_config.h"
-#include "stream/stream.h"
-#include "target_based/snort_protocols.h"
-
-#include "host_cache.h"
using namespace snort;
{
{ CountType::SUM, "service_adds", "host service adds" },
{ CountType::SUM, "service_finds", "host service finds" },
- { CountType::SUM, "service_removes", "host service removes" },
{ CountType::END, nullptr, nullptr },
};
const Parameter HostTrackerModule::service_params[] =
{
- { "name", Parameter::PT_STRING, nullptr, nullptr,
- "service identifier" },
-
- { "proto", Parameter::PT_ENUM, "tcp | udp", "tcp",
- "IP protocol" },
+ { "port", Parameter::PT_PORT, nullptr, nullptr, "port number" },
- { "port", Parameter::PT_PORT, nullptr, nullptr,
- "port number" },
+ { "proto", Parameter::PT_ENUM, "ip | tcp | udp", nullptr, "IP protocol" },
{ nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
};
const Parameter HostTrackerModule::host_tracker_params[] =
{
- { "ip", Parameter::PT_ADDR, nullptr, "0.0.0.0/32",
- "hosts address / cidr" },
-
- { "frag_policy", Parameter::PT_ENUM, IP_POLICIES, nullptr,
- "defragmentation policy" },
-
- { "tcp_policy", Parameter::PT_ENUM, TCP_POLICIES, nullptr,
- "TCP reassembly policy" },
+ { "ip", Parameter::PT_ADDR, nullptr, nullptr, "hosts address / cidr" },
{ "services", Parameter::PT_LIST, HostTrackerModule::service_params, nullptr,
"list of service parameters" },
{ nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
};
-bool HostTrackerModule::set(const char*, Value& v, SnortConfig* sc)
+bool HostTrackerModule::set(const char*, Value& v, SnortConfig*)
{
- if ( host and v.is("ip") )
- {
- SfIp addr;
+ if ( v.is("ip") )
v.get_addr(addr);
- host->set_ip_addr(addr);
- }
- else if ( host and v.is("frag_policy") )
- host->set_frag_policy(v.get_uint8() + 1);
-
- else if ( host and v.is("tcp_policy") )
- host->set_stream_policy(v.get_uint8() + 1);
-
- else if ( v.is("name") )
- app.snort_protocol_id = sc->proto_ref->add(v.get_string());
-
- else if ( v.is("proto") )
- app.ipproto = sc->proto_ref->add(v.get_string());
else if ( v.is("port") )
app.port = v.get_uint16();
+ else if ( v.is("proto") )
+ {
+ const IpProtocol mask[] =
+ { IpProtocol::IP, IpProtocol::TCP, IpProtocol::UDP };
+ app.proto = mask[v.get_uint8()];
+ }
+
else
return false;
bool HostTrackerModule::begin(const char* fqn, int idx, SnortConfig*)
{
if ( idx && !strcmp(fqn, "host_tracker") )
- host = new HostTracker;
-
+ {
+ addr.clear();
+ app = {};
+ }
return true;
}
{
if ( idx && !strcmp(fqn, "host_tracker.services") )
{
- host->add_service(app);
+ if ( addr.is_set() )
+ host_cache[addr]->add_service(app.port, app.proto);
app = {};
}
- else if ( idx && !strcmp(fqn, "host_tracker") )
+ else if ( idx && !strcmp(fqn, "host_tracker") && addr.is_set() )
{
- host_cache_add_host_tracker(host);
- host = nullptr; // Host cache is now responsible for freeing host
+ host_cache[addr];
+ addr.clear();
}
return true;
#include <cassert>
#include "framework/module.h"
-#include "host_tracker/host_tracker.h"
+#include "host_tracker/host_cache.h"
#define host_tracker_help \
"configure hosts"
class HostTrackerModule : public snort::Module
{
public:
- HostTrackerModule() : snort::Module("host_tracker", host_tracker_help, host_tracker_params, true)
- { host = nullptr; }
-
- ~HostTrackerModule() override
- { assert(!host); }
+ HostTrackerModule() :
+ snort::Module("host_tracker", host_tracker_help, host_tracker_params, true) { }
const PegInfo* get_pegs() const override;
PegCount* get_counts() const override;
static const snort::Parameter host_tracker_params[];
static const snort::Parameter service_params[];
- HostApplicationEntry app;
- HostTracker* host;
+ HostApplication app;
+ snort::SfIp addr;
};
#endif
namespace snort
{
-// Fakes to avoid bringing in a ton of dependencies.
-SnortProtocolId ProtocolReference::add(char const*) { return 0; }
-SnortProtocolId ProtocolReference::find(char const*) { return 0; }
SnortConfig* SnortConfig::get_conf() { return nullptr; }
char* snort_strdup(const char* s) { return strdup(s); }
Module* ModuleManager::get_module(const char*) { return nullptr; }
void show_stats(PegCount*, const PegInfo*, IndexVec&, const char*, FILE*)
{ }
-#define FRAG_POLICY 33
-#define STREAM_POLICY 100
-
TEST_GROUP(host_cache_module)
{
void setup() override
const PegCount* ht_stats = module.get_counts();
CHECK(!strcmp(ht_pegs[0].name, "lru_cache_adds"));
- CHECK(!strcmp(ht_pegs[1].name, "lru_cache_replaces"));
- CHECK(!strcmp(ht_pegs[2].name, "lru_cache_prunes"));
- CHECK(!strcmp(ht_pegs[3].name, "lru_cache_find_hits"));
- CHECK(!strcmp(ht_pegs[4].name, "lru_cache_find_misses"));
- CHECK(!strcmp(ht_pegs[5].name, "lru_cache_removes"));
- CHECK(!strcmp(ht_pegs[6].name, "lru_cache_clears"));
- CHECK(!ht_pegs[7].name);
+ CHECK(!strcmp(ht_pegs[1].name, "lru_cache_prunes"));
+ CHECK(!strcmp(ht_pegs[2].name, "lru_cache_find_hits"));
+ CHECK(!strcmp(ht_pegs[3].name, "lru_cache_find_misses"));
+ CHECK(!ht_pegs[4].name);
CHECK(ht_stats[0] == 0);
CHECK(ht_stats[1] == 0);
CHECK(ht_stats[2] == 0);
CHECK(ht_stats[3] == 0);
- CHECK(ht_stats[4] == 0);
- CHECK(ht_stats[5] == 0);
- CHECK(ht_stats[6] == 0);
size_val.set(&size_param);
ht_stats = module.get_counts();
CHECK(ht_stats[0] == 0);
-
- CHECK(2112 == host_cache.get_max_size());
}
TEST(host_cache_module, log_host_cache_messages)
#include "host_tracker/host_cache.h"
+#include <cstring>
+
#include "main/snort_config.h"
#include <CppUTest/CommandLineTestRunner.h>
namespace snort
{
-SnortConfig s_conf;
-THREAD_LOCAL SnortConfig* snort_conf = &s_conf;
-SnortConfig::SnortConfig(const SnortConfig* const) { }
-SnortConfig::~SnortConfig() = default;
-SnortConfig* SnortConfig::get_conf()
-{ return snort_conf; }
-
-SnortProtocolId ProtocolReference::find(char const*) { return 0; }
-SnortProtocolId ProtocolReference::add(const char* protocol)
-{
- if (!strcmp("servicename", protocol))
- return 3;
- if (!strcmp("tcp", protocol))
- return 2;
- return 1;
-}
-
char* snort_strdup(const char* str)
{
return strdup(str);
{
};
-// Test HostIpKey
-TEST(host_cache, host_ip_key_test)
-{
- HostIpKey zeroed_hk;
- uint8_t expected_hk[16] =
- { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
-
- memset(&zeroed_hk.ip_addr, 0, sizeof(zeroed_hk.ip_addr));
-
- HostIpKey hkey1;
- CHECK(hkey1 == zeroed_hk);
-
- HostIpKey hkey2(expected_hk);
- CHECK(hkey2 == expected_hk);
-}
-
-// Test HashHostIpKey
+// Test HashIp
TEST(host_cache, hash_test)
{
size_t expected_hash_val = 4521729;
uint8_t hk[16] =
{ 0x0a,0xff,0x12,0x00,0x00,0x00,0x00,0x00,0x0b,0x00,0x56,0x00,0x00,0x00,0x00,0x00 };
- HashHostIpKey hash_hk;
+ HashIp hash_hk;
+ SfIp ip;
- actual_hash_val = hash_hk(hk);
+ ip.set(hk);
+ actual_hash_val = hash_hk(ip);
CHECK(actual_hash_val == expected_hash_val);
}
-// Test host_cache_add_host_tracker
-TEST(host_cache, host_cache_add_host_tracker_test)
-{
- HostTracker* expected_ht = new HostTracker;
- std::shared_ptr<HostTracker> actual_ht;
- uint8_t hk[16] =
- { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
- SfIp ip_addr;
- SfIp actual_ip_addr;
- HostIpKey hkey(hk);
- Port port = 443;
- Protocol proto = 6;
- HostApplicationEntry app_entry(proto, port, 2);
- HostApplicationEntry actual_app_entry;
- bool ret;
-
- ip_addr.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-
- expected_ht->set_ip_addr(ip_addr);
- expected_ht->add_service(app_entry);
-
- host_cache_add_host_tracker(expected_ht);
-
- ret = host_cache.find((const uint8_t*) ip_addr.get_ip6_ptr(), actual_ht);
- CHECK(true == ret);
-
- actual_ip_addr = actual_ht->get_ip_addr();
- CHECK(!memcmp(&ip_addr, &actual_ip_addr, sizeof(ip_addr)));
-
- ret = actual_ht->find_service(proto, port, actual_app_entry);
- CHECK(true == ret);
- CHECK(actual_app_entry == app_entry);
-
- host_cache.clear(); // Free HostTracker objects
-}
-
-// Test host_cache_add_service
-TEST(host_cache, host_cache_add_service_test)
-{
- HostTracker* expected_ht = new HostTracker;
- std::shared_ptr<HostTracker> actual_ht;
- uint8_t hk[16] =
- { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
- SfIp ip_addr1;
- SfIp ip_addr2;
- HostIpKey hkey(hk);
- Port port1 = 443;
- Port port2 = 22;
- Protocol proto1 = 17;
- Protocol proto2 = 6;
- HostApplicationEntry app_entry1(proto1, port1, 1);
- HostApplicationEntry app_entry2(proto2, port2, 2);
- HostApplicationEntry actual_app_entry;
- bool ret;
-
- ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
- ip_addr2.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-
- // Initialize cache with a HostTracker.
- host_cache_add_host_tracker(expected_ht);
-
- // Add a service to a HostTracker that already exists.
- ret = host_cache_add_service(ip_addr1, proto1, port1, "udp");
- CHECK(true == ret);
-
- ret = host_cache.find((const uint8_t*) ip_addr1.get_ip6_ptr(), actual_ht);
- CHECK(true == ret);
-
- ret = actual_ht->find_service(proto1, port1, actual_app_entry);
- CHECK(true == ret);
- CHECK(actual_app_entry == app_entry1);
-
- // Try adding service again (should fail since service exists).
- ret = host_cache_add_service(ip_addr1, proto1, port1, "udp");
- CHECK(false == ret);
-
- // Add a service with a new IP.
- ret = host_cache_add_service(ip_addr2, proto2, port2, "tcp");
- CHECK(true == ret);
-
- ret = host_cache.find((const uint8_t*) ip_addr1.get_ip6_ptr(), actual_ht);
- CHECK(true == ret);
-
- ret = actual_ht->find_service(proto2, port2, actual_app_entry);
- CHECK(true == ret);
- CHECK(actual_app_entry == app_entry2);
-
- host_cache.clear(); // Free HostTracker objects
-}
-
-TEST(host_cache, host_cache_app_mapping_test )
-{
- HostTracker* expected_ht = new HostTracker;
- std::shared_ptr<HostTracker> actual_ht;
- uint8_t hk[16] =
- { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
- SfIp ip_addr1;
- SfIp ip_addr2;
- HostIpKey hkey(hk);
- Port port1 = 4123;
- Port port2 = 1827;
- Protocol proto1 = 6;
- Protocol proto2 = 7;
- AppId appid1 = 61;
- AppId appid2 = 62;
- AppId appid3 = 63;
- AppId ret;
- bool add_ret;
-
- ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
- ip_addr2.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:901b");
-
- // Initialize cache with a HostTracker.
- host_cache_add_host_tracker(expected_ht);
-
- add_ret = host_cache_add_app_mapping(ip_addr1, port1, proto1, appid1);
- CHECK(true == add_ret);
-
- ret = host_cache_find_app_mapping(&ip_addr1, port1, proto1);
- CHECK(61 == ret);
-
- ret = host_cache_find_app_mapping(&ip_addr1, port2, proto1);
- CHECK(APP_ID_NONE == ret);
-
- ret = host_cache_find_app_mapping(&ip_addr2, port1, proto1);
- CHECK(APP_ID_NONE == ret);
-
- add_ret = host_cache_add_app_mapping(ip_addr1, port2, proto1, appid2);
- CHECK(true == add_ret);
- ret = host_cache_find_app_mapping(&ip_addr1, port2, proto1);
- CHECK(62 == ret);
-
- add_ret = host_cache_add_app_mapping(ip_addr1, port1, proto2, appid3);
- CHECK(true == add_ret);
- ret = host_cache_find_app_mapping(&ip_addr1, port1, proto2);
- CHECK(63 == ret);
-
- host_cache.clear(); // Free HostTracker objects
-}
-
int main(int argc, char** argv)
{
- SfIp ip_addr1;
- Protocol proto1 = 17;
- Port port1 = 443;
-
- ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-
- // This is necessary to prevent the cpputest memory leak
- // detection from thinking there's a memory leak in the map
- // object contained within the global host_cache. The map
- // must have some data allocated when it is first created
- // that doesn't go away until the global map object is
- // deallocated. This pre-allocates the map so that initial
- // allocation is done prior to starting the tests. The same
- // is true for the list used when adding a service.
- HostTracker* ht = new HostTracker;
- host_cache_add_host_tracker(ht);
- host_cache_add_service(ip_addr1, proto1, port1, "udp");
-
- host_cache.clear();
-
// Use this if you want to turn off memory checks entirely:
// MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
#include "config.h"
#endif
+#include <cstring>
+
#include "host_tracker/host_cache.h"
#include "host_tracker/host_tracker_module.h"
-#include "target_based/snort_protocols.h"
#include "main/snort_config.h"
+#include "target_based/snort_protocols.h"
#include <CppUTest/CommandLineTestRunner.h>
#include <CppUTest/TestHarness.h>
namespace snort
{
-SnortConfig* SnortConfig::get_conf() { return nullptr; }
-SnortProtocolId ProtocolReference::find(char const*) { return 0; }
-SnortProtocolId ProtocolReference::add(const char* protocol)
-{
- if (!strcmp("servicename", protocol))
- return 3;
- if (!strcmp("tcp", protocol))
- return 2;
- return 1;
-}
-
char* snort_strdup(const char* s)
{ return strdup(s); }
}
void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
void show_stats(PegCount*, const PegInfo*, IndexVec&, const char*, FILE*) { }
-#define FRAG_POLICY 33
-#define STREAM_POLICY 100
-
SfIp expected_addr;
TEST_GROUP(host_tracker_module)
void setup() override
{
Value ip_val("10.23.45.56");
- Value frag_val((double)FRAG_POLICY);
- Value tcp_val((double)STREAM_POLICY);
- Value name_val("servicename");
Value proto_val("udp");
Value port_val((double)2112);
Parameter ip_param = { "ip", Parameter::PT_ADDR, nullptr, "0.0.0.0/32", "addr/cidr"};
- Parameter frag_param = { "frag_policy", Parameter::Parameter::PT_ENUM, "linux | bsd", nullptr, "frag policy"};
- Parameter tcp_param = { "tcp_policy", Parameter::PT_ENUM, "linux | bsd", nullptr, "tcp policy"};
- Parameter name_param = {"name", Parameter::PT_STRING, nullptr, nullptr, "name"};
Parameter proto_param = {"proto", Parameter::PT_ENUM, "tcp | udp", "tcp", "ip proto"};
Parameter port_param = {"port", Parameter::PT_PORT, nullptr, nullptr, "port num"};
HostTrackerModule module;
CHECK(!strcmp(ht_pegs[0].name, "service_adds"));
CHECK(!strcmp(ht_pegs[1].name, "service_finds"));
- CHECK(!strcmp(ht_pegs[2].name, "service_removes"));
- CHECK(!ht_pegs[3].name);
+ CHECK(!ht_pegs[2].name);
CHECK(ht_stats[0] == 0);
CHECK(ht_stats[1] == 0);
- CHECK(ht_stats[2] == 0);
ip_val.set(&ip_param);
- frag_val.set(&frag_param);
- tcp_val.set(&tcp_param);
- name_val.set(&name_param);
proto_val.set(&proto_param);
port_val.set(&port_param);
// Set up the module values and add a service.
module.begin("host_tracker", 1, nullptr);
module.set(nullptr, ip_val, nullptr);
- module.set(nullptr, frag_val, nullptr);
- module.set(nullptr, tcp_val, nullptr);
-
- // FIXIT-M see FIXIT-M below
- //module.set(nullptr, name_val, nullptr);
- //module.set(nullptr, proto_val, nullptr);
module.set(nullptr, port_val, nullptr);
module.end("host_tracker.services", 1, nullptr);
void teardown() override
{
memset(&host_tracker_stats, 0, sizeof(host_tracker_stats));
- host_cache.clear(); // Free HostTracker objects
}
};
CHECK(true);
}
-#if 0
-// FIXIT-M the below are more functional in scope because they require host_cache
-// services. need to stub this out better to focus on the module only.
-// Test that HostTrackerModule variables are set correctly.
-TEST(host_tracker_module, host_tracker_module_test_values)
-{
- SfIp cached_addr;
-
- HostIpKey host_ip_key((const uint8_t*) expected_addr.get_ip6_ptr());
- std::shared_ptr<HostTracker> ht;
-
- bool ret = host_cache.find(host_ip_key, ht);
- CHECK(ret == true);
-
- cached_addr = ht->get_ip_addr();
- CHECK(cached_addr.fast_equals_raw(expected_addr) == true);
-
- Policy policy = ht->get_stream_policy();
- CHECK(policy == STREAM_POLICY + 1);
-
- policy = ht->get_frag_policy();
- CHECK(policy == FRAG_POLICY + 1);
-}
-
-
-// Test that HostTrackerModule statistics are correct.
-TEST(host_tracker_module, host_tracker_module_test_stats)
-{
- HostIpKey host_ip_key((const uint8_t*) expected_addr.get_ip6_ptr());
- std::shared_ptr<HostTracker> ht;
-
- bool ret = host_cache.find(host_ip_key, ht);
- CHECK(ret == true);
-
- HostApplicationEntry app;
- ret = ht->find_service(1, 2112, app);
- CHECK(ret == true);
- CHECK(app.protocol == 3);
- CHECK(app.ipproto == 1);
- CHECK(app.port == 2112);
-
- ret = ht->remove_service(1, 2112);
- CHECK(ret == true);
-
- // Verify counts are correct. The add was done during setup.
- CHECK(host_tracker_stats.service_adds == 1);
- CHECK(host_tracker_stats.service_finds == 1);
- CHECK(host_tracker_stats.service_removes == 1);
-}
-#endif
-
int main(int argc, char** argv)
{
- // This is necessary to prevent the cpputest memory leak
- // detection from thinking there's a memory leak in the map
- // object contained within the global host_cache. The map
- // must have some data allocated when it is first created
- // that doesn't go away until the global map object is
- // deallocated. This pre-allocates the map so that initial
- // allocation is done prior to starting the tests.
- HostTracker* ht = new HostTracker;
- host_cache_add_host_tracker(ht);
- host_cache.clear();
-
+ MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
return CommandLineTestRunner::RunAllTests(argc, argv);
}
#include "config.h"
#endif
+#include <cstring>
+
#include "host_tracker/host_tracker.h"
#include <CppUTest/CommandLineTestRunner.h>
{
};
-// Test HostTracker ipaddr get/set functions.
-TEST(host_tracker, ipaddr_test)
-{
- HostTracker ht;
- SfIp zeroed_sfip;
- SfIp expected_ip_addr;
- SfIp actual_ip_addr;
-
- // Test IP prior to set.
- memset(&zeroed_sfip, 0, sizeof(zeroed_sfip));
- actual_ip_addr = ht.get_ip_addr();
- CHECK(0 == memcmp(&zeroed_sfip, &actual_ip_addr, sizeof(zeroed_sfip)));
-
- expected_ip_addr.pton(AF_INET6, "beef:abcd:ef01:2300::");
- ht.set_ip_addr(expected_ip_addr);
- actual_ip_addr = ht.get_ip_addr();
- CHECK(0 == memcmp(&expected_ip_addr, &actual_ip_addr, sizeof(expected_ip_addr)));
-}
-
-// Test HostTracker policy get/set functions.
-TEST(host_tracker, policy_test)
-{
- HostTracker ht;
- Policy expected_policy = 23;
- Policy actual_policy;
-
- actual_policy = ht.get_stream_policy();
- CHECK(0 == actual_policy);
-
- actual_policy = ht.get_frag_policy();
- CHECK(0 == actual_policy);
-
- ht.set_stream_policy(expected_policy);
- actual_policy = ht.get_stream_policy();
- CHECK(expected_policy == actual_policy);
-
- expected_policy = 77;
- ht.set_frag_policy(expected_policy);
- actual_policy = ht.get_frag_policy();
- CHECK(expected_policy == actual_policy);
-}
-
-TEST(host_tracker, app_mapping_test)
-{
- HostTracker ht;
- const uint16_t expected_ports = 4123;
- const uint16_t port1 = 4122;
- AppId actual_appid;
- Protocol expected_proto1 = 6;
- Protocol expected_proto2 = 17;
- AppId appid1 = 61;
- AppId appid2 = 62;
- bool ret;
-
- ht.add_app_mapping(expected_ports, expected_proto1, appid1);
-
- actual_appid = ht.find_app_mapping(expected_ports, expected_proto1);
- CHECK(61 == actual_appid);
-
- actual_appid = ht.find_app_mapping(expected_ports, expected_proto2);
- CHECK(APP_ID_NONE == actual_appid);
-
- actual_appid = ht.find_app_mapping(port1, expected_proto2);
- CHECK(APP_ID_NONE == actual_appid);
-
- ret = ht.find_else_add_app_mapping(port1, expected_proto1, appid2);
- CHECK(true == ret);
-
- ret = ht.find_else_add_app_mapping(port1, expected_proto1, appid2);
- CHECK(false == ret);
-}
-
-// Test HostTracker add and find service functions.
+// Test HostTracker find appid and add service functions.
TEST(host_tracker, add_find_service_test)
{
- bool ret;
HostTracker ht;
- HostApplicationEntry actual_entry;
- HostApplicationEntry app_entry1(6, 2112, 3);
- HostApplicationEntry app_entry2(17, 7777, 10);
// Try a find on an empty list.
- ret = ht.find_service(3,1000, actual_entry);
- CHECK(false == ret);
+ CHECK(APP_ID_NONE == ht.get_appid(80, IpProtocol::TCP));
// Test add and find.
- ret = ht.add_service(app_entry1);
- CHECK(true == ret);
-
- ret = ht.find_service(6, 2112, actual_entry);
- CHECK(true == ret);
- CHECK(actual_entry.port == 2112);
- CHECK(actual_entry.ipproto == 6);
- CHECK(actual_entry.snort_protocol_id == 3);
-
- ht.add_service(app_entry2);
- ret = ht.find_service(6, 2112, actual_entry);
- CHECK(true == ret);
- CHECK(actual_entry.port == 2112);
- CHECK(actual_entry.ipproto == 6);
- CHECK(actual_entry.snort_protocol_id == 3);
-
- ret = ht.find_service(17, 7777, actual_entry);
- CHECK(true == ret);
- CHECK(actual_entry.port == 7777);
- CHECK(actual_entry.ipproto == 17);
- CHECK(actual_entry.snort_protocol_id == 10);
-
- // Try adding an entry that exists already.
- ret = ht.add_service(app_entry1);
- CHECK(false == ret);
-
- // Try a find on a port that isn't in the list.
- ret = ht.find_service(6, 100, actual_entry);
- CHECK(false == ret);
-
- // Try a find on an ipproto that isn't in the list.
- ret = ht.find_service(17, 2112, actual_entry);
- CHECK(false == ret);
-
- // Try to remove an entry that's not in the list.
- ret = ht.remove_service(6,100);
- CHECK(false == ret);
-
- ret = ht.remove_service(17,2112);
- CHECK(false == ret);
-
- // Actually remove an entry.
- ret = ht.remove_service(6,2112);
- CHECK(true == ret);
+ CHECK(true == ht.add_service(80, IpProtocol::TCP, 676, true));
+ CHECK(true == ht.add_service(443, IpProtocol::TCP, 1122));
+ CHECK(676 == ht.get_appid(80, IpProtocol::TCP));
+ CHECK(1122 == ht.get_appid(443, IpProtocol::TCP));
+
+ // Try adding an entry that exists already and update appid
+ CHECK(true == ht.add_service(443, IpProtocol::TCP, 847));
+ CHECK(847 == ht.get_appid(443, IpProtocol::TCP));
+
+ // Try a find appid on a port that isn't in the list.
+ CHECK(APP_ID_NONE == ht.get_appid(8080, IpProtocol::UDP));
}
TEST(host_tracker, stringify)
{
- SfIp ip;
- ip.pton(AF_INET6, "feed:dead:beef::");
- HostTracker ht(ip);
- ht.add_app_mapping(80, 6, 676);
- ht.add_app_mapping(443, 6, 1122);
- ht.set_frag_policy(3);
- HostApplicationEntry app_entry(6, 80, 10);
- ht.add_service(app_entry);
+ HostTracker ht;
+ ht.add_service(80, IpProtocol::TCP, 676, true);
+ ht.add_service(443, IpProtocol::TCP, 1122);
string host_tracker_string;
ht.stringify(host_tracker_string);
STRCMP_EQUAL(host_tracker_string.c_str(),
- "IP: feed:dead:beef:0000:0000:0000:0000:0000\n"
- "app_mappings size: 2\n"
- " port: 80, proto: 6, appid: 676\n"
- " port: 443, proto: 6, appid: 1122\n"
- "stream policy: 0, frag policy: 3\n"
- "services size: 1\n"
- " port: 80, proto: 6, snort proto: 10");
+ "\nservices size: 2"
+ "\n port: 80, proto: 6, appid: 676, inferred"
+ "\n port: 443, proto: 6, appid: 1122");
}
int main(int argc, char** argv)
}
else if (asd.config->mod_config->is_host_port_app_cache_runtime)
{
- AppId appid = snort::host_cache_find_app_mapping(ip, port, (Protocol)protocol);
- if (appid > APP_ID_NONE)
+ auto ht = host_cache.find(*ip);
+ if ( ht )
{
- asd.client.set_id(appid);
- asd.client_disco_state = APPID_DISCO_STATE_FINISHED;
+ AppId appid = ht->get_appid(port, protocol, true);
+ if ( appid > APP_ID_NONE )
+ {
+ asd.client.set_id(appid);
+ asd.client_disco_state = APPID_DISCO_STATE_FINISHED;
+ }
}
}
}
}
unsigned port = lua_tointeger(L, ++index);
- unsigned proto = lua_tointeger(L, ++index);
- if (proto > (unsigned)IpProtocol::RESERVED)
+ IpProtocol proto = (IpProtocol) lua_tointeger(L, ++index);
+ if (proto > IpProtocol::RESERVED)
{
- ErrorMessage("%s:Invalid protocol value %u\n",__func__, proto);
+ ErrorMessage("%s:Invalid protocol value %u\n",__func__, (unsigned) proto);
return 0;
}
- if (!(snort::host_cache_add_app_mapping(ip_addr, port, proto, appid)))
- ErrorMessage("%s:Failed to add app mapping\n",__func__);
+ if ( !host_cache[ip_addr]->add_service(port, proto, appid, true) )
+ ErrorMessage("%s:Failed to add host tracker service\n",__func__);
return 0;
}
s_discovery_manager = new ServiceDiscovery();
return *s_discovery_manager;
}
-AppId snort::host_cache_find_app_mapping(snort::SfIp const*, Port, Protocol){ return 0; }
+
+LruCacheShared<SfIp, HostTracker, HashIp> host_cache(50);
+AppId HostTracker::get_appid(Port, IpProtocol, bool)
+{
+ return APP_ID_NONE;
+}
+
// Stubs for ClientDiscovery
ClientDiscovery::ClientDiscovery(){}
ClientDiscovery::~ClientDiscovery() {}
void RnaPnd::discover_network_icmp(const Packet* p)
{
- if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
- p->flow->client_port, SNORT_PROTO_ICMP) )
+ if ( !(host_cache[p->flow->client_ip]->
+ add_service(p->flow->client_port, p->get_ip_proto_next())) )
return;
// process rna discovery for icmp
}
void RnaPnd::discover_network_ip(const Packet* p)
{
- if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
- p->flow->client_port, SNORT_PROTO_IP) )
+ if ( !(host_cache[p->flow->client_ip]->
+ add_service(p->flow->client_port, p->get_ip_proto_next())) )
return;
// process rna discovery for ip
}
void RnaPnd::discover_network_tcp(const Packet* p)
{
// Track from initiator direction, if not already seen
- if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
- p->flow->client_port, SNORT_PROTO_TCP) )
+ if ( !(host_cache[p->flow->client_ip]->
+ add_service(p->flow->client_port, p->get_ip_proto_next())) )
return;
// Add mac address to ht list, ttl, last_seen, etc.
void RnaPnd::discover_network_udp(const Packet* p)
{
- if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
- p->flow->client_port, SNORT_PROTO_UDP) )
+ if ( !(host_cache[p->flow->client_ip]->
+ add_service(p->flow->client_port, p->get_ip_proto_next())) )
return;
// process rna discovery for udp
}
{
f->service = m->book.find_spell(data, len, m);
- if (f->service != nullptr)
+ if ( f->service != nullptr )
{
- // FIXIT-H need to make sure Flow's ipproto and service
- // correspond to HostApplicationEntry's ipproto and service
- host_cache_add_service(f->server_ip, f->ip_proto, f->server_port, f->service);
+ host_cache[f->server_ip]->add_service(f->server_port, (IpProtocol)f->ip_proto);
return true;
}
if (cst.curse->alg(data, len, cst.tracker))
{
f->service = cst.curse->service.c_str();
- // FIXIT-H need to make sure Flow's ipproto and service
- // correspond to HostApplicationEntry's ipproto and service
- host_cache_add_service(f->server_ip, f->ip_proto, f->server_port, f->service);
- return true;
+ if ( f->service != nullptr )
+ {
+ host_cache[f->server_ip]->add_service(f->server_port, (IpProtocol)f->ip_proto);
+ return true;
+ }
}
}
bool fast_gt6(const SfIp& ip2) const;
bool fast_eq6(const SfIp& ip2) const;
bool fast_equals_raw(const SfIp& ip2) const;
+ bool operator==(const SfIp& ip2) const;
/*
* Miscellaneous
return false;
}
+inline bool SfIp::operator==(const SfIp& ip2) const
+{
+ return fast_equals_raw(ip2);
+}
+
/* End of member function definitions */
SO_PUBLIC const char* sfip_ntop(const SfIp* ip, char* buf, int bufsize);