]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2693 in SNORT/snort3 from ~SMINUT/snort3:host_cache_rna to master
authorMasud Hasan (mashasan) <mashasan@cisco.com>
Tue, 26 Jan 2021 21:40:51 +0000 (21:40 +0000)
committerMasud Hasan (mashasan) <mashasan@cisco.com>
Tue, 26 Jan 2021 21:40:51 +0000 (21:40 +0000)
Squashed commit of the following:

commit ec7f9504910ba29d2899c7669f833195b29fd6dd
Author: Silviu Minut <sminut@cisco.com>
Date:   Fri Jan 8 10:55:59 2021 -0500

    rna: Minimize synchronization overhead

    Avoid some locks during network discovery in order to increase speed,
    by caching the host trackers locally in the RNAFlow, in a way in which
    the cached host trackers do not spill memory into the host cache during
    pruning.

21 files changed:
src/hash/lru_cache_shared.h
src/host_tracker/host_cache.h
src/host_tracker/host_tracker.cc
src/host_tracker/host_tracker.h
src/host_tracker/test/CMakeLists.txt
src/host_tracker/test/host_cache_allocator_ht_test.cc
src/host_tracker/test/host_cache_allocator_test.cc
src/host_tracker/test/host_tracker_test.cc
src/network_inspectors/appid/test/CMakeLists.txt
src/network_inspectors/appid/test/appid_discovery_test.cc
src/network_inspectors/rna/CMakeLists.txt
src/network_inspectors/rna/rna_app_discovery.cc
src/network_inspectors/rna/rna_fingerprint_tcp.cc
src/network_inspectors/rna/rna_fingerprint_tcp.h
src/network_inspectors/rna/rna_flow.cc [new file with mode: 0644]
src/network_inspectors/rna/rna_flow.h [new file with mode: 0644]
src/network_inspectors/rna/rna_inspector.cc
src/network_inspectors/rna/rna_logger.h
src/network_inspectors/rna/rna_pnd.cc
src/network_inspectors/rna/test/rna_flow_mock.cc [new file with mode: 0644]
src/network_inspectors/rna/test/rna_module_mock.h

index 8e7c03c6e1e9815066694e6bac38885f1118a163..905da594c1f35b230e4a6ffddf2c57e0829cc8a7 100644 (file)
@@ -48,7 +48,8 @@ struct LruCacheSharedStats
     PegCount replaced = 0;      // found entry and replaced it
 };
 
-template<typename Key, typename Value, typename Hash, typename Eq = std::equal_to<Key>>
+template<typename Key, typename Value, typename Hash, typename Eq = std::equal_to<Key>,
+    typename Purgatory = std::vector<std::shared_ptr<Value>>>
 class LruCacheShared
 {
 public:
@@ -67,6 +68,7 @@ public:
 
     using Data = std::shared_ptr<Value>;
     using ValueType = Value;
+    using KeyType = Key;
 
     // Return data entry associated with key. If doesn't exist, return nullptr.
     Data find(const Key& key);
@@ -108,12 +110,12 @@ public:
 
     //  Remove entry associated with Key.
     //  Returns true if entry existed, false otherwise.
-    bool remove(const Key& key);
+    virtual bool remove(const Key& key);
 
     //  Remove entry associated with key and return removed data.
     //  Returns true and copy of data if entry existed.  Returns false if
     //  entry did not exist.
-    bool remove(const Key& key, Data& data);
+    virtual bool remove(const Key& key, Data& data);
 
     const PegInfo* get_pegs() const
     { return lru_cache_shared_peg_names; }
@@ -166,7 +168,7 @@ protected:
 
     // Caller must lock and unlock. Don't use this during snort reload for which
     // we need gradual pruning and size reduction via reload resource tuner.
-    void prune(std::vector<Data>& data)
+    void prune(Purgatory& data)
     {
         LruListIter list_iter;
         assert(data.empty());
@@ -182,8 +184,8 @@ protected:
     }
 };
 
-template<typename Key, typename Value, typename Hash, typename Eq>
-bool LruCacheShared<Key, Value, Hash, Eq>::set_max_size(size_t newsize)
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
+bool LruCacheShared<Key, Value, Hash, Eq, Purgatory>::set_max_size(size_t newsize)
 {
     if (newsize == 0)
         return false;   //  Not allowed to set size to zero.
@@ -191,7 +193,7 @@ bool LruCacheShared<Key, Value, Hash, Eq>::set_max_size(size_t newsize)
     // Like with remove(), we need local temporary references to data being
     // deleted, to avoid race condition. This data needs to self-destruct
     // after the cache_lock does.
-    std::vector<Data> data;
+    Purgatory data;
 
     std::lock_guard<std::mutex> cache_lock(cache_mutex);
 
@@ -203,8 +205,8 @@ bool LruCacheShared<Key, Value, Hash, Eq>::set_max_size(size_t newsize)
     return true;
 }
 
-template<typename Key, typename Value, typename Hash, typename Eq>
-std::shared_ptr<Value> LruCacheShared<Key, Value, Hash, Eq>::find(const Key& key)
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
+std::shared_ptr<Value> LruCacheShared<Key, Value, Hash, Eq, Purgatory>::find(const Key& key)
 {
     LruMapIter map_iter;
     std::lock_guard<std::mutex> cache_lock(cache_mutex);
@@ -222,14 +224,14 @@ std::shared_ptr<Value> LruCacheShared<Key, Value, Hash, Eq>::find(const Key& key
     return map_iter->second->second;
 }
 
-template<typename Key, typename Value, typename Hash, typename Eq>
-std::shared_ptr<Value> LruCacheShared<Key, Value, Hash, Eq>::operator[](const Key& key)
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
+std::shared_ptr<Value> LruCacheShared<Key, Value, Hash, Eq, Purgatory>::operator[](const Key& key)
 {
     return find_else_create(key, nullptr);
 }
 
-template<typename Key, typename Value, typename Hash, typename Eq>
-std::shared_ptr<Value> LruCacheShared<Key, Value, Hash, Eq>::
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
+std::shared_ptr<Value> LruCacheShared<Key, Value, Hash, Eq, Purgatory>::
 find_else_create(const Key& key, bool* new_data)
 {
     LruMapIter map_iter;
@@ -240,7 +242,7 @@ find_else_create(const Key& key, bool* new_data)
     // unlocking the cache_mutex, because the cache must be locked when we
     // return the data pointer (below), or else, some other thread might
     // delete it before we got a chance to return it.
-    std::vector<Data> tmp_data;
+    Purgatory tmp_data;
 
     std::lock_guard<std::mutex> cache_lock(cache_mutex);
 
@@ -270,13 +272,13 @@ find_else_create(const Key& key, bool* new_data)
     return data;
 }
 
-template<typename Key, typename Value, typename Hash, typename Eq>
-bool LruCacheShared<Key, Value, Hash, Eq>::
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
+bool LruCacheShared<Key, Value, Hash, Eq, Purgatory>::
 find_else_insert(const Key& key, std::shared_ptr<Value>& data, bool replace)
 {
     LruMapIter map_iter;
 
-    std::vector<Data> tmp_data;
+    Purgatory tmp_data;
     std::lock_guard<std::mutex> cache_lock(cache_mutex);
 
     map_iter = map.find(key);
@@ -311,9 +313,9 @@ find_else_insert(const Key& key, std::shared_ptr<Value>& data, bool replace)
     return false;
 }
 
-template<typename Key, typename Value, typename Hash, typename Eq>
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
 std::vector< std::pair<Key, std::shared_ptr<Value>> >
-LruCacheShared<Key, Value, Hash, Eq>::get_all_data()
+LruCacheShared<Key, Value, Hash, Eq, Purgatory>::get_all_data()
 {
     std::vector<std::pair<Key, Data> > vec;
     std::lock_guard<std::mutex> cache_lock(cache_mutex);
@@ -326,8 +328,8 @@ LruCacheShared<Key, Value, Hash, Eq>::get_all_data()
     return vec;
 }
 
-template<typename Key, typename Value, typename Hash, typename Eq>
-bool LruCacheShared<Key, Value, Hash, Eq>::remove(const Key& key)
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
+bool LruCacheShared<Key, Value, Hash, Eq, Purgatory>::remove(const Key& key)
 {
     LruMapIter map_iter;
 
@@ -367,8 +369,9 @@ bool LruCacheShared<Key, Value, Hash, Eq>::remove(const Key& key)
     return true;
 }
 
-template<typename Key, typename Value, typename Hash, typename Eq>
-bool LruCacheShared<Key, Value, Hash, Eq>::remove(const Key& key, std::shared_ptr<Value>& data)
+template<typename Key, typename Value, typename Hash, typename Eq, typename Purgatory>
+bool LruCacheShared<Key, Value, Hash, Eq, Purgatory>::remove(const Key& key,
+    std::shared_ptr<Value>& data)
 {
     LruMapIter map_iter;
 
index 9f4541cc7b247864e64e2a5bbc1951d7a64b8559..4bc1508da555677e5fa9fe358037766aa063dda2 100644 (file)
@@ -58,11 +58,13 @@ struct IpEqualTo
     }
 };
 
-template<typename Key, typename Value, typename Hash, typename Eq = std::equal_to<Key>>
-class LruCacheSharedMemcap : public LruCacheShared<Key, Value, Hash, Eq>, public HostCacheInterface
+template<typename Key, typename Value, typename Hash, typename Eq = std::equal_to<Key>,
+    typename Purgatory = std::vector<std::shared_ptr<Value>>>
+class LruCacheSharedMemcap : public LruCacheShared<Key, Value, Hash, Eq, Purgatory>,
+    public HostCacheInterface
 {
 public:
-    using LruBase = LruCacheShared<Key, Value, Hash, Eq>;
+    using LruBase = LruCacheShared<Key, Value, Hash, Eq, Purgatory>;
     using LruBase::cache_mutex;
     using LruBase::current_size;
     using LruBase::list;
@@ -78,7 +80,7 @@ public:
     LruCacheSharedMemcap(const LruCacheSharedMemcap& arg) = delete;
     LruCacheSharedMemcap& operator=(const LruCacheSharedMemcap& arg) = delete;
 
-    LruCacheSharedMemcap(const size_t initial_size) : LruCacheShared<Key, Value, Hash, Eq>(initial_size),
+    LruCacheSharedMemcap(const size_t sz) : LruCacheShared<Key, Value, Hash, Eq, Purgatory>(sz),
         valid_id(invalid_id+1) {}
 
     size_t mem_size() override
@@ -203,7 +205,7 @@ private:
             // to hold the pruned data until after the cache is unlocked.
             // Do not change the order of data and cache_lock, as the data must
             // self destruct after cache_lock.
-            std::vector<Data> data;
+            Purgatory data;
             std::lock_guard<std::mutex> cache_lock(cache_mutex);
             LruBase::prune(data);
         }
@@ -228,7 +230,55 @@ private:
     friend class TEST_host_cache_module_misc_Test; // for unit test
 };
 
-typedef LruCacheSharedMemcap<snort::SfIp, snort::HostTracker, HashIp, IpEqualTo> HostCacheIp;
+
+class HTPurgatory
+{
+public:
+
+    ~HTPurgatory()
+    {
+        for (auto& ht : data)
+        {
+            ht->remove_flows();
+        }
+    }
+
+    bool empty() const {
+        return data.empty();
+    }
+
+    void emplace_back(std::shared_ptr<snort::HostTracker>& ht)
+    {
+        data.emplace_back(ht);
+    }
+
+    std::vector<std::shared_ptr<snort::HostTracker>> data;
+};
+
+typedef LruCacheSharedMemcap<snort::SfIp, snort::HostTracker, HashIp, IpEqualTo, HTPurgatory>
+    HostCacheIpSpec;
+
+// Since the LruCacheShared and LruCacheSharedMemcap templates make no
+// assumptions about the item, we have to derive our host cache
+// from the specialization, if we want to make use of things within the item.
+class HostCacheIp : public HostCacheIpSpec
+{
+public:
+    HostCacheIp(const size_t initial_size) : HostCacheIpSpec(initial_size) { }
+
+    bool remove(const KeyType& key)
+    {
+        LruBase::Data data;
+        return remove(key, data);
+    }
+
+    bool remove(const KeyType& key, LruBase::Data& data)
+    {
+        bool out = LruBase::remove(key, data);
+        data->remove_flows();
+        return out;
+    }
+};
 
 extern SO_PUBLIC HostCacheIp host_cache;
 
index a887d5ef1c00adb155d9c88e88b1eca8bf70d9b6..74ee2941baf0f009671010e8c096e19f41a35452 100644 (file)
 
 #include <algorithm>
 
+#include "flow/flow.h"
+#include "network_inspectors/rna/rna_flow.h"
+#include "utils/util.h"
+
 #include "host_cache.h"
 #include "host_cache_allocator.cc"
 #include "host_tracker.h"
-#include "utils/util.h"
 
 using namespace snort;
 using namespace std;
@@ -1025,6 +1028,52 @@ HostClient HostTracker::find_or_add_client(AppId id, const char* version, AppId
     return clients.back();
 }
 
+void HostTracker::add_flow(RNAFlow* fd)
+{
+    lock_guard<mutex> lck(flows_lock);
+    flows.insert(fd);
+}
+
+void HostTracker::remove_flow(RNAFlow* fd)
+{
+    lock_guard<mutex> lck(flows_lock);
+    flows.erase(fd);
+}
+
+void HostTracker::remove_flows()
+{
+    // To lock, or not to lock? That is the question!
+    //
+    // The only way we get here is from LRU::update(), called by the allocator.
+    // That is, we only get here from a HT::add_<> operation. All of those
+    // operations lock the HT, so the HT is already locked when we get here.
+    // Also, none of those operations modify the HT::flows set. So we should
+    // not lock the HT (because we'd cause a deadlock), nor do we need to
+    // (because there's no contention on HT::flows from those adds).
+    //
+    // However, this HT could be part of a different rna flow, which could
+    // go out of existence exactly at the time when this thread modifies
+    // the HT::flows set. The rna flow destructor calls on this
+    // HT::remove_flow(), which does modify HT::flows. The for loop itself
+    // does not modify the HT::flows set, but flows.clear() does - whether
+    // or not we call it here explicitly. We, therefore, need to protect the
+    // HT::flows() array with a lock on this host_tracker_lock.
+    //
+    // We have identified two situations with opposite requirements:
+    // one requires locking, the other requires not locking.
+    //
+    // Now, note that the thread contention is not on the host tracker itself,
+    // but on the HT::flows set. This means we may not lock the HT here,
+    // to avoid the deadlock from the first case, but we SHOULD lock on
+    // a different mutex to protect the HT::flows set.
+    lock_guard<mutex> lck(flows_lock);
+    for (auto& rna_flow : flows)
+    {
+        rna_flow->clear_ht(*this);
+    }
+    flows.clear();
+}
+
 HostApplicationInfo::HostApplicationInfo(const char *ver, const char *ven)
 {
     if ( ver )
index 026f6308095e5ca57897fe251471cac88df7b47c..bfee92dcd1ee14c9627e2f27b5029feb3635849f 100644 (file)
@@ -29,6 +29,7 @@
 #include <mutex>
 #include <list>
 #include <set>
+#include <unordered_set>
 #include <vector>
 
 #include "framework/counts.h"
@@ -48,6 +49,8 @@ struct HostTrackerStats
 
 extern THREAD_LOCAL struct HostTrackerStats host_tracker_stats;
 
+class RNAFlow;
+
 namespace snort
 {
 #define INFO_SIZE 32
@@ -79,7 +82,6 @@ struct HostApplicationInfo
     friend class HostTracker;
 private:
     bool visibility = true;
-
 };
 
 typedef HostCacheAllocIp<HostApplicationInfo> HostAppInfoAllocator;
@@ -407,8 +409,14 @@ public:
     }
 #endif
 
+    void add_flow(RNAFlow*);
+    void remove_flows();
+    void remove_flow(RNAFlow*);
+
 private:
+
     mutable std::mutex host_tracker_lock; // ensure that updates to a shared object are safe
+    mutable std::mutex flows_lock;        // protect the flows set separately
     uint8_t hops;                 // hops from the snort inspector, e.g., zero for ARP
     uint32_t last_seen;           // the last time this host was seen
     uint32_t last_event;          // the last time an event was generated
@@ -423,6 +431,9 @@ private:
     std::set<uint32_t, std::less<uint32_t>, HostCacheAllocIp<uint32_t>> udp_fpids;
     std::vector<DeviceFingerprint, HostDeviceFpAllocator> ua_fps;
 
+    // flows that we belong to
+    std::unordered_set<RNAFlow*> flows;
+
     bool vlan_tag_present = false;
     vlan::VlanTagHdr vlan_tag;
     HostType host_type = HOST_TYPE_HOST;
index a32afdf283be989aeee921ac06c9f3072d718461..5e1279733b01f70e786b495a2f57710804ec2ff6 100644 (file)
@@ -3,6 +3,7 @@ add_cpputest( host_cache_test
     SOURCES
         ../host_cache.cc
         ../host_tracker.cc
+        ../../network_inspectors/rna/test/rna_flow_mock.cc
         ../../sfip/sf_ip.cc
 )
 
@@ -14,6 +15,7 @@ add_cpputest( host_cache_module_test
         ../../framework/module.cc
         ../../framework/value.cc
         ../../hash/lru_cache_shared.cc
+        ../../network_inspectors/rna/test/rna_flow_mock.cc
         ../../sfip/sf_ip.cc
         $<TARGET_OBJECTS:catch_tests>
     LIBS
@@ -23,6 +25,7 @@ add_cpputest( host_cache_module_test
 add_cpputest( host_tracker_test
     SOURCES
         ../host_tracker.cc
+        ../../network_inspectors/rna/test/rna_flow_mock.cc
         ../../sfip/sf_ip.cc
 )
 
@@ -34,6 +37,7 @@ add_cpputest( host_tracker_module_test
         ../../framework/module.cc
         ../../framework/parameter.cc
         ../../framework/value.cc
+        ../../network_inspectors/rna/test/rna_flow_mock.cc
         ../../sfip/sf_ip.cc
         $<TARGET_OBJECTS:catch_tests>
     LIBS
@@ -43,10 +47,12 @@ add_cpputest( host_tracker_module_test
 add_cpputest( host_cache_allocator_ht_test
     SOURCES
         ../host_tracker.cc
+        ../../network_inspectors/rna/test/rna_flow_mock.cc
         ../../sfip/sf_ip.cc
 )
 
 add_cpputest( host_cache_allocator_test
     SOURCES
         ../host_tracker.cc
+        ../../network_inspectors/rna/test/rna_flow_mock.cc
 )
index d7746ca36d4281cd2e7484e8ec862ca2cd4db4d3..299abae4397910573f3e397381569860d1885525 100644 (file)
@@ -25,6 +25,7 @@
 
 #include "host_tracker/host_cache.h"
 #include "host_tracker/host_cache_allocator.cc"
+#include "network_inspectors/rna/rna_flow.h"
 
 #include <cstring>
 
index 22614e9fc1d042eb19d9488f7dc29320a31f1343..fa2d327a28ac2445e0d3de8aefaea2f6943874e8 100644 (file)
@@ -24,6 +24,7 @@
 
 #include "host_tracker/host_cache.h"
 #include "host_tracker/host_cache_allocator.cc"
+#include "network_inspectors/rna/rna_flow.h"
 
 #include <string>
 
index b852084938402db12f0f14d470f4694ed35541bc..51d80ea0e162c92333424a4956d4fa12734fa808 100644 (file)
@@ -27,6 +27,7 @@
 
 #include "host_tracker/host_cache.h"
 #include "host_tracker/host_cache_allocator.cc"
+#include "network_inspectors/rna/rna_flow.h"
 
 #include <CppUTest/CommandLineTestRunner.h>
 #include <CppUTest/TestHarness.h>
index e7b7d63c7def1cb328c647122bc6a9110da5a134..35a260a6e3adf88b8889bc304df5bcee9a0ca010 100644 (file)
@@ -44,6 +44,7 @@ add_cpputest( appid_http_session_test
 add_cpputest( tp_lib_handler_test
     SOURCES
         tp_lib_handler_test.cc
+        ../../../network_inspectors/rna/test/rna_flow_mock.cc
         ../tp_lib_handler.cc
     LIBS
         dl
index eff65cbf241a291948e3b48d51ff0f2de67142c6..6bed3a5df9c96ed1939625ba6306cfa4c47fd5a1 100644 (file)
@@ -245,6 +245,8 @@ AppId HostTracker::get_appid(Port, IpProtocol, bool, bool)
     return APP_ID_NONE;
 }
 
+void HostTracker::remove_flows() {}
+
 // Stubs for ClientDiscovery
 void ClientDiscovery::initialize() {}
 void ClientDiscovery::reload() {}
index 85bae0cb37f8a1ceb0c6b25887e6a85a89ed6e28..54854381404a12773f60f9429d6cded63415b079 100644 (file)
@@ -3,6 +3,7 @@ set (RNA_INCLUDES
     rna_fingerprint_tcp.h
     rna_fingerprint_ua.h
     rna_fingerprint_udp.h
+    rna_flow.h
     rna_inspector.h
     rna_logger.h
     rna_name.h
@@ -17,12 +18,11 @@ set ( RNA_SOURCES
     rna_event_handler.cc
     rna_event_handler.h
     rna_fingerprint.cc
-    rna_fingerprint.h
     rna_fingerprint_tcp.cc
     rna_fingerprint_ua.cc
     rna_fingerprint_udp.cc
     rna_inspector.cc
-    rna_inspector.h
+    rna_flow.cc
     rna_logger.cc
     rna_logger_common.h
     rna_mac_cache.cc
index 18151a49d0815729ae500b601d6fd385d2826ea5..0e2580a0dfc13cda8e4a6562dbd87102116d17b4 100644 (file)
 #include "detection/detection_engine.h"
 #include "network_inspectors/appid/appid_session_api.h"
 
+#include "rna_flow.h"
 #include "rna_logger_common.h"
 
 using namespace snort;
 
-RnaTracker RnaAppDiscovery::get_server_rna_tracker(const Packet* p, RNAFlow*)
+RnaTracker RnaAppDiscovery::get_server_rna_tracker(const Packet* p, RNAFlow* rna_flow)
 {
-    return host_cache.find(p->flow->server_ip);
+    return rna_flow->get_server(p->flow->server_ip);
 }
 
-RnaTracker RnaAppDiscovery::get_client_rna_tracker(const Packet* p, RNAFlow*)
+RnaTracker RnaAppDiscovery::get_client_rna_tracker(const Packet* p, RNAFlow* rna_flow)
 {
-    return host_cache.find(p->flow->client_ip);
+    return rna_flow->get_client(p->flow->client_ip);
 }
 
 void RnaAppDiscovery::process(AppidEvent* appid_event, DiscoveryFilter& filter, RnaConfig* conf,
@@ -223,7 +224,7 @@ void RnaAppDiscovery::discover_payload(const Packet* p, DiscoveryFilter& filter,
 
     if ( !srt or !srt->is_visible() )
         return;
-    
+
     srt->update_last_seen();
     if ( conf and conf->max_payloads )
         max_payloads = conf->max_payloads;
@@ -407,7 +408,7 @@ void RnaAppDiscovery::analyze_user_agent_fingerprint(const Packet* p, DiscoveryF
     RNAFlow* rna_flow, const char* host, const char* uagent, RnaLogger& logger,
     UaFpProcessor& processor)
 {
-    if ( !host or !uagent or 
+    if ( !host or !uagent or
         !filter.is_host_monitored(p, nullptr, nullptr, FlowCheckDirection::DF_CLIENT))
         return;
 
@@ -424,7 +425,7 @@ void RnaAppDiscovery::analyze_user_agent_fingerprint(const Packet* p, DiscoveryF
     if ( uafp and rt->add_ua_fingerprint(uafp->fpid, uafp->fp_type, jail_broken,
         device_info, MAX_USER_AGENT_DEVICES) )
     {
-        logger.log(RNA_EVENT_NEW, NEW_OS, p, &rt, 
+        logger.log(RNA_EVENT_NEW, NEW_OS, p, &rt,
             (const struct in6_addr*)p->flow->client_ip.get_ip6_ptr(), rt->get_last_seen_mac(),
             (FpFingerprint*)uafp, packet_time(), device_info, jail_broken);
     }
index 389151915f9e72386462c2231e2ab63b09241773..9a723212ca6494ab5e73f3c73ff9d17792c67653 100644 (file)
@@ -36,6 +36,8 @@
 #include "protocols/tcp.h"
 #include "protocols/tcp_options.h"
 
+#include "rna_flow.h"
+
 using namespace snort;
 using namespace std;
 
index 47d15f133890e0dcf762474649e1fae40d8a2cbf..0307fcc186fe6b9d0a92ce03648bf573e99836c1 100644 (file)
 #ifndef RNA_FINGERPRINT_TCP_H
 #define RNA_FINGERPRINT_TCP_H
 
+#include <mutex>
 #include <unordered_map>
 #include <vector>
 
 #include "main/snort_types.h"
 #include "protocols/packet.h"
 #include "protocols/tcp.h"
+#include "sfip/sf_ip.h"
 
 #include "rna_fingerprint.h"
-#include "rna_logger.h"
 
 class RNAFlow;
 
@@ -137,21 +138,4 @@ struct FpFingerprintState
     bool set(const snort::Packet*);
 };
 
-class RNAFlow : public snort::FlowData
-{
-public:
-    FpFingerprintState state;
-
-
-    RNAFlow() : FlowData(inspector_id) { }
-    ~RNAFlow() override { }
-
-    static void init();
-    size_t size_of() override;
-
-    static unsigned inspector_id;
-    RnaTracker serverht = nullptr;
-    RnaTracker clientht = nullptr;
-};
-
 #endif
diff --git a/src/network_inspectors/rna/rna_flow.cc b/src/network_inspectors/rna/rna_flow.cc
new file mode 100644 (file)
index 0000000..f155f29
--- /dev/null
@@ -0,0 +1,95 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2021-2021 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// rna_flow.cc author Silviu Minut <sminut@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "rna_flow.h"
+
+#include "host_tracker/host_cache.h"
+
+using namespace snort;
+using namespace std;
+
+RNAFlow::~RNAFlow()
+{
+    // Do not call remove_flow() directly on our own server and client
+    // because those might be set to 0 between the null check and remove_flow().
+    // Use temporaries. We still need to lock the rna flow though, but
+    // this won't lead to a deadlock.
+    rna_mutex.lock();
+    auto serverht_loc = serverht;
+    auto clientht_loc = clientht;
+    rna_mutex.unlock();
+
+    if (serverht_loc)
+        serverht_loc->remove_flow(this);
+
+    if (clientht_loc)
+        clientht_loc->remove_flow(this);
+}
+
+void RNAFlow::clear_ht(HostTracker& ht)
+{
+    lock_guard<mutex> lck(rna_mutex);
+    if (&ht == clientht.get())
+        clientht = nullptr;
+    else if (&ht == serverht.get())
+        serverht = nullptr;
+}
+
+RnaTracker RNAFlow::get_server(const SfIp& ip)
+{
+    rna_mutex.lock();
+    auto loc_ht = serverht;
+    rna_mutex.unlock();
+
+    if ( !loc_ht )
+        loc_ht = host_cache.find(ip);
+
+    return loc_ht;
+}
+
+RnaTracker RNAFlow::get_client(const SfIp& ip)
+{
+    rna_mutex.lock();
+    auto loc_ht = clientht;
+    rna_mutex.unlock();
+
+    if ( !loc_ht )
+        loc_ht = host_cache.find(ip);
+
+    return loc_ht;
+}
+
+void RNAFlow::set_server(RnaTracker& ht)
+{
+    rna_mutex.lock();
+    serverht = ht;
+    rna_mutex.unlock();
+}
+
+void RNAFlow::set_client(RnaTracker& ht)
+{
+    rna_mutex.lock();
+    clientht = ht;
+    rna_mutex.unlock();
+}
diff --git a/src/network_inspectors/rna/rna_flow.h b/src/network_inspectors/rna/rna_flow.h
new file mode 100644 (file)
index 0000000..4bfd037
--- /dev/null
@@ -0,0 +1,62 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2021-2021 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// rna_flow.h author Silviu Minut <sminut@cisco.com>
+
+#ifndef RNA_FLOW_H
+#define RNA_FLOW_H
+
+#include <memory>
+#include <mutex>
+
+#include "flow/flow_data.h"
+#include "host_tracker/host_tracker.h"
+#include "sfip/sf_ip.h"
+
+#include "rna_fingerprint_tcp.h"
+
+using RnaTracker = std::shared_ptr<snort::HostTracker>;
+
+class RNAFlow : public snort::FlowData
+{
+public:
+    FpFingerprintState state;
+
+    RNAFlow() : FlowData(inspector_id) { }
+    ~RNAFlow() override;
+
+    static void init();
+    size_t size_of() override;
+
+    void clear_ht(snort::HostTracker& ht);
+
+    static unsigned inspector_id;
+    RnaTracker serverht = nullptr;
+    RnaTracker clientht = nullptr;
+
+    std::mutex rna_mutex;
+
+    RnaTracker get_server(const snort::SfIp&);
+    RnaTracker get_client(const snort::SfIp&);
+
+    void set_server(RnaTracker& ht);
+    void set_client(RnaTracker& ht);
+
+};
+
+#endif
index d227d0bef189a30630d2834c29694152f85a1bfa..2c64ca8de983d6468d49a4ea8bc38da3583b62c3 100644 (file)
@@ -39,6 +39,7 @@
 #include "rna_fingerprint_tcp.h"
 #include "rna_fingerprint_ua.h"
 #include "rna_fingerprint_udp.h"
+#include "rna_flow.h"
 #include "rna_mac_cache.h"
 #include "rna_module.h"
 #include "rna_pnd.h"
index ec68272546593f7c81943e8048baf3c22ec45fda..29c31c2c3ce6bb71df44f137a8e0b19df51539a8 100644 (file)
@@ -24,6 +24,8 @@
 #include "host_tracker/host_cache.h"
 #include "host_tracker/host_tracker.h"
 
+#include "rna_flow.h"
+
 namespace snort
 {
 class Flow;
@@ -31,8 +33,6 @@ struct Packet;
 class FpFingerprint;
 }
 
-using RnaTracker = std::shared_ptr<snort::HostTracker>;
-
 struct RnaLoggerEvent : public Event
 {
     RnaLoggerEvent (uint16_t t, uint16_t st, const uint8_t* mc, const RnaTracker* rt,
index c2be2efe043c90934afe5441f138b8a03a172663..5da304a274f0943c6241234baa45355de2eb703e 100644 (file)
@@ -39,6 +39,7 @@
 #include "rna_app_discovery.h"
 #include "rna_fingerprint_tcp.h"
 #include "rna_fingerprint_udp.h"
+#include "rna_flow.h"
 #include "rna_logger_common.h"
 
 #ifdef UNIT_TEST
@@ -199,6 +200,12 @@ void RnaPnd::discover_network(const Packet* p, uint8_t ttl)
             rna_flow = new RNAFlow();
             p->flow->set_flow_data(rna_flow);
         }
+        ht->add_flow(rna_flow);
+
+        if ( p->is_from_client() )
+            rna_flow->set_client(ht);
+        else
+            rna_flow->set_server(ht);
     }
 
     if ( new_host )
diff --git a/src/network_inspectors/rna/test/rna_flow_mock.cc b/src/network_inspectors/rna/test/rna_flow_mock.cc
new file mode 100644 (file)
index 0000000..8521aeb
--- /dev/null
@@ -0,0 +1,34 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2021-2021 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+// rna_flow_mock.h author Silviu Minut <sminut@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "network_inspectors/rna/rna_flow.h"
+
+#include "sfip/sf_ip.h"
+
+using namespace std;
+using namespace snort;
+
+void RNAFlow::clear_ht(HostTracker& ht) { UNUSED(ht); }
+RnaTracker RNAFlow::get_server(const SfIp& ) { return nullptr; }
+RnaTracker RNAFlow::get_client(const SfIp&) { return nullptr; }
index 918368f12a65d31f3c06809096df78f80afc4384..69b5eb469bbbdc95bc47f77035c23b5db9c54ca9 100644 (file)
@@ -126,4 +126,7 @@ Inspector* InspectorManager::get_inspector(const char*, bool, const SnortConfig*
 {
     return nullptr;
 }
+
+void HostTracker::remove_flows() { }
+
 #endif