]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2933 in SNORT/snort3 from ~DIPANDIT/snort3:multichannel_shared_pt...
authorBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Wed, 9 Jun 2021 06:30:14 +0000 (06:30 +0000)
committerBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Wed, 9 Jun 2021 06:30:14 +0000 (06:30 +0000)
Squashed commit of the following:

commit 98177702616043e80f1c7c20df6b4731696c763a
Author: Dipto Pandit (dipandit) <dipandit@cisco.com>
Date:   Tue Jun 8 13:37:02 2021 -0400

    dce_rpc: store shared pointer of session tracker

commit e6a88c3afe70c9d690489cd5f004ce2782bab9b6
Author: Dipto Pandit (dipandit) <dipandit@cisco.com>
Date:   Tue Jun 8 07:47:55 2021 -0400

    dce_rpc: handle reload prune for smb session cache

src/service_inspectors/dce_rpc/dce_smb2.cc
src/service_inspectors/dce_rpc/dce_smb2.h
src/service_inspectors/dce_rpc/dce_smb2_session.cc
src/service_inspectors/dce_rpc/dce_smb2_session.h
src/service_inspectors/dce_rpc/dce_smb2_session_cache.h
src/service_inspectors/dce_rpc/dce_smb_inspector.cc

index 1fe8b11b20d34eab952c7a009894919e6eb7b720..7e985db0d85fb51ef80b81ceb9f36ea2086f07c5 100644 (file)
@@ -142,21 +142,21 @@ Smb2SessionKey Dce2Smb2SessionData::get_session_key(uint64_t session_id)
     return key;
 }
 
-Dce2Smb2SessionTracker* Dce2Smb2SessionData::find_session(uint64_t session_id)
+Dce2Smb2SessionTrackerPtr Dce2Smb2SessionData::find_session(uint64_t session_id)
 {
     std::lock_guard<std::mutex> guard(session_data_mutex);
     auto it_session = connected_sessions.find(session_id);
 
     if (it_session != connected_sessions.end())
     {
-        Dce2Smb2SessionTracker* session = it_session->second;
+        Dce2Smb2SessionTrackerPtr session = it_session->second;
         //we already have the session, but call find to update the LRU
         smb2_session_cache.find_session(session->get_key(), this);
         return session;
     }
     else
     {
-        Dce2Smb2SessionTracker* session = smb2_session_cache.find_session(
+        Dce2Smb2SessionTrackerPtr session = smb2_session_cache.find_session(
             get_session_key(session_id), this);
         if (session)
             connected_sessions.insert(std::make_pair(session_id,session));
@@ -165,18 +165,20 @@ Dce2Smb2SessionTracker* Dce2Smb2SessionData::find_session(uint64_t session_id)
 }
 
 // Caller must ensure that the session is not already present in flow
-Dce2Smb2SessionTracker* Dce2Smb2SessionData::create_session(uint64_t session_id)
+Dce2Smb2SessionTrackerPtr Dce2Smb2SessionData::create_session(uint64_t session_id)
 {
     Smb2SessionKey session_key = get_session_key(session_id);
     std::lock_guard<std::mutex> guard(session_data_mutex);
-    Dce2Smb2SessionTracker* session = smb2_session_cache.find_else_create_session(session_key, this);
+    Dce2Smb2SessionTrackerPtr session = smb2_session_cache.find_else_create_session(session_key, this);
     connected_sessions.insert(std::make_pair(session_id, session));
     return session;
 }
 
-void Dce2Smb2SessionData::remove_session(uint64_t session_id)
+void Dce2Smb2SessionData::remove_session(uint64_t session_id, bool sync)
 {
+    if (sync) session_data_mutex.lock();
     connected_sessions.erase(session_id);
+    if (sync) session_data_mutex.unlock();
 }
 
 void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
@@ -235,7 +237,7 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
         flow_key, Smb2Mid(smb_hdr), session_id, Smb2Tid(smb_hdr));
     // Try to find the session.
     // The case when session is not available will be handled per command.
-    Dce2Smb2SessionTracker* session = find_session(session_id);
+    Dce2Smb2SessionTrackerPtr session = find_session(session_id);
 
     switch (command)
     {
index 1d04cee84db64ec3c5b0cbee4a90582e39ff712b..222c4e2e40bf104910af30adc76a36a9deb4f0d1 100644 (file)
@@ -295,8 +295,9 @@ struct Smb2TreeConnectResponseHdr
 class Dce2Smb2FileTracker;
 class Dce2Smb2SessionTracker;
 
+using Dce2Smb2SessionTrackerPtr = std::shared_ptr<Dce2Smb2SessionTracker>;
 using Dce2Smb2SessionTrackerMap =
-    std::unordered_map<uint64_t, Dce2Smb2SessionTracker*, std::hash<uint64_t> >;
+    std::unordered_map<uint64_t, Dce2Smb2SessionTrackerPtr, std::hash<uint64_t> >;
 
 PADDING_GUARD_BEGIN
 struct Smb2SessionKey
@@ -444,7 +445,7 @@ public:
     Dce2Smb2SessionData(const snort::Packet*, const dce2SmbProtoConf* proto);
     ~Dce2Smb2SessionData() override;
     void process() override;
-    void remove_session(uint64_t);
+    void remove_session(uint64_t, bool = false);
     void handle_retransmit(FilePosition, FileVerdict) override { }
     void reset_matching_tcp_file_tracker(Dce2Smb2FileTracker*);
     void set_reassembled_data(uint8_t*, uint16_t) override;
@@ -458,8 +459,8 @@ public:
 private:
     void process_command(const Smb2Hdr*, const uint8_t*);
     Smb2SessionKey get_session_key(uint64_t);
-    Dce2Smb2SessionTracker* create_session(uint64_t);
-    Dce2Smb2SessionTracker* find_session(uint64_t);
+    Dce2Smb2SessionTrackerPtr create_session(uint64_t);
+    Dce2Smb2SessionTrackerPtr find_session(uint64_t);
 
     uint32_t flow_key;
     Dce2Smb2FileTracker* tcp_file_tracker;
index d5f10faea63b150a4a62b205842ad35cfc023b3f..d823fe818f122b4a28ef813fc11297ca9d9365ec 100644 (file)
@@ -176,6 +176,14 @@ void Dce2Smb2SessionTracker::decrease_size(const size_t size)
     smb2_session_cache.decrease_size(size);
 }
 
+void Dce2Smb2SessionTracker::unlink()
+{
+    attached_flows_mutex.lock();
+    for (auto it_flow : attached_flows)
+        it_flow.second->remove_session(session_id, reload_prune.load());
+    attached_flows_mutex.unlock();
+}
+
 // Session Tracker is created and destroyed only from session cache
 Dce2Smb2SessionTracker::~Dce2Smb2SessionTracker(void)
 {
@@ -185,15 +193,20 @@ Dce2Smb2SessionTracker::~Dce2Smb2SessionTracker(void)
             "session tracker %" PRIu64 " terminating\n", session_id);
     }
 
+    std::vector<Dce2Smb2TreeTracker*> all_trees;
+    connected_trees_mutex.lock();
     auto it_tree = connected_trees.begin();
     while (it_tree != connected_trees.end())
     {
-        Dce2Smb2TreeTracker* tree = it_tree->second;
+        all_trees.push_back(it_tree->second);
         it_tree = connected_trees.erase(it_tree);
+    }
+    connected_trees_mutex.unlock();
+
+    for (Dce2Smb2TreeTracker* tree : all_trees)
+    {
         delete tree;
     }
 
-    for (auto it_flow : attached_flows)
-        it_flow.second->remove_session(session_id);
 }
 
index b99ff52bc3bb31700fedd0bea87cdcdfaa46b4dd..a96153c46eb070cec2a7357df0851d3ba41c9fbc 100644 (file)
@@ -35,6 +35,7 @@ public:
     {
         session_id = key.sid;
         session_key = key;
+        reload_prune = false;
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "session tracker %" PRIu64
             " created\n", session_id);
     }
@@ -64,10 +65,12 @@ public:
 
     Smb2SessionKey get_key() { return session_key; }
     void clean_file_context_from_flow(Dce2Smb2FileTracker*, uint64_t, uint64_t);
+    void unlink();
     Dce2Smb2SessionData* get_flow(uint32_t);
     void process(const uint16_t, uint8_t, const Smb2Hdr*, const uint8_t*, const uint32_t);
     void increase_size(const size_t size);
     void decrease_size(const size_t size);
+    void set_reload_prune() { reload_prune = true; }
 
 private:
     Dce2Smb2TreeTracker* find_tree_for_message(const uint64_t, const uint32_t);
@@ -75,6 +78,7 @@ private:
     Smb2SessionKey session_key;
     Dce2Smb2SessionDataMap attached_flows;
     Dce2Smb2TreeTrackerMap connected_trees;
+    std::atomic<bool> reload_prune;
     std::mutex connected_trees_mutex;
     std::mutex attached_flows_mutex;
 };
index 6a19c46894e24c7a5e3570a7eee612d4af5b73fe..982a095e909fc8aa52dd40e8033896db0597f08d 100644 (file)
@@ -40,21 +40,23 @@ public:
     Dce2Smb2SharedCache(const size_t initial_size) :
         LruCacheShared<Key, Value, Hash, Eq, Purgatory>(initial_size) { }
 
-    Value* find_session(Key key, Dce2Smb2SessionData* ssd)
+    using Data = std::shared_ptr<Value>;
+
+    Data find_session(Key key, Dce2Smb2SessionData* ssd)
     {
         flow_mutex.lock();
-        Value* session = this->find(key).get();
+        Data session = this->find(key);
         if (session)
             session->attach_flow(ssd->get_flow_key(), ssd);
         flow_mutex.unlock();
         return session;
     }
 
-    Value* find_else_create_session(Key& key, Dce2Smb2SessionData* ssd)
+    Data find_else_create_session(Key& key, Dce2Smb2SessionData* ssd)
     {
-        std::shared_ptr<Value> new_session = std::shared_ptr<Value>(new Value(key));
+        Data new_session = Data(new Value(key));
         flow_mutex.lock();
-        Value* session = this->find_else_insert(key, new_session, nullptr).get();
+        Data session = this->find_else_insert(key, new_session, nullptr);
         session->attach_flow(ssd->get_flow_key(), ssd);
         flow_mutex.unlock();
         return session;
@@ -76,9 +78,38 @@ public:
         current_size -= size;
     }
 
+    // Since decrease_size() does not account for associated objects in smb2_session_cache,
+    // we will over-prune when we reach the new_size here, as more space will be freed up
+    // when actual objects are destroyed. We might need to do gradual pruning like how
+    // host cache does. For now over pruning is ok.
+    void reload_prune(size_t new_size)
+    {
+        Purgatory data;
+        std::lock_guard<std::mutex> cache_lock(cache_mutex);
+        max_size = new_size;
+        while (current_size > max_size && !list.empty())
+        {
+            LruListIter list_iter = --list.end();
+            data.emplace_back(list_iter->second); // increase reference count
+            // This instructs the session_tracker to take a lock before detaching
+            // from ssd, when it is getting destroyed.
+            list_iter->second->set_reload_prune();
+            decrease_size(list_iter->second.get());
+            map.erase(list_iter->first);
+            list.erase(list_iter);
+            ++stats.reload_prunes;
+        }
+    }
+
 private:
-    using LruCacheShared<Key, Value, Hash, Eq, Purgatory>::current_size;
-    using LruCacheShared<Key, Value, Hash, Eq, Purgatory>::cache_mutex;
+    using LruBase = LruCacheShared<Key, Value, Hash, Eq, Purgatory>;
+    using LruBase::cache_mutex;
+    using LruBase::current_size;
+    using LruBase::list;
+    using LruBase::map;
+    using LruBase::max_size;
+    using LruBase::stats;
+    using LruListIter = typename LruBase::LruListIter;
     std::mutex flow_mutex;
     void increase_size(Value* value_ptr=nullptr) override
     {
@@ -91,6 +122,8 @@ private:
         {
             assert(current_size >= sizeof(*value_ptr) );
             current_size -= sizeof(*value_ptr);
+            //This is going down, remove references from flow here
+            value_ptr->unlink();
         }
     }
 };
index 65924ab6b8553eb03393cca489e66505efbaee70..174b81a867b8f5e2e0d212900cc6027f5dd1e88e 100644 (file)
@@ -129,7 +129,7 @@ static Inspector* dce2_smb_ctor(Module* m)
     mod->get_data(config);
     size_t max_smb_mem = DCE2_ScSmbMemcap(&config);
     smb_module_is_up = true;
-    smb2_session_cache.set_max_size(max_smb_mem);
+    smb2_session_cache.reload_prune(max_smb_mem);
     return new Dce2Smb(config);
 }