]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #1691 in SNORT/snort3 from ~MASHASAN/snort3:refactor_host_cache...
authorMike Stepanek (mstepane) <mstepane@cisco.com>
Mon, 29 Jul 2019 16:35:43 +0000 (12:35 -0400)
committerMike Stepanek (mstepane) <mstepane@cisco.com>
Mon, 29 Jul 2019 16:35:43 +0000 (12:35 -0400)
Squashed commit of the following:

commit 8226ab4c2662a508d291efb2527777364cbaac6b
Author: Masud Hasan <mashasan@cisco.com>
Date:   Fri Jul 26 01:44:02 2019 -0400

    host_cache: Refactoring code to fix  multithreading issues and to remove redundancy

20 files changed:
src/hash/lru_cache_shared.cc
src/hash/lru_cache_shared.h
src/hash/test/lru_cache_shared_test.cc
src/host_tracker/host_cache.cc
src/host_tracker/host_cache.h
src/host_tracker/host_cache_module.cc
src/host_tracker/host_tracker.cc
src/host_tracker/host_tracker.h
src/host_tracker/host_tracker_module.cc
src/host_tracker/host_tracker_module.h
src/host_tracker/test/host_cache_module_test.cc
src/host_tracker/test/host_cache_test.cc
src/host_tracker/test/host_tracker_module_test.cc
src/host_tracker/test/host_tracker_test.cc
src/network_inspectors/appid/appid_discovery.cc
src/network_inspectors/appid/lua_detector_api.cc
src/network_inspectors/appid/test/appid_discovery_test.cc
src/network_inspectors/rna/rna_pnd.cc
src/service_inspectors/wizard/wizard.cc
src/sfip/sf_ip.h

index f458dee722c902d5cd407b82a0ef6ad97fe1326e..d56b1e832d8a29da536dda05b9839d98954aff73 100644 (file)
 const PegInfo lru_cache_shared_peg_names[] =
 {
     { CountType::SUM, "lru_cache_adds", "lru cache added new entry" },
-    { CountType::SUM, "lru_cache_replaces", "lru cache replaced existing entry" },
     { CountType::SUM, "lru_cache_prunes", "lru cache pruned entry to make space for new entry" },
     { CountType::SUM, "lru_cache_find_hits", "lru cache found entry in cache" },
     { CountType::SUM, "lru_cache_find_misses", "lru cache did not find entry in cache" },
-    { CountType::SUM, "lru_cache_removes", "lru cache found entry and removed it" },
-    { CountType::SUM, "lru_cache_clears", "lru cache clear API calls" },
     { CountType::END, nullptr, nullptr },
 };
 
index 097855616b774ec364e252e9500c784aa18f3083..6f2503e4ad3f797d8358915a8959990edcd014ed 100644 (file)
 // least-recently-used (LRU) entries are removed once a fixed size is hit.
 
 #include <list>
-#include <vector>
-#include <unordered_map>
+#include <memory>
 #include <mutex>
+#include <typeinfo>
+#include <unordered_map>
+#include <vector>
 
 #include "framework/counts.h"
 
@@ -36,16 +38,13 @@ extern const PegInfo lru_cache_shared_peg_names[];
 struct LruCacheSharedStats
 {
     PegCount adds = 0;       //  An insert that added new entry.
-    PegCount replaces = 0;   //  An insert that replaced existing entry
     PegCount prunes = 0;     //  When an old entry is removed to make
                              //  room for a new entry.
     PegCount find_hits = 0;  //  Found entry in cache.
     PegCount find_misses = 0; //  Did not find entry in cache.
-    PegCount removes = 0;    //  Found entry and removed it.
-    PegCount clears = 0;     //  Calls to clear API.
 };
 
-template<typename Key, typename Data, typename Hash>
+template<typename Key, typename Value, typename Hash>
 class LruCacheShared
 {
 public:
@@ -57,58 +56,29 @@ public:
     LruCacheShared& operator=(const LruCacheShared& arg) = delete;
 
     LruCacheShared(const size_t initial_size) :
-        max_size(initial_size),
-        current_size(0)
-    {
-    }
+        max_size(initial_size), current_size(0) { }
 
-    //  Get current number of elements in the LruCache.
-    size_t size()
-    {
-        std::lock_guard<std::mutex> cache_lock(cache_mutex);
-        return current_size;
-    }
+    using Data = std::shared_ptr<Value>;
 
-    size_t get_max_size()
-    {
-        std::lock_guard<std::mutex> cache_lock(cache_mutex);
-        return max_size;
-    }
+    // Return data entry associated with key. If doesn't exist, return nullptr.
+    Data find(const Key& key);
+
+    // Return data entry associated with key. If doesn't exist, create a new entry.
+    Data operator[](const Key& key);
+
+    // Return all data from the LruCache in order (most recently used to least)
+    std::vector<std::pair<Key, Data> > get_all_data();
 
     //  Modify the maximum number of entries allowed in the cache.
     //  If the size is reduced, the oldest entries are removed.
     bool set_max_size(size_t newsize);
 
-    //  Add data to cache or replace data if it already exists.
-    void insert(const Key& key, const Data& data);
-
-    //  Find Data associated with Key.  If update is true, mark entry as
-    //  recently used.
-    //  Returns true and copies data if the key is found.
-    bool find(const Key& key, Data& data, bool update=true);
-
-    //  Remove entry associated with Key.
-    //  Returns true if entry existed, false otherwise.
-    bool remove(const Key& key);
-
-    //  Remove entry associated with key and return removed data.
-    //  Returns true and copy of data if entry existed.  Returns false if
-    //  entry did not exist.
-    bool remove(const Key& key, Data& data);
-
-    //  Remove all elements from the LruCache
-    void clear();
-
-    //  Return all data from the LruCache in order (most recently used to
-    //  least).
-    std::vector<std::pair<Key, Data> > get_all_data();
-
     const PegInfo* get_pegs() const
     {
         return lru_cache_shared_peg_names;
     }
 
-    PegCount* get_counts() const
+    PegCount* get_counts()
     {
         return (PegCount*)&stats;
     }
@@ -144,8 +114,8 @@ private:
     struct LruCacheSharedStats stats;
 };
 
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::set_max_size(size_t newsize)
+template<typename Key, typename Value, typename Hash>
+bool LruCacheShared<Key, Value, Hash>::set_max_size(size_t newsize)
 {
     LruListIter list_iter;
 
@@ -162,32 +132,50 @@ bool LruCacheShared<Key, Data, Hash>::set_max_size(size_t newsize)
         current_size--;
         map.erase(list_iter->first);
         list.erase(list_iter);
+        stats.prunes++;
     }
 
     max_size = newsize;
     return true;
 }
 
-template<typename Key, typename Data, typename Hash>
-void LruCacheShared<Key, Data, Hash>::insert(const Key& key, const Data& data)
+template<typename Key, typename Value, typename Hash>
+std::shared_ptr<Value> LruCacheShared<Key, Value, Hash>::find(const Key& key)
 {
     LruMapIter map_iter;
     std::lock_guard<std::mutex> cache_lock(cache_mutex);
 
-    //  If key already exists, remove it.
     map_iter = map.find(key);
-    if (map_iter != map.end())
+    if (map_iter == map.end())
     {
-        current_size--;
-        list.erase(map_iter->second);
-        map.erase(map_iter);
-        stats.replaces++;
+        stats.find_misses++;
+        return nullptr;
     }
-    else
+
+    //  Move entry to front of LruList
+    list.splice(list.begin(), list, map_iter->second);
+    stats.find_hits++;
+    return map_iter->second->second;
+}
+
+template<typename Key, typename Value, typename Hash>
+std::shared_ptr<Value> LruCacheShared<Key, Value, Hash>::operator[](const Key& key)
+{
+    LruMapIter map_iter;
+    std::lock_guard<std::mutex> cache_lock(cache_mutex);
+
+    map_iter = map.find(key);
+    if (map_iter != map.end())
     {
-        stats.adds++;
+        stats.find_hits++;
+        list.splice(list.begin(), list, map_iter->second); // update LRU
+        return map_iter->second->second;
     }
 
+    stats.find_misses++;
+    stats.adds++;
+    Data data = Data(new Value);
+
     //  Add key/data pair to front of list.
     list.emplace_front(std::make_pair(key, data));
 
@@ -208,87 +196,12 @@ void LruCacheShared<Key, Data, Hash>::insert(const Key& key, const Data& data)
     {
         current_size++;
     }
+    return data;
 }
 
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::find(const Key& key, Data& data, bool update)
-{
-    LruMapIter map_iter;
-    std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
-    map_iter = map.find(key);
-    if (map_iter == map.end())
-    {
-        stats.find_misses++;
-        return false;   //  Key is not in LruCache.
-    }
-
-    data = map_iter->second->second;
-
-    //  If needed, move entry to front of LruList
-    if (update)
-        list.splice(list.begin(), list, map_iter->second);
-
-    stats.find_hits++;
-    return true;
-}
-
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::remove(const Key& key)
-{
-    LruMapIter map_iter;
-    std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
-    map_iter = map.find(key);
-    if (map_iter == map.end())
-        return false;   //  Key is not in LruCache.
-
-    current_size--;
-    list.erase(map_iter->second);
-    map.erase(map_iter);
-    stats.removes++;
-    return(true);
-}
-
-template<typename Key, typename Data, typename Hash>
-bool LruCacheShared<Key, Data, Hash>::remove(const Key& key, Data& data)
-{
-    LruMapIter map_iter;
-    std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
-    map_iter = map.find(key);
-    if (map_iter == map.end())
-        return false;   //  Key is not in LruCache.
-
-    data = map_iter->second->second;
-
-    current_size--;
-    list.erase(map_iter->second);
-    map.erase(map_iter);
-    stats.removes++;
-    return(true);
-}
-
-template<typename Key, typename Data, typename Hash>
-void LruCacheShared<Key, Data, Hash>::clear()
-{
-    LruMapIter map_iter;
-    std::lock_guard<std::mutex> cache_lock(cache_mutex);
-
-    for (map_iter = map.begin(); map_iter != map.end(); /* No incr */)
-    {
-        list.erase(map_iter->second);
-
-        //  erase returns next iterator after erased element.
-        map_iter = map.erase(map_iter);
-    }
-
-    current_size = 0;
-    stats.clears++;
-}
-
-template<typename Key, typename Data, typename Hash>
-std::vector<std::pair<Key, Data> > LruCacheShared<Key, Data, Hash>::get_all_data()
+template<typename Key, typename Value, typename Hash>
+std::vector< std::pair<Key, std::shared_ptr<Value>> >
+LruCacheShared<Key, Value, Hash>::get_all_data()
 {
     std::vector<std::pair<Key, Data> > vec;
     std::lock_guard<std::mutex> cache_lock(cache_mutex);
index 3ca3185f521d014ded1603c0b0ec8e795a1e8b64..89832a1509a59236c6b6d1431a373d49e6c12a0a 100644 (file)
@@ -34,175 +34,73 @@ TEST_GROUP(lru_cache_shared)
 {
 };
 
-//  Test LruCacheShared constructor and member access.
-TEST(lru_cache_shared, constructor_test)
-{
-    LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
-
-    CHECK(lru_cache.get_max_size() == 5);
-    CHECK(lru_cache.size() == 0);
-}
-
-//  Test LruCacheShared insert, find and get_all_data functions.
+//  Test LruCacheShared find, operator[], and get_all_data functions.
 TEST(lru_cache_shared, insert_test)
 {
-    std::string data;
-    LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
+    LruCacheShared<int, std::string, std::hash<int> > lru_cache(3);
 
-    lru_cache.insert(0, "zero");
-    CHECK(true == lru_cache.find(0, data));
-    CHECK("zero" == data);
-
-    lru_cache.insert(1, "one");
-    CHECK(true == lru_cache.find(1, data));
-    CHECK("one" == data);
-
-    lru_cache.insert(2, "two");
-    CHECK(true == lru_cache.find(2, data));
-    CHECK("two" == data);
-
-    //  Verify find fails for non-existent item.
-    CHECK(false == lru_cache.find(3, data));
-
-    //  Verify that insert will replace data if key exists already.
-    lru_cache.insert(1, "new one");
-    CHECK(true == lru_cache.find(1, data));
-    CHECK("new one" == data);
-
-    //  Verify current number of entries in cache.
-    CHECK(3 == lru_cache.size());
-
-    //  Verify that the data is in LRU order.
-    auto vec = lru_cache.get_all_data();
-    CHECK(3 == vec.size());
-    CHECK((vec[0] == std::make_pair(1, std::string("new one"))));
-    CHECK((vec[1] == std::make_pair(2, std::string("two"))));
-    CHECK((vec[2] == std::make_pair(0, std::string("zero"))));
-}
+    auto data = lru_cache[0];
+    CHECK(data == lru_cache.find(0));
+    data->assign("zero");
 
-//  Test that the least recently used items are removed when we exceed
-//  the capacity of the LruCache.
-TEST(lru_cache_shared, lru_removal_test)
-{
-    LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
+    data = lru_cache[1];
+    CHECK(data == lru_cache.find(1));
+    data->assign("one");
 
-    for (int i = 0; i < 10; i++)
-    {
-        lru_cache.insert(i, std::to_string(i));
-    }
-
-    CHECK(5 == lru_cache.size());
-
-    //  Verify that the data is in LRU order and is correct.
-    auto vec = lru_cache.get_all_data();
-    CHECK(5 == vec.size());
-    CHECK((vec[0] == std::make_pair(9, std::string("9"))));
-    CHECK((vec[1] == std::make_pair(8, std::string("8"))));
-    CHECK((vec[2] == std::make_pair(7, std::string("7"))));
-    CHECK((vec[3] == std::make_pair(6, std::string("6"))));
-    CHECK((vec[4] == std::make_pair(5, std::string("5"))));
-}
+    data = lru_cache[2];
+    CHECK(data == lru_cache.find(2));
+    data->assign("two");
 
-//  Test the remove and clear functions.
-TEST(lru_cache_shared, remove_test)
-{
-    std::string data;
-    LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
+    // Replace existing
+    data = lru_cache[0];
+    data->assign("new_zero");
 
-    for (int i = 0; i < 5; i++)
-    {
-        lru_cache.insert(i, std::to_string(i));
-        CHECK(true == lru_cache.find(i, data));
-        CHECK(data == std::to_string(i));
-
-        CHECK(true == lru_cache.remove(i));
-        CHECK(false == lru_cache.find(i, data));
-    }
-
-    CHECK(0 == lru_cache.size());
-
-    //  Test remove API that returns the removed data.
-    lru_cache.insert(1, "one");
-    CHECK(1 == lru_cache.size());
-    CHECK(true == lru_cache.remove(1, data));
-    CHECK(data == "one");
-    CHECK(0 == lru_cache.size());
-
-    lru_cache.insert(1, "one");
-    lru_cache.insert(2, "two");
-    CHECK(2 == lru_cache.size());
-
-    //  Verify that removing an item that does not exist does not affect
-    //  cache.
-    CHECK(false == lru_cache.remove(3));
-    CHECK(false == lru_cache.remove(4, data));
-    CHECK(2 == lru_cache.size());
-
-    auto vec = lru_cache.get_all_data();
-    CHECK(2 == vec.size());
-    CHECK((vec[0] == std::make_pair(2, std::string("two"))));
-    CHECK((vec[1] == std::make_pair(1, std::string("one"))));
-
-    //  Verify that clear() removes all entries.
-    lru_cache.clear();
-    CHECK(0 == lru_cache.size());
-
-    vec = lru_cache.get_all_data();
-    CHECK(vec.empty());
+    //  Verify find fails for non-existent item.
+    CHECK(nullptr == lru_cache.find(3));
+
+    // Replace least recently used when capacity exceeds
+    data = lru_cache[3];
+    CHECK(data == lru_cache.find(3));
+    data->assign("three");
+
+    //  Verify current number of entries in cache and LRU order.
+    const auto&& vec = lru_cache.get_all_data();
+    CHECK(vec.size() == 3);
+    CHECK(vec[0].second->compare("three") == 0);
+    CHECK(vec[1].second->compare("new_zero") == 0);
+    CHECK(vec[2].second->compare("two") == 0);
 }
 
-//  Test statistics counters.
+//  Test statistics counters and set_max_size.
 TEST(lru_cache_shared, stats_test)
 {
-    std::string data;
     LruCacheShared<int, std::string, std::hash<int> > lru_cache(5);
 
     for (int i = 0; i < 10; i++)
-    {
-        lru_cache.insert(i, std::to_string(i));
-    }
-
-    lru_cache.insert(8, "new-eight");  //  Replace entries.
-    lru_cache.insert(9, "new-nine");
-
-    CHECK(5 == lru_cache.size());
-
-    lru_cache.find(7, data);     //  Hits
-    lru_cache.find(8, data);
-    lru_cache.find(9, data);
-
-    lru_cache.remove(7);
-    lru_cache.remove(8);
-    lru_cache.remove(9, data);
-    CHECK("new-nine" == data);
+        lru_cache[i];
 
-    lru_cache.find(8, data);    //  Misses now that they're removed.
-    lru_cache.find(9, data);
+    lru_cache.find(7);     //  Hits
+    lru_cache.find(8);
+    lru_cache.find(9);
 
-    lru_cache.remove(100);  //  Removing a non-existent entry does not
-                            //  increase remove count.
+    lru_cache.find(10);    //  Misses; in addition to previous 10
+    lru_cache.find(11);
 
-    lru_cache.clear();
+    CHECK(lru_cache.set_max_size(3) == true); // change size prunes; in addition to previous 5
 
     PegCount* stats = lru_cache.get_counts();
 
     CHECK(stats[0] == 10);  //  adds
-    CHECK(stats[1] == 2);   //  replaces
-    CHECK(stats[2] == 5);   //  prunes
-    CHECK(stats[3] == 3);   //  find hits
-    CHECK(stats[4] == 2);   //  find misses
-    CHECK(stats[5] == 3);   //  removes
-    CHECK(stats[6] == 1);   //  clears
+    CHECK(stats[1] == 7);   //  prunes
+    CHECK(stats[2] == 3);   //  find hits
+    CHECK(stats[3] == 12);  //  find misses
 
     // Check statistics names.
     const PegInfo* pegs = lru_cache.get_pegs();
     CHECK(!strcmp(pegs[0].name, "lru_cache_adds"));
-    CHECK(!strcmp(pegs[1].name, "lru_cache_replaces"));
-    CHECK(!strcmp(pegs[2].name, "lru_cache_prunes"));
-    CHECK(!strcmp(pegs[3].name, "lru_cache_find_hits"));
-    CHECK(!strcmp(pegs[4].name, "lru_cache_find_misses"));
-    CHECK(!strcmp(pegs[5].name, "lru_cache_removes"));
-    CHECK(!strcmp(pegs[6].name, "lru_cache_clears"));
+    CHECK(!strcmp(pegs[1].name, "lru_cache_prunes"));
+    CHECK(!strcmp(pegs[2].name, "lru_cache_find_hits"));
+    CHECK(!strcmp(pegs[3].name, "lru_cache_find_misses"));
 }
 
 int main(int argc, char** argv)
index 450d46eb5cbb7af404fd6b32ae638113c4006958..9a55ab808410157d5a0b6697237dc76f380b93c7 100644 (file)
 
 #include "host_cache.h"
 
-#include "main/snort_config.h"
-#include "target_based/snort_protocols.h"
-
 using namespace snort;
-using namespace std;
 
 #define LRU_CACHE_INITIAL_SIZE 65535
 
-LruCacheShared<HostIpKey, std::shared_ptr<HostTracker>, HashHostIpKey>
-    host_cache(LRU_CACHE_INITIAL_SIZE);
-
-namespace snort
-{
-
-void host_cache_add_host_tracker(HostTracker* ht)
-{
-    std::shared_ptr<HostTracker> sptr(ht);
-    host_cache.insert((const uint8_t*) ht->get_ip_addr().get_ip6_ptr(), sptr);
-}
-
-bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, SnortProtocolId id)
-{
-    HostIpKey ipkey((const uint8_t*) ipaddr.get_ip6_ptr());
-    std::shared_ptr<HostTracker> ht;
-
-    if (!host_cache.find(ipkey, ht))
-    {
-        //  This host hasn't been seen.  Add it.
-        ht = std::make_shared<HostTracker>(ipaddr);
-
-        if (ht == nullptr)
-        {
-            //  FIXIT-L add error count
-            return false;
-        }
-        host_cache.insert(ipkey, ht);
-    }
-
-    HostApplicationEntry app_entry(ipproto, port, id);
-    return ht->add_service(app_entry);
-}
-
-bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, const char* service)
-{
-    return host_cache_add_service(ipaddr, ipproto, port,
-        SnortConfig::get_conf()->proto_ref->find(service));
-}
-
-bool host_cache_add_app_mapping(const SfIp& ipaddr, Port port, Protocol proto, AppId appId)
-{
-    HostIpKey ipkey((const uint8_t*) ipaddr.get_ip6_ptr());
-    std::shared_ptr<HostTracker> ht;
-
-    if (!host_cache.find(ipkey, ht))
-    {
-        ht = std::make_shared<HostTracker> (ipaddr);
-
-        if (ht == nullptr)
-        {
-            return false;
-        }
-        ht->add_app_mapping(port, proto, appId);
-        host_cache.insert(ipkey, ht);
-    }
-    else
-    {
-        return ht->find_else_add_app_mapping(port, proto, appId);
-    }
-
-    return true;
-}
-
-AppId host_cache_find_app_mapping(const SfIp* ipaddr, Port port, Protocol proto)
-{
-    HostIpKey ipkey((const uint8_t*) ipaddr->get_ip6_ptr());
-    std::shared_ptr<HostTracker> ht;
-
-    if (host_cache.find(ipkey, ht))
-    {
-        return ht->find_app_mapping(port, proto);
-    }
-    
-    return APP_ID_NONE;
-}
-}
+LruCacheShared<SfIp, HostTracker, HashIp> host_cache(LRU_CACHE_INITIAL_SIZE);
index d73290645b06c14d81b3291a8b726259e9646f89..725146acde234b5bdba35a4ecc59a97fa759ccaf 100644 (file)
 // The host cache is used to cache information about hosts so that it can
 // be shared among threads.
 
-#include <memory>
-
 #include "hash/lru_cache_shared.h"
 #include "host_tracker/host_tracker.h"
-
-#define HOST_IP_KEY_SIZE 16
-
-struct HostIpKey
-{
-    union host_ip_addr
-    {
-        uint8_t ip8[HOST_IP_KEY_SIZE];
-        uint64_t ip64[HOST_IP_KEY_SIZE/8];
-    } ip_addr = {{0}}; //  Holds either IPv4 or IPv6 addr
-
-    HostIpKey() = default;
-
-    HostIpKey(const uint8_t ip[HOST_IP_KEY_SIZE])
-    {
-        memcpy(&ip_addr, ip, HOST_IP_KEY_SIZE);
-    }
-
-    inline bool operator==(const HostIpKey& rhs) const
-    {
-        return !memcmp(&ip_addr, &rhs.ip_addr, HOST_IP_KEY_SIZE);
-    }
-};
+#include "sfip/sf_ip.h"
 
 //  Used to create hash of key for indexing into cache.
-struct HashHostIpKey
+struct HashIp
 {
-    size_t operator()(const HostIpKey& ip) const
+    size_t operator()(const snort::SfIp& ip) const
     {
-        return std::hash<uint64_t>() (ip.ip_addr.ip64[0]) ^
-               std::hash<uint64_t>() (ip.ip_addr.ip64[1]);
+        const uint64_t* ip64 = (const uint64_t*) ip.get_ip6_ptr();
+        return std::hash<uint64_t>() (ip64[0]) ^
+               std::hash<uint64_t>() (ip64[1]);
     }
 };
 
-extern LruCacheShared<HostIpKey, std::shared_ptr<HostTracker>, HashHostIpKey> host_cache;
-
-namespace snort
-{
-void host_cache_add_host_tracker(HostTracker*);
-
-//  Insert a new service into host cache if it doesn't already exist.
-SO_PUBLIC bool host_cache_add_service(const SfIp&, Protocol, Port, SnortProtocolId);
-SO_PUBLIC bool host_cache_add_service(const SfIp&, Protocol, Port, const char*);
+extern SO_PUBLIC LruCacheShared<snort::SfIp, HostTracker, HashIp> host_cache;
 
-bool host_cache_add_app_mapping(const SfIp&, Port, Protocol, AppId);
-AppId host_cache_find_app_mapping(const SfIp* , Port, Protocol );
-}
 #endif
 
index a22c4b673f4df563116b550659dd2f4b304ae5f9..f253644f5a4b9e5e8219bae34ec9313c96cc588c 100644 (file)
@@ -123,7 +123,6 @@ HostCacheModule::~HostCacheModule()
         log_host_cache(dump_file);
         snort_free((void*)dump_file);
     }
-    host_cache.clear();
 }
 
 void HostCacheModule::log_host_cache(const char* file_name, bool verbose)
@@ -153,9 +152,12 @@ void HostCacheModule::log_host_cache(const char* file_name, bool verbose)
     }
 
     string str;
+    SfIpString ip_str;
     const auto&& lru_data = host_cache.get_all_data();
     for ( const auto& elem : lru_data )
     {
+        str = "IP: ";
+        str += elem.first.ntop(ip_str);
         elem.second->stringify(str);
         out_stream << str << endl << endl;
     }
index c117563faeeabd516791e0d1e8267ed6cee13834..d0911ab4be8a4b280772c4ba44bd39e8ba2d18f1 100644 (file)
@@ -29,165 +29,58 @@ using namespace std;
 
 THREAD_LOCAL struct HostTrackerStats host_tracker_stats;
 
-snort::SfIp HostTracker::get_ip_addr()
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    return ip_addr;
-}
-
-void HostTracker::set_ip_addr(const snort::SfIp& new_ip_addr)
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    std::memcpy(&ip_addr, &new_ip_addr, sizeof(ip_addr));
-}
-
-Policy HostTracker::get_stream_policy()
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    return stream_policy;
-}
-
-void HostTracker::set_stream_policy(const Policy& policy)
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    stream_policy = policy;
-}
-
-Policy HostTracker::get_frag_policy()
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    return frag_policy;
-}
-
-void HostTracker::set_frag_policy(const Policy& policy)
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    frag_policy = policy;
-}
-
-void HostTracker::add_app_mapping(Port port, Protocol proto, AppId appid)
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    AppMapping app_map = {port, proto, appid};
-
-    app_mappings.emplace_back(app_map);
-}
-
-AppId HostTracker::find_app_mapping(Port port, Protocol proto)
+bool HostTracker::add_service(Port port, IpProtocol proto, AppId appid, bool inferred_appid)
 {
+    host_tracker_stats.service_adds++;
     std::lock_guard<std::mutex> lck(host_tracker_lock);
-    for (std::vector<AppMapping>::iterator it=app_mappings.begin(); it!=app_mappings.end(); ++it)
-    {
-        if (it->port == port and it->proto ==proto)
-        {
-            return it->appid;
-        }
-    }
-    return APP_ID_NONE;
-}
 
-bool HostTracker::find_else_add_app_mapping(Port port, Protocol proto, AppId appid)
-{
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-    for (std::vector<AppMapping>::iterator it=app_mappings.begin(); it!=app_mappings.end(); ++it)
+    for ( auto& s : services )
     {
-        if (it->port == port and it->proto ==proto)
+        if ( s.port == port and s.proto == proto )
         {
-            return false;
+            if ( s.appid != appid and appid != APP_ID_NONE )
+            {
+                s.appid = appid;
+                s.inferred_appid = inferred_appid;
+            }
+            return true;
         }
     }
-    AppMapping app_map = {port, proto, appid};
-
-    app_mappings.emplace_back(app_map);
-    return true;
-}
-
-bool HostTracker::add_service(const HostApplicationEntry& app_entry)
-{
-    host_tracker_stats.service_adds++;
-
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-
-    auto iter = std::find(services.begin(), services.end(), app_entry);
-    if (iter != services.end())
-        return false;   //  Already exists.
 
-    services.emplace_front(app_entry);
+    services.emplace_back( HostApplication{port, proto, appid, inferred_appid} );
     return true;
 }
 
-void HostTracker::add_or_replace_service(const HostApplicationEntry& app_entry)
+AppId HostTracker::get_appid(Port port, IpProtocol proto, bool inferred_only)
 {
-    host_tracker_stats.service_adds++;
-
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-
-    auto iter = std::find(services.begin(), services.end(), app_entry);
-    if (iter != services.end())
-        services.erase(iter);
-
-    services.emplace_front(app_entry);
-}
-
-bool HostTracker::find_service(Protocol ipproto, Port port, HostApplicationEntry& app_entry)
-{
-    HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID);
     host_tracker_stats.service_finds++;
-
     std::lock_guard<std::mutex> lck(host_tracker_lock);
 
-    auto iter = std::find(services.begin(), services.end(), tmp_entry);
-    if (iter != services.end())
+    for ( const auto& s : services )
     {
-        app_entry = *iter;
-        return true;
+        if ( s.port == port and s.proto == proto and
+            (!inferred_only or s.inferred_appid == inferred_only) )
+            return s.appid;
     }
 
-    return false;
-}
-
-bool HostTracker::remove_service(Protocol ipproto, Port port)
-{
-    HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID);
-    host_tracker_stats.service_removes++;
-
-    std::lock_guard<std::mutex> lck(host_tracker_lock);
-
-    auto iter = std::find(services.begin(), services.end(), tmp_entry);
-    if (iter != services.end())
-    {
-        services.erase(iter);
-        return true;   //  Assumes only one matching entry.
-    }
-
-    return false;
+    return APP_ID_NONE;
 }
 
 void HostTracker::stringify(string& str)
 {
-    str = "IP: ";
-    SfIpString ip_str;
-    str += ip_addr.ntop(ip_str);
-
-    if ( !app_mappings.empty() )
-    {
-        str += "\napp_mappings size: " + to_string(app_mappings.size());
-        for ( const auto& elem : app_mappings )
-            str += "\n    port: " + to_string(elem.port)
-                + ", proto: " + to_string(elem.proto)
-                + ", appid: " + to_string(elem.appid);
-    }
-
-    if ( stream_policy or frag_policy )
-        str += "\nstream policy: " + to_string(stream_policy)
-            + ", frag policy: " + to_string(frag_policy);
-
     if ( !services.empty() )
     {
         str += "\nservices size: " + to_string(services.size());
-        for ( const auto& elem : services )
-            str += "\n    port: " + to_string(elem.port)
-                + ", proto: " + to_string(elem.ipproto)
-                + ", snort proto: " + to_string(elem.snort_protocol_id);
-    }
+        for ( const auto& s : services )
+        {
+            str += "\n    port: " + to_string(s.port)
+                + ", proto: " + to_string((uint8_t) s.proto);
+            if ( s.appid != APP_ID_NONE )
+            {
+                str += ", appid: " + to_string(s.appid);
+                if ( s.inferred_appid )
+                    str += ", inferred";
+            }
+        }
+   }
 }
index eafc9fbdcd0e7982effab7d6df7aab7e224d4412..0c22509bec9f3c7d05bf35601d62e751eb71388e 100644 (file)
 // configuration or dynamic discovery).  It provides a thread-safe API to
 // set/get the host data.
 
-#include <algorithm>
-#include <cstring>
-#include <list>
 #include <mutex>
+#include <vector>
+
 #include "framework/counts.h"
 #include "main/snort_types.h"
 #include "main/thread.h"
 #include "network_inspectors/appid/application_ids.h"
 #include "protocols/protocol_ids.h"
-#include "sfip/sf_ip.h"
-#include "target_based/snort_protocols.h"
-
-//  FIXIT-M For now this emulates the Snort++ attribute table.
-//  Need to add in host_tracker.h data eventually.
-
-typedef uint16_t Protocol;
-typedef uint8_t Policy;
 
 struct HostTrackerStats
 {
     PegCount service_adds;
     PegCount service_finds;
-    PegCount service_removes;
 };
 
 extern THREAD_LOCAL struct HostTrackerStats host_tracker_stats;
 
-struct HostApplicationEntry
-{
-    Port port = 0;
-    Protocol ipproto = 0;
-    SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID;
-
-    HostApplicationEntry() = default;
-
-    HostApplicationEntry(Protocol ipproto_param, Port port_param, SnortProtocolId protocol_param) :
-        port(port_param),
-        ipproto(ipproto_param),
-        snort_protocol_id(protocol_param)
-    {
-    }
-
-    inline bool operator==(const HostApplicationEntry& rhs) const
-    {
-        return ipproto == rhs.ipproto and port == rhs.port;
-    }
-};
-
-struct AppMapping
+struct HostApplication
 {
     Port port;
-    Protocol proto;
+    IpProtocol proto;
     AppId appid;
+    bool inferred_appid;
 };
 
-class HostTracker
+class SO_PUBLIC HostTracker
 {
 public:
-    HostTracker()
-    { memset(&ip_addr, 0, sizeof(ip_addr)); }
-
-    HostTracker(const snort::SfIp& new_ip_addr)
-    { std::memcpy(&ip_addr, &new_ip_addr, sizeof(ip_addr)); }
-
-    snort::SfIp get_ip_addr();
-    void set_ip_addr(const snort::SfIp& new_ip_addr);
-    Policy get_stream_policy();
-    void set_stream_policy(const Policy& policy);
-    Policy get_frag_policy();
-    void set_frag_policy(const Policy& policy);
-    void add_app_mapping(Port port, Protocol proto, AppId appid);
-    AppId find_app_mapping(Port port, Protocol proto);
-    bool find_else_add_app_mapping(Port port, Protocol proto, AppId appid);
-
-    //  Add host service data only if it doesn't already exist.  Returns
-    //  false if entry exists already, and true if entry was added.
-    bool add_service(const HostApplicationEntry& app_entry);
+    // Appid may not be identified always. Inferred means dynamic/runtime
+    // appid detected from one flow to another flow such as BitTorrent.
+    bool add_service(Port port, IpProtocol proto,
+        AppId appid = APP_ID_NONE, bool inferred_appid = false);
 
-    //  Add host service data if it doesn't already exist.  If it does exist
-    //  replace the previous entry with the new entry.
-    void add_or_replace_service(const HostApplicationEntry& app_entry);
-
-    //  Returns true and fills in copy of HostApplicationEntry when found.
-    //  Returns false when not found.
-    bool find_service(Protocol ipproto, Port port, HostApplicationEntry& app_entry);
-
-    //  Removes HostApplicationEntry object associated with ipproto and port.
-    //  Returns true if entry existed.  False otherwise.
-    bool remove_service(Protocol ipproto, Port port);
+    AppId get_appid(Port port, IpProtocol proto, bool inferred_only = false);
 
     //  This should be updated whenever HostTracker data members are changed
     void stringify(std::string& str);
@@ -122,15 +67,7 @@ private:
     //  Ensure that updates to a shared object are safe
     std::mutex host_tracker_lock;
 
-    //  FIXIT-M do we need to use a host_id instead of SfIp as in sfrna?
-    snort::SfIp ip_addr;
-    std::vector< AppMapping > app_mappings;
-
-    //  Policies to apply to this host.
-    Policy stream_policy = 0;
-    Policy frag_policy = 0;
-
-    std::list<HostApplicationEntry> services;
+    std::vector<HostApplication> services;
 };
 
 #endif
index 07d74426ab40ea58e7d9a60bc67eacc523c0c6b2..f56c03ca40c4db7e2ba8b24d2cc1be8640fe3b11 100644 (file)
 
 #include "host_tracker_module.h"
 
+#include "log/messages.h"
 #include "main/snort_config.h"
-#include "stream/stream.h"
-#include "target_based/snort_protocols.h"
-
-#include "host_cache.h"
 
 using namespace snort;
 
@@ -36,34 +33,21 @@ const PegInfo host_tracker_pegs[] =
 {
     { CountType::SUM, "service_adds", "host service adds" },
     { CountType::SUM, "service_finds", "host service finds" },
-    { CountType::SUM, "service_removes", "host service removes" },
     { CountType::END, nullptr, nullptr },
 };
 
 const Parameter HostTrackerModule::service_params[] =
 {
-    { "name", Parameter::PT_STRING, nullptr, nullptr,
-      "service identifier" },
-
-    { "proto", Parameter::PT_ENUM, "tcp | udp", "tcp",
-      "IP protocol" },
+    { "port", Parameter::PT_PORT, nullptr, nullptr, "port number" },
 
-    { "port", Parameter::PT_PORT, nullptr, nullptr,
-      "port number" },
+    { "proto", Parameter::PT_ENUM, "ip | tcp | udp", nullptr, "IP protocol" },
 
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
 const Parameter HostTrackerModule::host_tracker_params[] =
 {
-    { "ip", Parameter::PT_ADDR, nullptr, "0.0.0.0/32",
-      "hosts address / cidr" },
-
-    { "frag_policy", Parameter::PT_ENUM, IP_POLICIES, nullptr,
-      "defragmentation policy" },
-
-    { "tcp_policy", Parameter::PT_ENUM, TCP_POLICIES, nullptr,
-      "TCP reassembly policy" },
+    { "ip", Parameter::PT_ADDR, nullptr, nullptr, "hosts address / cidr" },
 
     { "services", Parameter::PT_LIST, HostTrackerModule::service_params, nullptr,
       "list of service parameters" },
@@ -71,29 +55,21 @@ const Parameter HostTrackerModule::host_tracker_params[] =
     { nullptr, Parameter::PT_MAX, nullptr, nullptr, nullptr }
 };
 
-bool HostTrackerModule::set(const char*, Value& v, SnortConfig* sc)
+bool HostTrackerModule::set(const char*, Value& v, SnortConfig*)
 {
-    if ( host and v.is("ip") )
-    {
-        SfIp addr;
+    if ( v.is("ip") )
         v.get_addr(addr);
-        host->set_ip_addr(addr);
-    }
-    else if ( host and v.is("frag_policy") )
-        host->set_frag_policy(v.get_uint8() + 1);
-
-    else if ( host and v.is("tcp_policy") )
-        host->set_stream_policy(v.get_uint8() + 1);
-
-    else if ( v.is("name") )
-        app.snort_protocol_id = sc->proto_ref->add(v.get_string());
-
-    else if ( v.is("proto") )
-        app.ipproto = sc->proto_ref->add(v.get_string());
 
     else if ( v.is("port") )
         app.port = v.get_uint16();
 
+    else if ( v.is("proto") )
+    {
+        const IpProtocol mask[] =
+        { IpProtocol::IP, IpProtocol::TCP, IpProtocol::UDP };
+        app.proto = mask[v.get_uint8()];
+    }
+
     else
         return false;
 
@@ -103,8 +79,10 @@ bool HostTrackerModule::set(const char*, Value& v, SnortConfig* sc)
 bool HostTrackerModule::begin(const char* fqn, int idx, SnortConfig*)
 {
     if ( idx && !strcmp(fqn, "host_tracker") )
-        host = new HostTracker;
-
+    {
+        addr.clear();
+        app = {};
+    }
     return true;
 }
 
@@ -112,13 +90,14 @@ bool HostTrackerModule::end(const char* fqn, int idx, SnortConfig*)
 {
     if ( idx && !strcmp(fqn, "host_tracker.services") )
     {
-        host->add_service(app);
+        if ( addr.is_set() )
+            host_cache[addr]->add_service(app.port, app.proto);
         app = {};
     }
-    else if ( idx && !strcmp(fqn, "host_tracker") )
+    else if ( idx && !strcmp(fqn, "host_tracker") && addr.is_set() )
     {
-        host_cache_add_host_tracker(host);
-        host = nullptr;  //  Host cache is now responsible for freeing host
+        host_cache[addr];
+        addr.clear();
     }
 
     return true;
index 6637d3d5df2bcbff19e6b74025e225aca6117e03..00f400030a54fc0cebe60cdfd1f623d3296e3ea4 100644 (file)
@@ -30,7 +30,7 @@
 #include <cassert>
 
 #include "framework/module.h"
-#include "host_tracker/host_tracker.h"
+#include "host_tracker/host_cache.h"
 
 #define host_tracker_help \
     "configure hosts"
 class HostTrackerModule : public snort::Module
 {
 public:
-    HostTrackerModule() : snort::Module("host_tracker", host_tracker_help, host_tracker_params, true)
-    { host = nullptr; }
-
-    ~HostTrackerModule() override
-    { assert(!host); }
+    HostTrackerModule() :
+        snort::Module("host_tracker", host_tracker_help, host_tracker_params, true) { }
 
     const PegInfo* get_pegs() const override;
     PegCount* get_counts() const override;
@@ -58,8 +55,8 @@ private:
     static const snort::Parameter host_tracker_params[];
     static const snort::Parameter service_params[];
 
-    HostApplicationEntry app;
-    HostTracker* host;
+    HostApplication app;
+    snort::SfIp addr;
 };
 
 #endif
index f8f3b28e9fbd42423fdf0f210e855605d9c6756d..5266dbb6df2e97bde04096afbd748a846645db6e 100644 (file)
@@ -46,9 +46,6 @@ static char logged_message[LOG_MAX+1];
 
 namespace snort
 {
-//  Fakes to avoid bringing in a ton of dependencies.
-SnortProtocolId ProtocolReference::add(char const*) { return 0; }
-SnortProtocolId ProtocolReference::find(char const*) { return 0; }
 SnortConfig* SnortConfig::get_conf() { return nullptr; }
 char* snort_strdup(const char* s) { return strdup(s); }
 Module* ModuleManager::get_module(const char*) { return nullptr; }
@@ -73,9 +70,6 @@ void show_stats(PegCount*, const PegInfo*, unsigned, const char*)
 void show_stats(PegCount*, const PegInfo*, IndexVec&, const char*, FILE*)
 { }
 
-#define FRAG_POLICY 33
-#define STREAM_POLICY 100
-
 TEST_GROUP(host_cache_module)
 {
     void setup() override
@@ -98,21 +92,15 @@ TEST(host_cache_module, host_cache_module_test_values)
     const PegCount* ht_stats = module.get_counts();
 
     CHECK(!strcmp(ht_pegs[0].name, "lru_cache_adds"));
-    CHECK(!strcmp(ht_pegs[1].name, "lru_cache_replaces"));
-    CHECK(!strcmp(ht_pegs[2].name, "lru_cache_prunes"));
-    CHECK(!strcmp(ht_pegs[3].name, "lru_cache_find_hits"));
-    CHECK(!strcmp(ht_pegs[4].name, "lru_cache_find_misses"));
-    CHECK(!strcmp(ht_pegs[5].name, "lru_cache_removes"));
-    CHECK(!strcmp(ht_pegs[6].name, "lru_cache_clears"));
-    CHECK(!ht_pegs[7].name);
+    CHECK(!strcmp(ht_pegs[1].name, "lru_cache_prunes"));
+    CHECK(!strcmp(ht_pegs[2].name, "lru_cache_find_hits"));
+    CHECK(!strcmp(ht_pegs[3].name, "lru_cache_find_misses"));
+    CHECK(!ht_pegs[4].name);
 
     CHECK(ht_stats[0] == 0);
     CHECK(ht_stats[1] == 0);
     CHECK(ht_stats[2] == 0);
     CHECK(ht_stats[3] == 0);
-    CHECK(ht_stats[4] == 0);
-    CHECK(ht_stats[5] == 0);
-    CHECK(ht_stats[6] == 0);
 
     size_val.set(&size_param);
 
@@ -123,8 +111,6 @@ TEST(host_cache_module, host_cache_module_test_values)
 
     ht_stats = module.get_counts();
     CHECK(ht_stats[0] == 0);
-
-    CHECK(2112 == host_cache.get_max_size());
 }
 
 TEST(host_cache_module, log_host_cache_messages)
index db835d894a6269cb0793640d8cd242b551e2f801..f177349ac694fc228dfa3d8e0cb22b279ec959ea 100644 (file)
@@ -25,6 +25,8 @@
 
 #include "host_tracker/host_cache.h"
 
+#include <cstring>
+
 #include "main/snort_config.h"
 
 #include <CppUTest/CommandLineTestRunner.h>
@@ -34,23 +36,6 @@ using namespace snort;
 
 namespace snort
 {
-SnortConfig s_conf;
-THREAD_LOCAL SnortConfig* snort_conf = &s_conf;
-SnortConfig::SnortConfig(const SnortConfig* const) { }
-SnortConfig::~SnortConfig() = default;
-SnortConfig* SnortConfig::get_conf()
-{ return snort_conf; }
-
-SnortProtocolId ProtocolReference::find(char const*) { return 0; }
-SnortProtocolId ProtocolReference::add(const char* protocol)
-{
-    if (!strcmp("servicename", protocol))
-        return 3;
-    if (!strcmp("tcp", protocol))
-        return 2;
-    return 1;
-}
-
 char* snort_strdup(const char* str)
 {
     return strdup(str);
@@ -61,23 +46,7 @@ TEST_GROUP(host_cache)
 {
 };
 
-//  Test HostIpKey
-TEST(host_cache, host_ip_key_test)
-{
-    HostIpKey zeroed_hk;
-    uint8_t expected_hk[16] =
-    { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
-
-    memset(&zeroed_hk.ip_addr, 0, sizeof(zeroed_hk.ip_addr));
-
-    HostIpKey hkey1;
-    CHECK(hkey1 == zeroed_hk);
-
-    HostIpKey hkey2(expected_hk);
-    CHECK(hkey2 == expected_hk);
-}
-
-//  Test HashHostIpKey
+//  Test HashIp
 TEST(host_cache, hash_test)
 {
     size_t expected_hash_val = 4521729;
@@ -85,174 +54,16 @@ TEST(host_cache, hash_test)
     uint8_t hk[16] =
     { 0x0a,0xff,0x12,0x00,0x00,0x00,0x00,0x00,0x0b,0x00,0x56,0x00,0x00,0x00,0x00,0x00 };
 
-    HashHostIpKey hash_hk;
+    HashIp hash_hk;
+    SfIp ip;
 
-    actual_hash_val = hash_hk(hk);
+    ip.set(hk);
+    actual_hash_val = hash_hk(ip);
     CHECK(actual_hash_val == expected_hash_val);
 }
 
-//  Test host_cache_add_host_tracker
-TEST(host_cache, host_cache_add_host_tracker_test)
-{
-    HostTracker* expected_ht = new HostTracker;
-    std::shared_ptr<HostTracker> actual_ht;
-    uint8_t hk[16] =
-    { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
-    SfIp ip_addr;
-    SfIp actual_ip_addr;
-    HostIpKey hkey(hk);
-    Port port = 443;
-    Protocol proto = 6;
-    HostApplicationEntry app_entry(proto, port, 2);
-    HostApplicationEntry actual_app_entry;
-    bool ret;
-
-    ip_addr.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-
-    expected_ht->set_ip_addr(ip_addr);
-    expected_ht->add_service(app_entry);
-
-    host_cache_add_host_tracker(expected_ht);
-
-    ret = host_cache.find((const uint8_t*) ip_addr.get_ip6_ptr(), actual_ht);
-    CHECK(true == ret);
-
-    actual_ip_addr = actual_ht->get_ip_addr();
-    CHECK(!memcmp(&ip_addr, &actual_ip_addr, sizeof(ip_addr)));
-
-    ret = actual_ht->find_service(proto, port, actual_app_entry);
-    CHECK(true == ret);
-    CHECK(actual_app_entry == app_entry);
-
-    host_cache.clear();     //  Free HostTracker objects
-}
-
-//  Test host_cache_add_service
-TEST(host_cache, host_cache_add_service_test)
-{
-    HostTracker* expected_ht = new HostTracker;
-    std::shared_ptr<HostTracker> actual_ht;
-    uint8_t hk[16] =
-    { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
-    SfIp ip_addr1;
-    SfIp ip_addr2;
-    HostIpKey hkey(hk);
-    Port port1 = 443;
-    Port port2 = 22;
-    Protocol proto1 = 17;
-    Protocol proto2 = 6;
-    HostApplicationEntry app_entry1(proto1, port1, 1);
-    HostApplicationEntry app_entry2(proto2, port2, 2);
-    HostApplicationEntry actual_app_entry;
-    bool ret;
-
-    ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-    ip_addr2.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-
-    //  Initialize cache with a HostTracker.
-    host_cache_add_host_tracker(expected_ht);
-
-    //  Add a service to a HostTracker that already exists.
-    ret = host_cache_add_service(ip_addr1, proto1, port1, "udp");
-    CHECK(true == ret);
-
-    ret = host_cache.find((const uint8_t*) ip_addr1.get_ip6_ptr(), actual_ht);
-    CHECK(true == ret);
-
-    ret = actual_ht->find_service(proto1, port1, actual_app_entry);
-    CHECK(true == ret);
-    CHECK(actual_app_entry == app_entry1);
-
-    //  Try adding service again (should fail since service exists).
-    ret = host_cache_add_service(ip_addr1, proto1, port1, "udp");
-    CHECK(false == ret);
-
-    //  Add a service with a new IP.
-    ret = host_cache_add_service(ip_addr2, proto2, port2, "tcp");
-    CHECK(true == ret);
-
-    ret = host_cache.find((const uint8_t*) ip_addr1.get_ip6_ptr(), actual_ht);
-    CHECK(true == ret);
-
-    ret = actual_ht->find_service(proto2, port2, actual_app_entry);
-    CHECK(true == ret);
-    CHECK(actual_app_entry == app_entry2);
-
-    host_cache.clear();     //  Free HostTracker objects
-}
-
-TEST(host_cache, host_cache_app_mapping_test )
-{
-    HostTracker* expected_ht = new HostTracker;
-    std::shared_ptr<HostTracker> actual_ht;
-    uint8_t hk[16] =
-        { 0xde,0xad,0xbe,0xef,0xab,0xcd,0xef,0x01,0x23,0x34,0x56,0x78,0x90,0xab,0xcd,0xef };
-    SfIp ip_addr1;
-    SfIp ip_addr2;
-    HostIpKey hkey(hk);
-    Port port1 = 4123;
-    Port port2 = 1827;
-    Protocol proto1 = 6;
-    Protocol proto2 = 7;
-    AppId appid1 = 61;
-    AppId appid2 = 62;
-    AppId appid3 = 63;
-    AppId ret;
-    bool add_ret;
-
-    ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-    ip_addr2.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:901b");
-
-    // Initialize cache with a HostTracker.
-    host_cache_add_host_tracker(expected_ht);
-
-    add_ret = host_cache_add_app_mapping(ip_addr1, port1, proto1, appid1);
-    CHECK(true == add_ret);
-
-    ret = host_cache_find_app_mapping(&ip_addr1, port1, proto1);
-    CHECK(61 == ret);
-
-    ret = host_cache_find_app_mapping(&ip_addr1, port2, proto1);
-    CHECK(APP_ID_NONE == ret);
-
-    ret = host_cache_find_app_mapping(&ip_addr2, port1, proto1);
-    CHECK(APP_ID_NONE == ret);
-
-    add_ret = host_cache_add_app_mapping(ip_addr1, port2, proto1, appid2);
-    CHECK(true == add_ret);
-    ret = host_cache_find_app_mapping(&ip_addr1, port2, proto1);
-    CHECK(62 == ret);
-    
-    add_ret = host_cache_add_app_mapping(ip_addr1, port1, proto2, appid3);
-    CHECK(true == add_ret);
-    ret = host_cache_find_app_mapping(&ip_addr1, port1, proto2);
-    CHECK(63 == ret);
-
-    host_cache.clear();     // Free HostTracker objects
-}
-
 int main(int argc, char** argv)
 {
-    SfIp ip_addr1;
-    Protocol proto1 = 17;
-    Port port1 = 443;
-
-    ip_addr1.pton(AF_INET6, "beef:dead:beef:abcd:ef01:2334:5678:90ab");
-
-    //  This is necessary to prevent the cpputest memory leak
-    //  detection from thinking there's a memory leak in the map
-    //  object contained within the global host_cache.  The map
-    //  must have some data allocated when it is first created
-    //  that doesn't go away until the global map object is
-    //  deallocated. This pre-allocates the map so that initial
-    //  allocation is done prior to starting the tests.  The same
-    //  is true for the list used when adding a service.
-    HostTracker* ht = new HostTracker;
-    host_cache_add_host_tracker(ht);
-    host_cache_add_service(ip_addr1, proto1, port1, "udp");
-
-    host_cache.clear();
-
     //  Use this if you want to turn off memory checks entirely:
     // MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
 
index 2d5e36f7ac3d1402fdb1d87ff8695c3d41dc8c79..33e9da5a1ba5ab3b82045a5d7cd329e6011be11d 100644 (file)
 #include "config.h"
 #endif
 
+#include <cstring>
+
 #include "host_tracker/host_cache.h"
 #include "host_tracker/host_tracker_module.h"
-#include "target_based/snort_protocols.h"
 #include "main/snort_config.h"
+#include "target_based/snort_protocols.h"
 
 #include <CppUTest/CommandLineTestRunner.h>
 #include <CppUTest/TestHarness.h>
@@ -35,17 +37,6 @@ using namespace snort;
 
 namespace snort
 {
-SnortConfig* SnortConfig::get_conf() { return nullptr; }
-SnortProtocolId ProtocolReference::find(char const*) { return 0; }
-SnortProtocolId ProtocolReference::add(const char* protocol)
-{
-    if (!strcmp("servicename", protocol))
-        return 3;
-    if (!strcmp("tcp", protocol))
-        return 2;
-    return 1;
-}
-
 char* snort_strdup(const char* s)
 { return strdup(s); }
 }
@@ -54,9 +45,6 @@ char* snort_strdup(const char* s)
 void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
 void show_stats(PegCount*, const PegInfo*, IndexVec&, const char*, FILE*) { }
 
-#define FRAG_POLICY 33
-#define STREAM_POLICY 100
-
 SfIp expected_addr;
 
 TEST_GROUP(host_tracker_module)
@@ -64,15 +52,9 @@ TEST_GROUP(host_tracker_module)
     void setup() override
     {
         Value ip_val("10.23.45.56");
-        Value frag_val((double)FRAG_POLICY);
-        Value tcp_val((double)STREAM_POLICY);
-        Value name_val("servicename");
         Value proto_val("udp");
         Value port_val((double)2112);
         Parameter ip_param = { "ip", Parameter::PT_ADDR, nullptr, "0.0.0.0/32", "addr/cidr"};
-        Parameter frag_param = { "frag_policy", Parameter::Parameter::PT_ENUM, "linux | bsd", nullptr, "frag policy"};
-        Parameter tcp_param = { "tcp_policy", Parameter::PT_ENUM, "linux | bsd", nullptr, "tcp policy"};
-        Parameter name_param = {"name", Parameter::PT_STRING, nullptr, nullptr, "name"};
         Parameter proto_param = {"proto", Parameter::PT_ENUM, "tcp | udp", "tcp", "ip proto"};
         Parameter port_param = {"port", Parameter::PT_PORT, nullptr, nullptr, "port num"};
         HostTrackerModule module;
@@ -81,17 +63,12 @@ TEST_GROUP(host_tracker_module)
 
         CHECK(!strcmp(ht_pegs[0].name, "service_adds"));
         CHECK(!strcmp(ht_pegs[1].name, "service_finds"));
-        CHECK(!strcmp(ht_pegs[2].name, "service_removes"));
-        CHECK(!ht_pegs[3].name);
+        CHECK(!ht_pegs[2].name);
 
         CHECK(ht_stats[0] == 0);
         CHECK(ht_stats[1] == 0);
-        CHECK(ht_stats[2] == 0);
 
         ip_val.set(&ip_param);
-        frag_val.set(&frag_param);
-        tcp_val.set(&tcp_param);
-        name_val.set(&name_param);
         proto_val.set(&proto_param);
         port_val.set(&port_param);
 
@@ -101,12 +78,6 @@ TEST_GROUP(host_tracker_module)
         // Set up the module values and add a service.
         module.begin("host_tracker", 1, nullptr);
         module.set(nullptr, ip_val, nullptr);
-        module.set(nullptr, frag_val, nullptr);
-        module.set(nullptr, tcp_val, nullptr);
-
-        // FIXIT-M see FIXIT-M below
-        //module.set(nullptr, name_val, nullptr);
-        //module.set(nullptr, proto_val, nullptr);
 
         module.set(nullptr, port_val, nullptr);
         module.end("host_tracker.services", 1, nullptr);
@@ -118,7 +89,6 @@ TEST_GROUP(host_tracker_module)
     void teardown() override
     {
         memset(&host_tracker_stats, 0, sizeof(host_tracker_stats));
-        host_cache.clear();    //  Free HostTracker objects
     }
 };
 
@@ -127,70 +97,9 @@ TEST(host_tracker_module, host_tracker_module_test_basic)
     CHECK(true);
 }
 
-#if 0
-// FIXIT-M the below are more functional in scope because they require host_cache
-// services.  need to stub this out better to focus on the module only.
-//  Test that HostTrackerModule variables are set correctly.
-TEST(host_tracker_module, host_tracker_module_test_values)
-{
-    SfIp cached_addr;
-
-    HostIpKey host_ip_key((const uint8_t*) expected_addr.get_ip6_ptr());
-    std::shared_ptr<HostTracker> ht;
-
-    bool ret = host_cache.find(host_ip_key, ht);
-    CHECK(ret == true);
-
-    cached_addr = ht->get_ip_addr();
-    CHECK(cached_addr.fast_equals_raw(expected_addr) == true);
-
-    Policy policy = ht->get_stream_policy();
-    CHECK(policy == STREAM_POLICY + 1);
-
-    policy = ht->get_frag_policy();
-    CHECK(policy == FRAG_POLICY + 1);
-}
-
-
-//  Test that HostTrackerModule statistics are correct.
-TEST(host_tracker_module, host_tracker_module_test_stats)
-{
-    HostIpKey host_ip_key((const uint8_t*) expected_addr.get_ip6_ptr());
-    std::shared_ptr<HostTracker> ht;
-
-    bool ret = host_cache.find(host_ip_key, ht);
-    CHECK(ret == true);
-
-    HostApplicationEntry app;
-    ret = ht->find_service(1, 2112, app);
-    CHECK(ret == true);
-    CHECK(app.protocol == 3);
-    CHECK(app.ipproto == 1);
-    CHECK(app.port == 2112);
-
-    ret = ht->remove_service(1, 2112);
-    CHECK(ret == true);
-
-    //  Verify counts are correct.  The add was done during setup.
-    CHECK(host_tracker_stats.service_adds == 1);
-    CHECK(host_tracker_stats.service_finds == 1);
-    CHECK(host_tracker_stats.service_removes == 1);
-}
-#endif
-
 int main(int argc, char** argv)
 {
-    //  This is necessary to prevent the cpputest memory leak
-    //  detection from thinking there's a memory leak in the map
-    //  object contained within the global host_cache.  The map
-    //  must have some data allocated when it is first created
-    //  that doesn't go away until the global map object is
-    //  deallocated. This pre-allocates the map so that initial
-    //  allocation is done prior to starting the tests.
-    HostTracker* ht = new HostTracker;
-    host_cache_add_host_tracker(ht);
-    host_cache.clear();
-
+    MemoryLeakWarningPlugin::turnOffNewDeleteOverloads();
     return CommandLineTestRunner::RunAllTests(argc, argv);
 }
 
index 5acd3b4548009e9f63a5aef1bdbaaa0f57e56906..ae6eb4400c18af6799a2f0620523bacb21484802 100644 (file)
@@ -23,6 +23,8 @@
 #include "config.h"
 #endif
 
+#include <cstring>
+
 #include "host_tracker/host_tracker.h"
 
 #include <CppUTest/CommandLineTestRunner.h>
@@ -42,159 +44,40 @@ TEST_GROUP(host_tracker)
 {
 };
 
-//  Test HostTracker ipaddr get/set functions.
-TEST(host_tracker, ipaddr_test)
-{
-    HostTracker ht;
-    SfIp zeroed_sfip;
-    SfIp expected_ip_addr;
-    SfIp actual_ip_addr;
-
-    //  Test IP prior to set.
-    memset(&zeroed_sfip, 0, sizeof(zeroed_sfip));
-    actual_ip_addr = ht.get_ip_addr();
-    CHECK(0 == memcmp(&zeroed_sfip, &actual_ip_addr, sizeof(zeroed_sfip)));
-
-    expected_ip_addr.pton(AF_INET6, "beef:abcd:ef01:2300::");
-    ht.set_ip_addr(expected_ip_addr);
-    actual_ip_addr = ht.get_ip_addr();
-    CHECK(0 == memcmp(&expected_ip_addr, &actual_ip_addr, sizeof(expected_ip_addr)));
-}
-
-//  Test HostTracker policy get/set functions.
-TEST(host_tracker, policy_test)
-{
-    HostTracker ht;
-    Policy expected_policy = 23;
-    Policy actual_policy;
-
-    actual_policy = ht.get_stream_policy();
-    CHECK(0 == actual_policy);
-
-    actual_policy = ht.get_frag_policy();
-    CHECK(0 == actual_policy);
-
-    ht.set_stream_policy(expected_policy);
-    actual_policy = ht.get_stream_policy();
-    CHECK(expected_policy == actual_policy);
-
-    expected_policy = 77;
-    ht.set_frag_policy(expected_policy);
-    actual_policy = ht.get_frag_policy();
-    CHECK(expected_policy == actual_policy);
-}
-
-TEST(host_tracker, app_mapping_test)
-{
-    HostTracker ht;
-    const uint16_t expected_ports = 4123;
-    const uint16_t port1 = 4122;
-    AppId actual_appid;
-    Protocol expected_proto1 = 6; 
-    Protocol expected_proto2 = 17; 
-    AppId appid1 = 61;
-    AppId appid2 = 62;
-    bool ret;
-
-    ht.add_app_mapping(expected_ports, expected_proto1, appid1);
-
-    actual_appid = ht.find_app_mapping(expected_ports, expected_proto1);
-    CHECK(61 == actual_appid);
-
-    actual_appid = ht.find_app_mapping(expected_ports, expected_proto2);
-    CHECK(APP_ID_NONE == actual_appid);
-
-    actual_appid = ht.find_app_mapping(port1, expected_proto2);
-    CHECK(APP_ID_NONE == actual_appid);
-
-    ret = ht.find_else_add_app_mapping(port1, expected_proto1, appid2);
-    CHECK(true == ret);
-
-    ret = ht.find_else_add_app_mapping(port1, expected_proto1, appid2);
-    CHECK(false == ret);
-}
-
-//  Test HostTracker add and find service functions.
+//  Test HostTracker find appid and add service functions.
 TEST(host_tracker, add_find_service_test)
 {
-    bool ret;
     HostTracker ht;
-    HostApplicationEntry actual_entry;
-    HostApplicationEntry app_entry1(6, 2112, 3);
-    HostApplicationEntry app_entry2(17, 7777, 10);
 
     //  Try a find on an empty list.
-    ret = ht.find_service(3,1000, actual_entry);
-    CHECK(false == ret);
+    CHECK(APP_ID_NONE == ht.get_appid(80, IpProtocol::TCP));
 
     //  Test add and find.
-    ret = ht.add_service(app_entry1);
-    CHECK(true == ret);
-
-    ret = ht.find_service(6, 2112, actual_entry);
-    CHECK(true == ret);
-    CHECK(actual_entry.port == 2112);
-    CHECK(actual_entry.ipproto == 6);
-    CHECK(actual_entry.snort_protocol_id == 3);
-
-    ht.add_service(app_entry2);
-    ret = ht.find_service(6, 2112, actual_entry);
-    CHECK(true == ret);
-    CHECK(actual_entry.port == 2112);
-    CHECK(actual_entry.ipproto == 6);
-    CHECK(actual_entry.snort_protocol_id == 3);
-
-    ret = ht.find_service(17, 7777, actual_entry);
-    CHECK(true == ret);
-    CHECK(actual_entry.port == 7777);
-    CHECK(actual_entry.ipproto == 17);
-    CHECK(actual_entry.snort_protocol_id == 10);
-
-    //  Try adding an entry that exists already.
-    ret = ht.add_service(app_entry1);
-    CHECK(false == ret);
-
-    // Try a find on a port that isn't in the list.
-    ret = ht.find_service(6, 100, actual_entry);
-    CHECK(false == ret);
-
-    // Try a find on an ipproto that isn't in the list.
-    ret = ht.find_service(17, 2112, actual_entry);
-    CHECK(false == ret);
-
-    //  Try to remove an entry that's not in the list.
-    ret = ht.remove_service(6,100);
-    CHECK(false == ret);
-
-    ret = ht.remove_service(17,2112);
-    CHECK(false == ret);
-
-    //  Actually remove an entry.
-    ret = ht.remove_service(6,2112);
-    CHECK(true == ret);
+    CHECK(true == ht.add_service(80, IpProtocol::TCP, 676, true));
+    CHECK(true == ht.add_service(443, IpProtocol::TCP, 1122));
+    CHECK(676 == ht.get_appid(80, IpProtocol::TCP));
+    CHECK(1122 == ht.get_appid(443, IpProtocol::TCP));
+
+    //  Try adding an entry that exists already and update appid
+    CHECK(true == ht.add_service(443, IpProtocol::TCP, 847));
+    CHECK(847 == ht.get_appid(443, IpProtocol::TCP));
+
+    // Try a find appid on a port that isn't in the list.
+    CHECK(APP_ID_NONE == ht.get_appid(8080, IpProtocol::UDP));
 }
 
 TEST(host_tracker, stringify)
 {
-    SfIp ip;
-    ip.pton(AF_INET6, "feed:dead:beef::");
-    HostTracker ht(ip);
-    ht.add_app_mapping(80, 6, 676);
-    ht.add_app_mapping(443, 6, 1122);
-    ht.set_frag_policy(3);
-    HostApplicationEntry app_entry(6, 80, 10);
-    ht.add_service(app_entry);
+    HostTracker ht;
+    ht.add_service(80, IpProtocol::TCP, 676, true);
+    ht.add_service(443, IpProtocol::TCP, 1122);
     string host_tracker_string;
 
     ht.stringify(host_tracker_string);
     STRCMP_EQUAL(host_tracker_string.c_str(),
-        "IP: feed:dead:beef:0000:0000:0000:0000:0000\n"
-        "app_mappings size: 2\n"
-        "    port: 80, proto: 6, appid: 676\n"
-        "    port: 443, proto: 6, appid: 1122\n"
-        "stream policy: 0, frag policy: 3\n"
-        "services size: 1\n"
-        "    port: 80, proto: 6, snort proto: 10");
+        "\nservices size: 2"
+        "\n    port: 80, proto: 6, appid: 676, inferred"
+        "\n    port: 443, proto: 6, appid: 1122");
 }
 
 int main(int argc, char** argv)
index aae4e48f7432e03a6d80031ad5bab791e054275c..ced22d2f31696af1d0a49d42be07a96a13225acd 100644 (file)
@@ -638,11 +638,15 @@ static void lookup_appid_by_host_port(AppIdSession& asd, Packet* p, IpProtocol p
     }
     else if (asd.config->mod_config->is_host_port_app_cache_runtime)
     {
-        AppId appid = snort::host_cache_find_app_mapping(ip, port, (Protocol)protocol);
-        if (appid > APP_ID_NONE)
+        auto ht = host_cache.find(*ip);
+        if ( ht )
         {
-            asd.client.set_id(appid);
-            asd.client_disco_state = APPID_DISCO_STATE_FINISHED;
+            AppId appid = ht->get_appid(port, protocol, true);
+            if ( appid > APP_ID_NONE )
+            {
+                asd.client.set_id(appid);
+                asd.client_disco_state = APPID_DISCO_STATE_FINISHED;
+            }
         }
     }
 }
index ee1c869becc125fcd04dfd29c66252acf7a5f3be..b45baa02821d51f12f91c0372dbc13ea69818b0e 100644 (file)
@@ -1181,15 +1181,15 @@ static int detector_add_host_port_dynamic(lua_State* L)
     }
 
     unsigned port = lua_tointeger(L, ++index);
-    unsigned proto = lua_tointeger(L, ++index);
-    if (proto > (unsigned)IpProtocol::RESERVED)
+    IpProtocol proto = (IpProtocol) lua_tointeger(L, ++index);
+    if (proto > IpProtocol::RESERVED)
     {
-        ErrorMessage("%s:Invalid protocol value %u\n",__func__, proto);
+        ErrorMessage("%s:Invalid protocol value %u\n",__func__, (unsigned) proto);
         return 0;
     }
 
-    if (!(snort::host_cache_add_app_mapping(ip_addr, port, proto, appid)))
-        ErrorMessage("%s:Failed to add app mapping\n",__func__);
+    if ( !host_cache[ip_addr]->add_service(port, proto, appid, true) )
+        ErrorMessage("%s:Failed to add host tracker service\n",__func__);
 
     return 0;
 }
index 8e843ab156420d58e8d3ab80b21b5bf88bdc27f3..d893805d562f3dc4187f9158b2f201277e7b100f 100644 (file)
@@ -218,7 +218,13 @@ ServiceDiscovery& ServiceDiscovery::get_instance()
         s_discovery_manager = new ServiceDiscovery();
     return *s_discovery_manager;
 }
-AppId snort::host_cache_find_app_mapping(snort::SfIp const*, Port, Protocol){ return 0; }
+
+LruCacheShared<SfIp, HostTracker, HashIp> host_cache(50);
+AppId HostTracker::get_appid(Port, IpProtocol, bool)
+{
+    return APP_ID_NONE;
+}
+
 // Stubs for ClientDiscovery
 ClientDiscovery::ClientDiscovery(){}
 ClientDiscovery::~ClientDiscovery() {}
index 3708456c89e51b5fdb9bb097059900684ebe50a3..10d004065256b4aaf3ae078a737342f4f3c4675d 100644 (file)
@@ -105,16 +105,16 @@ void RnaPnd::analyze_flow_udp(const Packet* p)
 
 void RnaPnd::discover_network_icmp(const Packet* p)
 {
-    if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
-        p->flow->client_port, SNORT_PROTO_ICMP) )
+    if ( !(host_cache[p->flow->client_ip]->
+        add_service(p->flow->client_port, p->get_ip_proto_next())) )
         return;
     // process rna discovery for icmp
 }
 
 void RnaPnd::discover_network_ip(const Packet* p)
 {
-    if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
-        p->flow->client_port, SNORT_PROTO_IP) )
+    if ( !(host_cache[p->flow->client_ip]->
+        add_service(p->flow->client_port, p->get_ip_proto_next())) )
         return;
     // process rna discovery for ip
 }
@@ -128,8 +128,8 @@ void RnaPnd::discover_network_non_ip(const Packet* p)
 void RnaPnd::discover_network_tcp(const Packet* p)
 {
     // Track from initiator direction, if not already seen
-    if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
-        p->flow->client_port, SNORT_PROTO_TCP) )
+    if ( !(host_cache[p->flow->client_ip]->
+        add_service(p->flow->client_port, p->get_ip_proto_next())) )
         return;
 
     // Add mac address to ht list, ttl, last_seen, etc.
@@ -138,8 +138,8 @@ void RnaPnd::discover_network_tcp(const Packet* p)
 
 void RnaPnd::discover_network_udp(const Packet* p)
 {
-    if ( !host_cache_add_service(p->flow->client_ip, (uint8_t)p->get_ip_proto_next(),
-        p->flow->client_port, SNORT_PROTO_UDP) )
+    if ( !(host_cache[p->flow->client_ip]->
+        add_service(p->flow->client_port, p->get_ip_proto_next())) )
         return;
     // process rna discovery for udp
 }
index c9502ceddb43ed3d3aa52d5727fd2fed7109abbe..768adfb74bfd151af53a2be8e21b5957aec0e40c 100644 (file)
@@ -263,11 +263,9 @@ bool Wizard::spellbind(
 {
     f->service = m->book.find_spell(data, len, m);
 
-    if (f->service != nullptr)
+    if ( f->service != nullptr )
     {
-        // FIXIT-H need to make sure Flow's ipproto and service
-        // correspond to HostApplicationEntry's ipproto and service
-        host_cache_add_service(f->server_ip, f->ip_proto, f->server_port, f->service);
+        host_cache[f->server_ip]->add_service(f->server_port, (IpProtocol)f->ip_proto);
         return true;
     }
 
@@ -282,10 +280,11 @@ bool Wizard::cursebind(vector<CurseServiceTracker>& curse_tracker, Flow* f,
         if (cst.curse->alg(data, len, cst.tracker))
         {
             f->service = cst.curse->service.c_str();
-            // FIXIT-H need to make sure Flow's ipproto and service
-            // correspond to HostApplicationEntry's ipproto and service
-            host_cache_add_service(f->server_ip, f->ip_proto, f->server_port, f->service);
-            return true;
+            if ( f->service != nullptr )
+            {
+                host_cache[f->server_ip]->add_service(f->server_port, (IpProtocol)f->ip_proto);
+                return true;
+            }
         }
     }
 
index b0fde364bcb26f71bafee057f7e6b3c81352e16b..6eef4fa210c300012bbf9e12935a88ea9a99ff19 100644 (file)
@@ -81,6 +81,7 @@ struct SO_PUBLIC SfIp
     bool fast_gt6(const SfIp& ip2) const;
     bool fast_eq6(const SfIp& ip2) const;
     bool fast_equals_raw(const SfIp& ip2) const;
+    bool operator==(const SfIp& ip2) const;
 
     /*
      * Miscellaneous
@@ -463,6 +464,11 @@ inline bool SfIp::fast_equals_raw(const SfIp& ip2) const
     return false;
 }
 
+inline bool SfIp::operator==(const SfIp& ip2) const
+{
+    return fast_equals_raw(ip2);
+}
+
 /* End of member function definitions */
 
 SO_PUBLIC const char* sfip_ntop(const SfIp* ip, char* buf, int bufsize);