]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3444: dce_rpc: converting tree tracker to shared ptr
authorBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Tue, 31 May 2022 11:12:31 +0000 (11:12 +0000)
committerBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Tue, 31 May 2022 11:12:31 +0000 (11:12 +0000)
Merge in SNORT/snort3 from ~UMUNNIKR/snort3:tree_tracker_shared_ptr to master

Squashed commit of the following:

commit 7e04875cd7ad8cb7122469b985fe8f02575dba4d
Author: Unnikrishnan M <umunnikr@cisco.com>
Date:   Tue May 17 12:11:41 2022 +0530

    dce_rpc: converting tree tracker to shared ptr

src/service_inspectors/dce_rpc/dce_co.cc
src/service_inspectors/dce_rpc/dce_smb2_file.cc
src/service_inspectors/dce_rpc/dce_smb2_file.h
src/service_inspectors/dce_rpc/dce_smb2_session.cc
src/service_inspectors/dce_rpc/dce_smb2_session.h
src/service_inspectors/dce_rpc/dce_smb2_tree.cc
src/service_inspectors/dce_rpc/dce_smb2_tree.h

index 69b65225ed3ed31177f6af7ded6cb9c957d83b46..5717f7b51c18de7cb6d48299bd8d83229c308d0f 100644 (file)
@@ -1370,7 +1370,7 @@ static Packet* dce_co_reassemble(DCE2_SsnData* sd, DCE2_CoTracker* cot,
 
     DCE2_RpktType rpkt_type;
     Packet* rpkt = DCE2_CoGetRpkt(sd, cot, co_rtype, &rpkt_type);
-    if (rpkt == nullptr)
+    if (!rpkt || !rpkt->data)
     {
         return nullptr;
     }
@@ -2246,7 +2246,7 @@ static Packet* DCE2_CoGetSegRpkt(DCE2_SsnData* sd,
     case DCE2_TRANS_TYPE__SMB:
         rpkt = DCE2_GetRpkt(p, DCE2_RPKT_TYPE__SMB_CO_SEG, data_ptr, data_len);
 
-        if ( !rpkt )
+        if ( !rpkt  || !rpkt->data )
             return nullptr;
 
         set_smb_reassembled_data(const_cast<uint8_t*>(rpkt->data),
index 1ab85f677d12c81ca9c185c9a2e9ffeb4de03a6b..1d4a2859d5ce2e45b74124e40068aa6e28a4b515 100644 (file)
@@ -303,6 +303,6 @@ Dce2Smb2FileTracker::~Dce2Smb2FileTracker(void)
         snort_free((void*)file_name);
 
     file_name = nullptr;
-    parent_tree = nullptr;
+    parent_tree.reset();
 }
 
index 4aed699bfb04a99472ddd2cf52ae917ddf11f343..efb26aff84dd5f8a213c43df179f273f119eaadc 100644 (file)
@@ -27,6 +27,7 @@
 #include <atomic>
 
 class Dce2Smb2TreeTracker;
+using Dce2Smb2TreeTrackerPtr = std::shared_ptr<Dce2Smb2TreeTracker>;
 
 typedef struct _tcp_flow_state
 {
@@ -43,7 +44,7 @@ public:
     Dce2Smb2FileTracker(const Dce2Smb2FileTracker& arg) = delete;
     Dce2Smb2FileTracker& operator=(const Dce2Smb2FileTracker& arg) = delete;
 
-    Dce2Smb2FileTracker(uint64_t file_idv, const uint32_t flow_key, Dce2Smb2TreeTracker* p_tree,
+    Dce2Smb2FileTracker(uint64_t file_idv, const uint32_t flow_key, Dce2Smb2TreeTrackerPtr p_tree,
         uint64_t sid) :
         ignore(true), file_name_len(0), file_flow_key(flow_key),
         file_id(file_idv), file_size(0), file_name_hash(0), file_name(nullptr),
@@ -67,8 +68,8 @@ public:
     void stop_accepting_raw_data_from(uint32_t);
 
     void set_direction(FileDirection dir) { direction = dir; }
-    Dce2Smb2TreeTracker* get_parent() { return parent_tree; }
-    void set_parent(Dce2Smb2TreeTracker* pt) { parent_tree = pt; }
+    Dce2Smb2TreeTrackerPtr get_parent() { return parent_tree; }
+    void set_parent(Dce2Smb2TreeTrackerPtr pt) { parent_tree = pt; }
     uint64_t get_file_id() { return file_id; }
     uint64_t get_file_name_hash() { return file_name_hash; }
     uint64_t get_session_id() { return session_id; }
@@ -89,7 +90,7 @@ private:
     uint64_t file_name_hash;
     char* file_name;
     FileDirection direction;
-    Dce2Smb2TreeTracker* parent_tree;
+    Dce2Smb2TreeTrackerPtr parent_tree;
     std::unordered_map<uint32_t, tcp_flow_state, std::hash<uint32_t> > flow_state;
     uint64_t session_id;
     std::mutex process_file_mutex;
index 55b25f0a37338cb413ed68f7eaeb1503065eac70..d2d2c6786134a37fe9574c2f0de1c04927550e3f 100644 (file)
@@ -40,7 +40,7 @@ Dce2Smb2SessionData* Dce2Smb2SessionTracker::get_flow(uint32_t flow_key)
     return (it_flow != attached_flows.end()) ? it_flow->second : nullptr;
 }
 
-Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::find_tree_for_message(
+Dce2Smb2TreeTrackerPtr Dce2Smb2SessionTracker::find_tree_for_message(
     const uint64_t message_id, const uint32_t flow_key)
 {
     std::lock_guard<std::mutex> guard(connected_trees_mutex);
@@ -53,10 +53,20 @@ Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::find_tree_for_message(
     return nullptr;
 }
 
+Dce2Smb2TreeTrackerPtr Dce2Smb2SessionTracker::find_tree_for_tree_id(
+    const uint32_t tree_id)
+{
+    std::lock_guard<std::mutex> guard(connected_trees_mutex);
+    auto it_tree = connected_trees.find(tree_id);
+    if (it_tree != connected_trees.end())
+        return it_tree->second;
+    return nullptr;
+}
+
 void Dce2Smb2SessionTracker::process(const uint16_t command, uint8_t command_type,
     const Smb2Hdr* smb_header, const uint8_t* end, const uint32_t current_flow_key)
 {
-    Dce2Smb2TreeTracker* tree = nullptr;
+    Dce2Smb2TreeTrackerPtr tree;
     uint32_t tree_id = Smb2Tid(smb_header);
 
     if (tree_id)
@@ -121,7 +131,7 @@ void Dce2Smb2SessionTracker::process(const uint16_t command, uint8_t command_typ
     }
 }
 
-Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::connect_tree(const uint32_t tree_id,
+Dce2Smb2TreeTrackerPtr Dce2Smb2SessionTracker::connect_tree(const uint32_t tree_id,
     const uint32_t current_flow_key, const uint8_t share_type)
 {
     Dce2Smb2SessionData* current_flow = get_flow(current_flow_key);
@@ -135,7 +145,7 @@ Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::connect_tree(const uint32_t tree_id
         dce2_smb_stats.v2_tree_cnct_ignored++;
         return nullptr;
     }
-    Dce2Smb2TreeTracker* tree = nullptr;
+    Dce2Smb2TreeTrackerPtr tree;
     connected_trees_mutex.lock();
     auto it_tree = connected_trees.find(tree_id);
     if (it_tree != connected_trees.end())
@@ -143,7 +153,7 @@ Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::connect_tree(const uint32_t tree_id
     connected_trees_mutex.unlock();
     if (!tree)
     {
-        tree = new Dce2Smb2TreeTracker(tree_id, this, share_type);
+        tree = std::make_shared<Dce2Smb2TreeTracker>(tree_id, this, share_type);
         connected_trees_mutex.lock();
         connected_trees.insert(std::make_pair(tree_id, tree));
         connected_trees_mutex.unlock();
@@ -206,20 +216,17 @@ Dce2Smb2SessionTracker::~Dce2Smb2SessionTracker(void)
         return;
     }
 
-    std::vector<Dce2Smb2TreeTracker*> all_trees;
     connected_trees_mutex.lock();
     auto it_tree = connected_trees.begin();
     while (it_tree != connected_trees.end())
     {
-        all_trees.push_back(it_tree->second);
-        it_tree = connected_trees.erase(it_tree);
+        auto next_it_tree = std::next(it_tree);
+        it_tree->second->close_all_files();
+        disconnect_tree(it_tree->second->get_tree_id());
+        it_tree = next_it_tree;
     }
     connected_trees_mutex.unlock();
 
-    for (Dce2Smb2TreeTracker* tree : all_trees)
-    {
-        delete tree;
-    }
     do_not_delete = false;
     fcfs_mutex.unlock();
 }
index 8a813e1059183bbde0ab1f437879ad27c4099164..aa7c4ae44c760c90cccbbaafa02537d7e12b1dab 100644 (file)
@@ -45,11 +45,10 @@ public:
     }
 
     ~Dce2Smb2SessionTracker();
-    Dce2Smb2TreeTracker* connect_tree(const uint32_t, const uint32_t,
+    Dce2Smb2TreeTrackerPtr connect_tree(const uint32_t, const uint32_t,
         uint8_t=SMB2_SHARE_TYPE_DISK);
     void disconnect_tree(uint32_t tree_id)
     {
-        std::lock_guard<std::mutex> guard(connected_trees_mutex);
         connected_trees.erase(tree_id);
         decrease_size(sizeof(Dce2Smb2TreeTracker));
     }
@@ -90,12 +89,13 @@ public:
             dce2_smb_stats.total_encrypted_sessions++; 
     }
     bool get_encryption_flag() { return encryption_flag; }
+    Dce2Smb2TreeTrackerPtr find_tree_for_tree_id(const uint32_t);
 private:
     // do_not_delete is to make sure when we are in processing we should not delete the context
     // which is being processed
     bool do_not_delete;
     bool file_context_cleaned;
-    Dce2Smb2TreeTracker* find_tree_for_message(const uint64_t, const uint32_t);
+    Dce2Smb2TreeTrackerPtr find_tree_for_message(const uint64_t, const uint32_t);
     uint64_t session_id;
     //to keep the tab of previous command
     uint16_t command_prev;
index 2a7bb73e43fadff824e28bc736395b307e4c2821..4d5c451600385737e3b93c8525b6d1c9f98efbe9 100644 (file)
@@ -40,8 +40,11 @@ uint64_t Smb2Mid(const Smb2Hdr* hdr)
 Dce2Smb2FileTrackerPtr Dce2Smb2TreeTracker::open_file(const uint64_t file_id,
     const uint32_t current_flow_key)
 {
+    Dce2Smb2TreeTrackerPtr tree_ptr = parent_session->find_tree_for_tree_id(tree_id);
+    if (!tree_ptr)
+        return nullptr;
     std::shared_ptr<Dce2Smb2FileTracker> ftracker =  std::make_shared<Dce2Smb2FileTracker> (
-        file_id, current_flow_key, this, this->get_parent()->get_session_id());
+        file_id, current_flow_key, tree_ptr, this->get_parent()->get_session_id());
     tree_tracker_mutex.lock();
     opened_files.insert(std::make_pair(file_id, ftracker));
     tree_tracker_mutex.unlock();
@@ -65,7 +68,7 @@ void Dce2Smb2TreeTracker::close_file(uint64_t file_id, bool destroy)
     if (it_file != opened_files.end())
     {
         Dce2Smb2FileTrackerPtr file = it_file->second;
-        it_file->second->set_parent(nullptr);
+        it_file->second->get_parent().reset();
         if (opened_files.erase(file_id) and destroy)
         {
             parent_session->decrease_size(sizeof(Dce2Smb2FileTracker));
@@ -76,6 +79,19 @@ void Dce2Smb2TreeTracker::close_file(uint64_t file_id, bool destroy)
     tree_tracker_mutex.unlock();
 }
 
+void Dce2Smb2TreeTracker::close_all_files()
+{
+    tree_tracker_mutex.lock();
+    auto it_file = opened_files.begin();
+    while (it_file != opened_files.end())
+    {
+        get_parent()->clean_file_context_from_flow(it_file->second->get_file_id(),
+            it_file->second->get_file_name_hash());
+        it_file = opened_files.erase(it_file);
+    }
+    tree_tracker_mutex.unlock();
+}
+
 Dce2Smb2RequestTracker* Dce2Smb2TreeTracker::find_request(const uint64_t message_id,
     const uint32_t current_flow_key)
 {
@@ -603,12 +619,11 @@ Dce2Smb2TreeTracker::~Dce2Smb2TreeTracker(void)
     {
         get_parent()->clean_file_context_from_flow(it_file.second->get_file_id(),
             it_file.second->get_file_name_hash());
-        it_file.second->set_parent(nullptr);
+        it_file.second->get_parent().reset();
         parent_session->decrease_size(sizeof(Dce2Smb2FileTracker));
     }
 
     tree_tracker_mutex.unlock();
 
-    parent_session->disconnect_tree(tree_id);
 }
 
index d679d6d0a67ff01570eb06945dd2fdd25c0307ea..0df0b3f4b73cc6124acb3180b5d571e51c5a5f8c 100644 (file)
@@ -60,6 +60,7 @@ public:
 
     Dce2Smb2FileTrackerPtr open_file(const uint64_t, const uint32_t);
     void close_file(uint64_t, bool);
+    void close_all_files();
     Dce2Smb2FileTrackerPtr find_file(uint64_t);
     Dce2Smb2RequestTracker* find_request(const uint64_t, const uint32_t);
     void process(uint16_t, uint8_t, const Smb2Hdr*, const uint8_t*, const uint32_t);
@@ -98,8 +99,9 @@ private:
     std::mutex tree_tracker_mutex;
 };
 
+using Dce2Smb2TreeTrackerPtr = std::shared_ptr<Dce2Smb2TreeTracker>;
 using Dce2Smb2TreeTrackerMap =
-    std::unordered_map<uint32_t, Dce2Smb2TreeTracker*, std::hash<uint32_t> >;
+    std::unordered_map<uint32_t, Dce2Smb2TreeTrackerPtr, std::hash<uint32_t> >;
 
 #endif