]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #3088 in SNORT/snort3 from ~BSACHDEV/snort3:stress_smb2 to master
authorBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Thu, 7 Oct 2021 14:25:22 +0000 (14:25 +0000)
committerBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Thu, 7 Oct 2021 14:25:22 +0000 (14:25 +0000)
Squashed commit of the following:

commit 4de99c141ba599b04b6bb43fbb6af18b63ae836a
Author: Bhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Date:   Mon Oct 4 14:05:53 2021 +0000

    dce_smb: Optimised handling pruning of flows in stress environment

src/file_api/file_flows.cc
src/service_inspectors/dce_rpc/dce_smb2.cc
src/service_inspectors/dce_rpc/dce_smb2_file.h
src/service_inspectors/dce_rpc/dce_smb2_session.cc
src/service_inspectors/dce_rpc/dce_smb2_session.h
src/service_inspectors/dce_rpc/dce_smb2_session_cache.h
src/service_inspectors/dce_rpc/dce_smb2_tree.cc
src/service_inspectors/dce_rpc/dce_smb_inspector.cc

index df7ad70ec484d979bece40185c8dba8cfb91e464..ac49d75b936e938bc69361715460cfdd4f63a364 100644 (file)
@@ -150,7 +150,7 @@ FilePolicyBase* FileFlows::get_file_policy(Flow* flow)
 void FileFlows::set_current_file_context(FileContext* ctx)
 {
     // If we finished processing a file context object last time, delete it
-    if (current_context_delete_pending)
+    if (current_context_delete_pending and (current_context != ctx))
     {
         delete current_context;
         current_context_delete_pending = false;
index 5fff3e76e045ddd683905a8dbbc47fcfc5032392..05ae281c92267168cdc3cc1e4aac1cc422e9fb7d 100644 (file)
@@ -186,6 +186,9 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
 {
     const uint8_t* smb_data = (const uint8_t*)smb_hdr + SMB2_HEADER_LENGTH;
     uint16_t structure_size = alignedNtohs((const uint16_t*)smb_data);
+    uint16_t command = alignedNtohs(&(smb_hdr->command));
+    uint64_t session_id = Smb2Sid(smb_hdr);
+    Dce2Smb2SessionTrackerPtr session = find_session(session_id);
 
 // Macro and shorthand to save some repetitive code
 // Should only be used in this function
@@ -206,6 +209,11 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, TRACE_ERROR_LEVEL, GET_CURRENT_PACKET, \
                 "%s : smb data beyond end detected\n", \
                 smb2_command_string[command]); \
+            if (session) \
+            { \
+                session->set_do_not_delete(false); \
+                session->set_prev_comand(SMB2_COM_MAX); \
+            } \
             return; \
         } \
     }
@@ -217,6 +225,11 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             dce2_smb_stats.v2_ ## counter ## _err_resp++; \
             SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, TRACE_INFO_LEVEL, GET_CURRENT_PACKET, "%s_RESP: error\n", \
                 smb2_command_string[command]); \
+            if (session) \
+            { \
+                session->set_do_not_delete(false); \
+                session->set_prev_comand(SMB2_COM_MAX); \
+            } \
             return; \
         } \
     }
@@ -227,19 +240,34 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
         SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, TRACE_ERROR_LEVEL \
            , GET_CURRENT_PACKET, "%s: invalid struct size\n", \
             smb2_command_string[command]); \
+        if (session) \
+        { \
+            session->set_do_not_delete(false); \
+            session->set_prev_comand(SMB2_COM_MAX); \
+        } \
         return; \
     }
 
-    uint16_t command = alignedNtohs(&(smb_hdr->command));
-    uint64_t session_id = Smb2Sid(smb_hdr);
     SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, TRACE_INFO_LEVEL, GET_CURRENT_PACKET,
         "%s : flow %" PRIu32 " mid %" PRIu64 " sid %" PRIu64 " tid %" PRIu32 "\n",
         (command < SMB2_COM_MAX ? smb2_command_string[command] : "unknown"),
         flow_key, Smb2Mid(smb_hdr), session_id, Smb2Tid(smb_hdr));
+
+    // Handling case of two threads trying to do close same session at a time
+    if (command == SMB2_COM_CLOSE and (session and session->get_prev_command() !=  SMB2_COM_MAX))
+    {
+        session->set_do_not_delete(false);
+        return;
+    }
+
+    if (session)
+    {
+        session->set_do_not_delete(true);
+        session->set_prev_comand(command);
+    }
+
     // Try to find the session.
     // The case when session is not available will be handled per command.
-    Dce2Smb2SessionTrackerPtr session = find_session(session_id);
-
     switch (command)
     {
     //commands processed by flow
@@ -459,6 +487,11 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
         dce2_smb_stats.v2_msgs_uninspected++;
         break;
     }
+    if (session)
+    {
+        session->set_prev_comand(SMB2_COM_MAX);
+        session->set_do_not_delete(false);
+    }
 }
 
 // This is the main entry point for SMB2 processing.
@@ -467,6 +500,7 @@ void Dce2Smb2SessionData::process()
     Packet* p = DetectionEngine::get_current_packet();
     const uint8_t* data_ptr = p->data;
     uint16_t data_len = p->dsize;
+    Dce2Smb2SessionTrackerPtr session = nullptr;
 
     // Process the header
     if (p->is_pdu_start())
@@ -502,6 +536,8 @@ void Dce2Smb2SessionData::process()
                        SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID,
                        TRACE_ERROR_LEVEL, p, "bad next command offset\n");
                 dce2_smb_stats.v2_bad_next_cmd_offset++;
+                if (session)
+                    session->set_do_not_delete(false);
                 return;
             }
             if (next_command_offset)
@@ -516,6 +552,8 @@ void Dce2Smb2SessionData::process()
                        SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID,
                            TRACE_INFO_LEVEL, p, "compound request limit"
                     " reached %" PRIu8 "\n",compound_request_index);
+                if (session)
+                    session->set_do_not_delete(false);
                 return;
             }
         }
@@ -523,6 +561,12 @@ void Dce2Smb2SessionData::process()
     }
     else
     {
+        if ( tcp_file_tracker )
+        {
+             session = find_session(tcp_file_tracker->get_session_id());
+             if (session)
+                 session->set_do_not_delete(true);
+        }
         tcp_file_tracker_mutex.lock();
         if ( tcp_file_tracker and tcp_file_tracker->accepting_raw_data_from(flow_key))
         {
@@ -537,6 +581,8 @@ void Dce2Smb2SessionData::process()
         }
         tcp_file_tracker_mutex.unlock();
     }
+    if (session)
+        session->set_do_not_delete(false);
 }
 
 void Dce2Smb2SessionData::set_reassembled_data(uint8_t* nb_ptr, uint16_t co_len)
index 62ae349366efaab48e40a1367ce0d54a4e27267b..8187dcffb37d54b2c7d7644b5d359339a55fe0fe 100644 (file)
@@ -42,10 +42,11 @@ public:
     Dce2Smb2FileTracker(const Dce2Smb2FileTracker& arg) = delete;
     Dce2Smb2FileTracker& operator=(const Dce2Smb2FileTracker& arg) = delete;
 
-    Dce2Smb2FileTracker(uint64_t file_idv, const uint32_t flow_key, Dce2Smb2TreeTracker* p_tree) :
+    Dce2Smb2FileTracker(uint64_t file_idv, const uint32_t flow_key, Dce2Smb2TreeTracker* p_tree,
+        uint64_t sid) :
         ignore(true), file_name_len(0), file_flow_key(flow_key),
         file_id(file_idv), file_size(0), file_name_hash(0), file_name(nullptr),
-        direction(FILE_DOWNLOAD), parent_tree(p_tree)
+        direction(FILE_DOWNLOAD), parent_tree(p_tree), session_id(sid)
     {
            SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, TRACE_DEBUG_LEVEL, GET_CURRENT_PACKET,
             "file tracker %" PRIu64 " created\n", file_id);
@@ -67,6 +68,7 @@ public:
     void set_direction(FileDirection dir) { direction = dir; }
     Dce2Smb2TreeTracker* get_parent() { return parent_tree; }
     uint64_t get_file_id() { return file_id; }
+    uint64_t get_session_id() { return session_id; }
 
 private:
     void file_detect();
@@ -81,6 +83,7 @@ private:
     FileDirection direction;
     Dce2Smb2TreeTracker* parent_tree;
     std::unordered_map<uint32_t, tcp_flow_state, std::hash<uint32_t> > flow_state;
+    uint64_t session_id;
     std::mutex process_file_mutex;
     std::mutex flow_state_mutex;
 };
index aa67f7432dc0417a45d5f10d72ec2251fc86cc0b..05ed9d89918cfd24cba9e137fe0e2861f09988e8 100644 (file)
@@ -160,6 +160,7 @@ Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::connect_tree(const uint32_t tree_id
 void Dce2Smb2SessionTracker::clean_file_context_from_flow(Dce2Smb2FileTracker* file_tracker,
     uint64_t file_id, uint64_t file_name_hash)
 {
+    attached_flows_mutex.lock();
     for (auto it_flow : attached_flows)
     {
         snort::FileFlows* file_flows = snort::FileFlows::get_file_flows(
@@ -168,6 +169,7 @@ void Dce2Smb2SessionTracker::clean_file_context_from_flow(Dce2Smb2FileTracker* f
             file_flows->remove_processed_file_context(file_name_hash, file_id);
         it_flow.second->reset_matching_tcp_file_tracker(file_tracker);
     }
+    attached_flows_mutex.unlock();
 }
 
 void Dce2Smb2SessionTracker::increase_size(const size_t size)
@@ -191,6 +193,16 @@ void Dce2Smb2SessionTracker::unlink()
 // Session Tracker is created and destroyed only from session cache
 Dce2Smb2SessionTracker::~Dce2Smb2SessionTracker()
 {
+    if (!(fcfs_mutex.try_lock()))
+        return;
+
+    if (do_not_delete )
+    {
+        // Dont prune the session in LRU Cache
+        smb2_session_cache.find_id(get_key());
+        fcfs_mutex.unlock();
+        return;
+    }
     if (smb_module_is_up and (snort::is_packet_thread()))
     {
            SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, 
@@ -212,6 +224,7 @@ Dce2Smb2SessionTracker::~Dce2Smb2SessionTracker()
     {
         delete tree;
     }
-
+    do_not_delete = false;
+    fcfs_mutex.unlock();
 }
 
index ce31f10d1f311f03e6ce01c6cd232762cb450ada..bf6a21d620534348e19e9904c40cd09414282be4 100644 (file)
@@ -36,7 +36,9 @@ public:
         session_id = key.sid;
         session_key = key;
         reload_prune = false;
-           SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, TRACE_DEBUG_LEVEL, GET_CURRENT_PACKET, 
+        do_not_delete = false;
+        command_prev = SMB2_COM_MAX;
+        SMB_DEBUG(dce_smb_trace, DEFAULT_TRACE_OPTION_ID, TRACE_DEBUG_LEVEL, GET_CURRENT_PACKET,
             "session tracker %" PRIu64 "created\n", session_id);
     }
 
@@ -70,17 +72,30 @@ public:
     void process(const uint16_t, uint8_t, const Smb2Hdr*, const uint8_t*, const uint32_t);
     void increase_size(const size_t size);
     void decrease_size(const size_t size);
-    void set_reload_prune() { reload_prune = true; }
+    void set_reload_prune(bool flag) { reload_prune = flag; }
+    uint64_t get_session_id() { return session_id; }
+    void set_do_not_delete(bool flag) { do_not_delete = flag; }
+    bool get_do_not_delete() { return do_not_delete; }
+    void set_prev_comand(uint16_t cmd) { command_prev = cmd; }
+    uint16_t get_prev_command() { return command_prev; }
 
 private:
+    // do_not_delete is to make sure when we are in processing we should not delete the context
+    // which is being processed
+    bool do_not_delete;
     Dce2Smb2TreeTracker* find_tree_for_message(const uint64_t, const uint32_t);
     uint64_t session_id;
+    //to keep the tab of previous command
+    uint16_t command_prev;
     Smb2SessionKey session_key;
     Dce2Smb2SessionDataMap attached_flows;
     Dce2Smb2TreeTrackerMap connected_trees;
     std::atomic<bool> reload_prune;
     std::mutex connected_trees_mutex;
     std::mutex attached_flows_mutex;
+    // fcfs_mutex is to make sure the mutex is taken at first come first basis if code 
+    // is being hit by two different paths
+    std::mutex fcfs_mutex;
 };
 
 #endif
index 982a095e909fc8aa52dd40e8033896db0597f08d..e82d260bfdec97273450a3d14d0c5e01f1b9310b 100644 (file)
@@ -42,23 +42,25 @@ public:
 
     using Data = std::shared_ptr<Value>;
 
+    Data find_id(Key key)
+    {
+        Data session = this->find(key);
+        return session;
+    }
+
     Data find_session(Key key, Dce2Smb2SessionData* ssd)
     {
-        flow_mutex.lock();
         Data session = this->find(key);
         if (session)
             session->attach_flow(ssd->get_flow_key(), ssd);
-        flow_mutex.unlock();
         return session;
     }
 
     Data find_else_create_session(Key& key, Dce2Smb2SessionData* ssd)
     {
         Data new_session = Data(new Value(key));
-        flow_mutex.lock();
         Data session = this->find_else_insert(key, new_session, nullptr);
         session->attach_flow(ssd->get_flow_key(), ssd);
-        flow_mutex.unlock();
         return session;
     }
 
@@ -93,7 +95,7 @@ public:
             data.emplace_back(list_iter->second); // increase reference count
             // This instructs the session_tracker to take a lock before detaching
             // from ssd, when it is getting destroyed.
-            list_iter->second->set_reload_prune();
+            list_iter->second->set_reload_prune(true);
             decrease_size(list_iter->second.get());
             map.erase(list_iter->first);
             list.erase(list_iter);
@@ -110,7 +112,6 @@ private:
     using LruBase::max_size;
     using LruBase::stats;
     using LruListIter = typename LruBase::LruListIter;
-    std::mutex flow_mutex;
     void increase_size(Value* value_ptr=nullptr) override
     {
         if (value_ptr) current_size += sizeof(*value_ptr);
index 5dc9e10aabf53c8c19115e17697282bc1741ecf8..59b435d4b2969a212c6433c0e79926e19d210e83 100644 (file)
@@ -39,7 +39,8 @@ uint64_t Smb2Mid(const Smb2Hdr* hdr)
 Dce2Smb2FileTracker* Dce2Smb2TreeTracker::open_file(const uint64_t file_id,
     const uint32_t current_flow_key)
 {
-    Dce2Smb2FileTracker* ftracker = new Dce2Smb2FileTracker(file_id, current_flow_key, this);
+    Dce2Smb2FileTracker* ftracker = new Dce2Smb2FileTracker(file_id, current_flow_key, this,
+        this->get_parent()->get_session_id());
     tree_tracker_mutex.lock();
     opened_files.insert(std::make_pair(file_id, ftracker));
     tree_tracker_mutex.unlock();
index f2ae621bd9d9b523e25e81ea31d469e390c1fcad..76f192106d87589420ebca16c19a0ec8777d732d 100644 (file)
@@ -27,6 +27,7 @@
 #include "dce_smb_module.h"
 #include "dce_smb_utils.h"
 #include "dce_smb2_session_cache.h"
+#include "main/thread_config.h"
 
 #define DCE_SMB_PROTOCOL_ID "netbios-ssn"
 
@@ -139,8 +140,9 @@ static Inspector* dce2_smb_ctor(Module* m)
     dce2SmbProtoConf config;
     mod->get_data(config);
     size_t max_smb_mem = DCE2_ScSmbMemcap(&config);
+    uint16_t num_threads = ThreadConfig::get_instance_max();
     smb_module_is_up = true;
-    smb2_session_cache.reload_prune(max_smb_mem);
+    smb2_session_cache.reload_prune(max_smb_mem*num_threads);
     return new Dce2Smb(config);
 }