]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2372 in SNORT/snort3 from ~NEHASH4/snort3:final_smb_changes to...
authorBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Wed, 5 Aug 2020 13:17:32 +0000 (13:17 +0000)
committerBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Wed, 5 Aug 2020 13:17:32 +0000 (13:17 +0000)
Squashed commit of the following:

commit 8d9dafa0addf0eab367debd9007efcd5bb5cea67
Author: Neha Sharma <nehash4@cisco.com>
Date:   Fri Jul 31 13:02:58 2020 -0400

    dce_rpc: fix for smb session cleanup issue

13 files changed:
src/file_api/file_config.h
src/file_api/file_flows.cc
src/file_api/file_flows.h
src/service_inspectors/dce_rpc/dce_db.h
src/service_inspectors/dce_rpc/dce_smb.cc
src/service_inspectors/dce_rpc/dce_smb2.cc
src/service_inspectors/dce_rpc/dce_smb2.h
src/service_inspectors/dce_rpc/dce_smb2_commands.cc
src/service_inspectors/dce_rpc/dce_smb2_commands.h
src/service_inspectors/dce_rpc/dce_smb2_utils.cc
src/service_inspectors/dce_rpc/dce_smb2_utils.h
src/service_inspectors/dce_rpc/dce_smb_module.h
src/service_inspectors/dce_rpc/smb_message.cc

index c750d1d02560c8f252d69d215696dcdef229f02d..a834c55e5958d1d3f94e2daf24a8f191a8cb1f02 100644 (file)
@@ -37,7 +37,7 @@
 #define DEFAULT_FILE_CAPTURE_MIN_SIZE       0           // 0
 #define DEFAULT_FILE_CAPTURE_BLOCK_SIZE     32768       // 32 KiB
 #define DEFAULT_MAX_FILES_CACHED            65536
-#define DEFAULT_MAX_FILES_PER_FLOW          32
+#define DEFAULT_MAX_FILES_PER_FLOW          128
 
 #define FILE_ID_NAME "file_id"
 #define FILE_ID_HELP "configure file identification"
index 76ac8f8db46988b2d1b2346db17f7ce581cdc827..d55f7fce0317c10c0e887ea25a39a462e6f22a23 100644 (file)
@@ -238,6 +238,21 @@ void FileFlows::remove_processed_file_context(uint64_t file_id)
         current_context_delete_pending = true;
 }
 
+// Remove file context explicitly
+void FileFlows::remove_processed_file_context(uint64_t file_id, uint64_t multi_file_processing_id)
+{
+    if (!multi_file_processing_id)
+        multi_file_processing_id = file_id;
+
+    FileContext* context = get_file_context(file_id, false, multi_file_processing_id);
+    if (context)
+    {
+        set_current_file_context(context);
+        context->processing_complete = true;
+        remove_processed_file_context(multi_file_processing_id);
+    }
+}
+
 /* This function is used to process file that is sent in pieces
  *
  * Return:
index 3e2563fec540ccbb353beb7f3d7645323b5b6dde..ff17fcb113ce72269e87d57a22f5f1fcd9f1cfe5 100644 (file)
@@ -78,6 +78,8 @@ public:
     // Remove a file from the flow object when processing is complete
     void remove_processed_file_context(uint64_t file_id);
 
+    void remove_processed_file_context(uint64_t file_id, uint64_t multi_file_processing_id);
+
     uint64_t get_new_file_instance();
 
     void set_file_name(const uint8_t* fname, uint32_t name_size, uint64_t file_id=0);
index b3653fd9c7cb64c1ebdbcb785dc4c870b906211e..7d7696dd79c916b5fe30a3e65c64bf951e783779 100644 (file)
 
 #include "main/snort_types.h"
 
-// Callbacks
-typedef void (* DCE2_DbDataFree)(void*);
-
 template<typename Key, typename Value, typename Hash>
 class DCE2_Db
 {
 public:
 
-    virtual void Init(const DCE2_DbDataFree func) = 0;
-    virtual DCE2_Ret Insert(const Key& key, Value data) = 0;
+    virtual void SetDoNotFree() = 0;
+    virtual bool Insert(const Key& key, Value data) = 0;
     virtual Value Find(const Key& key) = 0;
     virtual void Remove(const Key& key) = 0;
     virtual int GetSize() = 0;
@@ -52,50 +49,45 @@ class DCE2_DbMap : public DCE2_Db<Key, Value, Hash>
 {
 public:
 
-    DCE2_DbMap()
-    {
-        data_free = nullptr;
-    }
+    DCE2_DbMap() { }
 
     ~DCE2_DbMap()
     {
         auto it = Map.cbegin();
         while (it != Map.cend())
         {
-            if (data_free)
-                data_free((void*)(it->second));
-            else
+            if (!do_not_free)
                 delete it->second;
             it = Map.erase(it);
         }
     }
 
-    void Init(const DCE2_DbDataFree func);
-    DCE2_Ret Insert(const Key& key, Value data);
+    void SetDoNotFree();
+    bool Insert(const Key& key, Value data);
     Value Find(const Key& key);
     void Remove(const Key& key);
     int GetSize()
     {
         return Map.size();
     }
+
     std::vector< std::pair<Key, Value> > get_all_entry();
 
 private:
     std::unordered_map<Key, Value, Hash> Map;
-    DCE2_DbDataFree data_free;
+    bool do_not_free = false;
 };
 
 template<typename Key, typename Value, typename Hash>
-void DCE2_DbMap<Key, Value, Hash>::Init(const DCE2_DbDataFree df)
+void DCE2_DbMap<Key, Value, Hash>::SetDoNotFree()
 {
-    data_free = df;
+    do_not_free = true;
 }
 
 template<typename Key, typename Value, typename Hash>
-DCE2_Ret DCE2_DbMap<Key, Value, Hash>::Insert(const Key& key, Value data)
+bool DCE2_DbMap<Key, Value, Hash>::Insert(const Key& key, Value data)
 {
-    Map[key] = data;
-    return DCE2_RET__SUCCESS;
+    return Map.insert(std::make_pair(key,data)).second;
 }
 
 template<typename Key, typename Value, typename Hash>
@@ -113,17 +105,15 @@ void DCE2_DbMap<Key, Value, Hash>::Remove(const Key& key)
     auto elem = Map.find(key);
     if (elem != Map.end())
     {
-        if (data_free)
-            data_free((void*)(elem->second));
-        else
+        if (!do_not_free)
             delete elem->second;
+
         Map.erase(elem->first);
     }
 }
 
 template<typename Key, typename Value, typename Hash>
-std::vector< std::pair<Key, Value> >
-DCE2_DbMap<Key, Value, Hash>::get_all_entry()
+std::vector< std::pair<Key, Value> >DCE2_DbMap<Key, Value, Hash>::get_all_entry()
 {
     std::vector<std::pair<Key, Value> > vec;
 
@@ -134,4 +124,6 @@ DCE2_DbMap<Key, Value, Hash>::get_all_entry()
 
     return vec;
 }
+
 #endif
+
index bb339c88184f1590178dd8990e3133c8cc2187f1..76e8b668d67192d99d76c2cd80d907f1ec678c0e 100644 (file)
@@ -27,6 +27,7 @@
 #include "detection/detection_engine.h"
 #include "file_api/file_service.h"
 #include "protocols/packet.h"
+#include "managers/inspector_manager.h"
 
 #include "dce_context_data.h"
 #include "dce_smb_commands.h"
@@ -456,12 +457,23 @@ static void dce_smb_thread_term()
     delete smb2_session_cache;
 }
 
+static size_t get_max_smb_session(dce2SmbProtoConf* config)
+{
+    size_t smb_sess_storage_req = (sizeof(DCE2_Smb2SessionTracker) +
+        sizeof(DCE2_Smb2TreeTracker) +  sizeof(DCE2_Smb2RequestTracker) +
+        (sizeof(DCE2_Smb2FileTracker) * SMB_AVG_FILES_PER_SESSION));
+
+    size_t max_smb_sess = DCE2_ScSmbMemcap(config);
+
+    return (max_smb_sess/smb_sess_storage_req);
+}
+
 static Inspector* dce2_smb_ctor(Module* m)
 {
     Dce2SmbModule* mod = (Dce2SmbModule*)m;
     dce2SmbProtoConf config;
     mod->get_data(config);
-    session_cache_size = DCE2_ScSmbMemcap(&config)/1024;
+    session_cache_size = get_max_smb_session(&config);
     return new Dce2Smb(config);
 }
 
index 525424204ae4593d49b4b203b02e622aab9459b1..1afb65ea787e917cebfbe6bdb0c2676be72ca38d 100644 (file)
 #include "dce_smb2.h"
 #include "dce_smb2_commands.h"
 #include "detection/detection_util.h"
+#include "flow/flow_key.h"
 #include "main/snort_debug.h"
 
 using namespace snort;
 
+void get_flow_key(SmbFlowKey* key)
+{
+    const FlowKey* flow_key = DetectionEngine::get_current_packet()->flow->key;
+
+    key->ip_l[0] = flow_key->ip_l[0];
+    key->ip_l[1] = flow_key->ip_l[1];
+    key->ip_l[2] = flow_key->ip_l[2];
+    key->ip_l[3] = flow_key->ip_l[3];
+    key->ip_h[0] = flow_key->ip_h[0];
+    key->ip_h[1] = flow_key->ip_h[1];
+    key->ip_h[2] = flow_key->ip_h[2];
+    key->ip_h[3] = flow_key->ip_h[3];
+    key->mplsLabel = flow_key->mplsLabel;
+    key->port_l = flow_key->port_l;
+    key->port_h = flow_key->port_h;
+    key->vlan_tag = flow_key->vlan_tag;
+    key->addressSpaceId = flow_key->addressSpaceId;
+    key->ip_protocol = flow_key->ip_protocol;
+    key->pkt_type = (uint8_t)flow_key->pkt_type;
+    key->version = flow_key->version;
+    key->padding = 0;
+}
+
+DCE2_Smb2FileTracker::~DCE2_Smb2FileTracker(void)
+{
+    FileFlows* file_flows = FileFlows::get_file_flows(DetectionEngine::get_current_packet()->flow);
+    if (file_flows)
+        file_flows->remove_processed_file_context(file_name_hash, file_id);
+
+    if (file_name)
+        snort_free((void*)file_name);
+
+    memory::MemoryCap::update_deallocations(sizeof(*this));
+}
+
+DCE2_Smb2SessionTracker::~DCE2_Smb2SessionTracker(void)
+{
+    removeSessionFromAllConnection();
+    memory::MemoryCap::update_deallocations(sizeof(*this));
+}
+
+void DCE2_Smb2SessionTracker::removeSessionFromAllConnection()
+{
+    auto all_conn_trackers = conn_trackers.get_all_entry();
+    auto all_tree_trackers = tree_trackers.get_all_entry();
+    for ( auto& h : all_conn_trackers )
+    {
+        if (h.second->ftracker_tcp)
+        {
+            for (auto& t : all_tree_trackers)
+            {
+                DCE2_Smb2FileTracker* ftr = t.second->findFtracker(
+                    h.second->ftracker_tcp->file_id);
+                if (ftr and ftr == h.second->ftracker_tcp)
+                {
+                    h.second->ftracker_tcp = nullptr;
+                    break;
+                }
+            }
+        }
+        DCE2_Smb2RemoveSidInSsd(h.second, session_id);
+    }
+}
+
 static inline bool DCE2_Smb2FindSidTid(DCE2_Smb2SsnData* ssd, const uint64_t sid,
     const uint32_t tid, DCE2_Smb2SessionTracker** str, DCE2_Smb2TreeTracker** ttr)
 {
@@ -107,7 +172,7 @@ static void DCE2_Smb2Inspect(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
             return;
         }
 
-        DCE2_Smb2CloseCmd(ssd, smb_hdr, smb_data, end, ttr);
+        DCE2_Smb2CloseCmd(ssd, smb_hdr, smb_data, end, ttr, str);
         break;
     case SMB2_COM_TREE_CONNECT:
         dce2_smb_stats.v2_tree_cnct++;
@@ -198,24 +263,20 @@ void DCE2_Smb2Process(DCE2_Smb2SsnData* ssd)
     else if ( ssd->ftracker_tcp and (ssd->ftracker_tcp->smb2_pdu_state ==
         DCE2_SMB_PDU_STATE__RAW_DATA))
     {
-        debug_logf(dce_smb_trace, nullptr, "Processing raw data\n");
-        // continue processing raw data
-        FileDirection dir = p->is_from_client() ? FILE_UPLOAD : FILE_DOWNLOAD;
-        DCE2_Smb2ProcessFileData(ssd, data_ptr, data_len, dir);
+        debug_logf(dce_smb_trace, nullptr,
+            "processing raw data file_name_hash %" PRIu64 " fid %" PRIu64 "\n",
+            ssd->ftracker_tcp->file_name_hash, ssd->ftracker_tcp->file_id);
+
+        if (!DCE2_Smb2ProcessFileData(ssd, data_ptr, data_len))
+            return;
         ssd->ftracker_tcp->file_offset += data_len;
     }
 }
 
-static inline void DCE2_Smb2FreeSessionData(void* str)
-{
-    DCE2_Smb2SessionTracker* stracker = (DCE2_Smb2SessionTracker*)str;
-    DCE2_SmbSessionCacheRemove(stracker->get_session_id());
-}
-
 DCE2_Ret DCE2_Smb2InitData(DCE2_Smb2SsnData* ssd)
 {
     memset(&ssd->sd, 0, sizeof(DCE2_SsnData));
-    ssd->session_trackers.Init(DCE2_Smb2FreeSessionData);
+    ssd->session_trackers.SetDoNotFree();
     memset(&ssd->policy, 0, sizeof(DCE2_Policy));
     ssd->dialect_index = 0;
     ssd->ssn_state_flags = 0;
@@ -242,3 +303,4 @@ DCE2_SmbVersion DCE2_Smb2Version(const Packet* p)
 
     return DCE2_SMB_VERSION_NULL;
 }
+
index 1c5a4eb5d3225a15174c27faaecb0a57817de0d1..43a23a4f113a78bb660dd54595b8a70b2462ddab 100644 (file)
 
 #include "dce_db.h"
 #include "dce_smb.h"
+#include "hash/lru_cache_shared.h"
+#include "main/thread_config.h"
+#include "memory/memory_cap.h"
 #include "utils/util.h"
 
+#define SMB_AVG_FILES_PER_SESSION 5
+
 struct Smb2Hdr
 {
     uint8_t smb_idf[4];       /* contains 0xFE,’SMB’ */
@@ -93,37 +98,52 @@ public:
     DCE2_Smb2RequestTracker(const DCE2_Smb2RequestTracker& arg) = delete;
     DCE2_Smb2RequestTracker& operator=(const DCE2_Smb2RequestTracker& arg) = delete;
 
-    DCE2_Smb2RequestTracker(uint64_t offset_v, uint64_t file_id_v,
-        char* fname_v, uint16_t fname_len_v, DCE2_Smb2TreeTracker *ttr) :
-        fname_len(fname_len_v), fname(fname_v), offset(offset_v),
-        file_id(file_id_v), tree_trk(ttr)
+    DCE2_Smb2RequestTracker(uint64_t file_id_v, uint64_t offset_v = 0) :
+        file_id(file_id_v), offset(offset_v)
     {
-        // fname allocated by DCE2_SmbGetFileName
+        memory::MemoryCap::update_allocations(sizeof(*this));
+    }
+
+    DCE2_Smb2RequestTracker(char* fname_v, uint16_t fname_len_v) :
+        fname(fname_v), fname_len(fname_len_v)
+    {
+        memory::MemoryCap::update_allocations(sizeof(*this));
     }
 
     ~DCE2_Smb2RequestTracker()
     {
-        if (fname != nullptr)
-        {
-            snort_free((void*)fname);
-        }
+        if (!file_id and fname)
+            snort_free(fname);
+        memory::MemoryCap::update_deallocations(sizeof(*this));
+    }
+
+    uint64_t get_offset()
+    {
+        return offset;
     }
 
-    uint16_t get_file_name_len() { return fname_len; }
-    char* get_file_name()  { return fname; }
-    uint64_t get_offset() { return offset; }
-    uint64_t get_file_id() { return file_id; }
-    DCE2_Smb2TreeTracker* get_tree_tracker() { return tree_trk; }
+    uint64_t get_file_id()
+    {
+        return file_id;
+    }
+
+    void set_file_id(uint64_t fid)
+    {
+        file_id = fid;
+    }
+
+    char* fname = nullptr;
+    uint16_t fname_len = 0;
 
 private:
 
-    uint16_t fname_len;
-    char* fname;
-    uint64_t offset;
-    uint64_t file_id;
-    DCE2_Smb2TreeTracker* tree_trk;
+    uint64_t file_id = 0;
+    uint64_t offset = 0;
 };
 
+struct DCE2_Smb2SsnData;
+class DCE2_Smb2SessionTracker;
+
 class DCE2_Smb2FileTracker
 {
 public:
@@ -132,31 +152,29 @@ public:
     DCE2_Smb2FileTracker(const DCE2_Smb2FileTracker& arg) = delete;
     DCE2_Smb2FileTracker& operator=(const DCE2_Smb2FileTracker& arg) = delete;
 
-    DCE2_Smb2FileTracker(uint64_t file_id_v, char* file_name_v,
-         uint64_t file_size_v) : file_id(file_id_v), file_size(file_size_v)
+    DCE2_Smb2FileTracker(uint64_t file_id_v, DCE2_Smb2TreeTracker* ttr_v,
+         DCE2_Smb2SessionTracker* str_v) : file_id(file_id_v), ttr(ttr_v),
+         str(str_v)
     {
-        if (file_name_v)
-            file_name.assign(file_name_v);
-
-        file_offset = 0;
-        bytes_processed = 0;
+        memory::MemoryCap::update_allocations(sizeof(*this));
     }
 
-    ~DCE2_Smb2FileTracker()
-    {
-        // Nothing to be done
-    }
+    ~DCE2_Smb2FileTracker();
 
-    uint64_t bytes_processed;
-    uint64_t file_offset;
-    uint64_t file_id;
+    bool ignore = false;
+    bool upload = false;
+    uint16_t file_name_len = 0;
+    uint64_t bytes_processed = 0;
+    uint64_t file_offset = 0;
+    uint64_t file_id = 0;
     uint64_t file_size = 0;
     uint64_t file_name_hash = 0;
-    std::string file_name;
+    char* file_name = nullptr;
     DCE2_SmbPduState smb2_pdu_state;
+    DCE2_Smb2TreeTracker* ttr = nullptr;
+    DCE2_Smb2SessionTracker* str = nullptr;
 };
 
-
 typedef DCE2_DbMap<uint64_t, DCE2_Smb2FileTracker*, std::hash<uint64_t> > DCE2_DbMapFtracker;
 typedef DCE2_DbMap<uint64_t, DCE2_Smb2RequestTracker*, std::hash<uint64_t> > DCE2_DbMapRtracker;
 class DCE2_Smb2TreeTracker
@@ -170,77 +188,185 @@ public:
     DCE2_Smb2TreeTracker (uint32_t tid_v, uint8_t share_type_v) : share_type(
             share_type_v), tid(tid_v)
     {
+        memory::MemoryCap::update_allocations(sizeof(*this));
+    }
+
+    ~DCE2_Smb2TreeTracker()
+    {
+        memory::MemoryCap::update_deallocations(sizeof(*this));
     }
 
+    // File Tracker
     DCE2_Smb2FileTracker* findFtracker(uint64_t file_id)
     {
         return file_trackers.Find(file_id);
     }
 
-    void insertFtracker(uint64_t file_id, DCE2_Smb2FileTracker* ftracker)
+    bool insertFtracker(uint64_t file_id, DCE2_Smb2FileTracker* ftracker)
     {
-        file_trackers.Insert(file_id, ftracker);
+        return file_trackers.Insert(file_id, ftracker);
     }
 
     void removeFtracker(uint64_t file_id)
     {
-        removeDataRtrackerWithFid(file_id);
         file_trackers.Remove(file_id);
     }
 
-    DCE2_Smb2RequestTracker* findDataRtracker(uint64_t message_id)
+    // Request Tracker
+    DCE2_Smb2RequestTracker* findRtracker(uint64_t mid)
     {
-        return request_trackers.Find(message_id);
+        return req_trackers.Find(mid);
     }
 
-    void insertDataRtracker(uint64_t message_id, DCE2_Smb2RequestTracker* readtracker)
+    bool insertRtracker(uint64_t message_id, DCE2_Smb2RequestTracker* rtracker)
     {
-        request_trackers.Insert(message_id, readtracker);
+        return req_trackers.Insert(message_id, rtracker);
     }
 
-    void removeDataRtracker(uint64_t message_id)
+    void removeRtracker(uint64_t message_id)
     {
-        if (findDataRtracker(message_id))
-        {
-            request_trackers.Remove(message_id);
-        }
+        req_trackers.Remove(message_id);
     }
 
-    void removeDataRtrackerWithFid(uint64_t fid)
+    int getRtrackerSize()
     {
-        auto all_requests = request_trackers.get_all_entry();
-        for ( auto & h : all_requests )
-        {
-            if (h.second->get_file_id() == fid)
-                removeDataRtracker(h.first); // this is message id
-        }
+        return req_trackers.GetSize();
+    }
+
+    // common methods
+    uint8_t get_share_type()
+    {
+        return share_type;
     }
 
-    int getDataRtrackerSize()
+    uint32_t get_tid()
     {
-        return request_trackers.GetSize();
+        return tid;
     }
 
-    uint8_t get_share_type() { return share_type; }
-    uint32_t get_tid() { return tid; }
 private:
-    uint8_t share_type;
-    uint32_t tid;
+    uint8_t share_type = 0;
+    uint32_t tid = 0;
 
-    DCE2_DbMapRtracker request_trackers;
+    DCE2_DbMapRtracker req_trackers;
     DCE2_DbMapFtracker file_trackers;
 };
 
+struct SmbFlowKey
+{
+    uint32_t ip_l[4];   /* Low IP */
+    uint32_t ip_h[4];   /* High IP */
+    uint32_t mplsLabel;
+    uint16_t port_l;    /* Low Port - 0 if ICMP */
+    uint16_t port_h;    /* High Port - 0 if ICMP */
+    uint16_t vlan_tag;
+    uint16_t addressSpaceId;
+    uint8_t ip_protocol;
+    uint8_t pkt_type;
+    uint8_t version;
+    uint8_t padding;
+
+    bool operator==(const SmbFlowKey& other) const
+    {
+        return (ip_l[0] == other.ip_l[0] and
+               ip_l[1] == other.ip_l[1] and
+               ip_l[2] == other.ip_l[2] and
+               ip_l[3] == other.ip_l[3] and
+               ip_h[0] == other.ip_h[0] and
+               ip_l[1] == other.ip_l[1] and
+               ip_l[2] == other.ip_l[2] and
+               ip_l[3] == other.ip_l[3] and
+               mplsLabel == other.mplsLabel and
+               port_l == other.port_l and
+               port_h == other.port_h and
+               vlan_tag == other.vlan_tag and
+               addressSpaceId == other.addressSpaceId and
+               ip_protocol == other.ip_protocol and
+               pkt_type == other.pkt_type and
+               version == other.version);
+    }
+};
+
+void get_flow_key(SmbFlowKey* key);
+
+struct SmbFlowKeyHash
+{
+    size_t operator()(const struct SmbFlowKey& key) const
+    {
+        uint32_t a, b, c;
+        a = b = c = 133824503;
+
+        const uint32_t* d = (const uint32_t*)&key;
+
+        a += d[0];   // IPv6 lo[0]
+        b += d[1];   // IPv6 lo[1]
+        c += d[2];   // IPv6 lo[2]
+
+        mix(a, b, c);
+
+        a += d[3];   // IPv6 lo[3]
+        b += d[4];   // IPv6 hi[0]
+        c += d[5];   // IPv6 hi[1]
+
+        mix(a, b, c);
+
+        a += d[6];   // IPv6 hi[2]
+        b += d[7];   // IPv6 hi[3]
+        c += d[8];   // mpls label
+
+        mix(a, b, c);
+
+        a += d[9];   // port lo & port hi
+        b += d[10];  // vlan tag, address space id
+        c += d[11];  // ip_proto, pkt_type, version, and 8 bits of zeroed pad
+
+        finalize(a, b, c);
+
+        return c;
+    }
+
+private:
+    inline uint32_t rot(uint32_t x, unsigned k) const
+    { return (x << k) | (x >> (32 - k)); }
+
+    inline void mix(uint32_t& a, uint32_t& b, uint32_t& c) const
+    {
+        a -= c; a ^= rot(c, 4); c += b;
+        b -= a; b ^= rot(a, 6); a += c;
+        c -= b; c ^= rot(b, 8); b += a;
+        a -= c; a ^= rot(c,16); c += b;
+        b -= a; b ^= rot(a,19); a += c;
+        c -= b; c ^= rot(b, 4); b += a;
+    }
+
+    inline void finalize(uint32_t& a, uint32_t& b, uint32_t& c) const
+    {
+        c ^= b; c -= rot(b,14);
+        a ^= c; a -= rot(c,11);
+        b ^= a; b -= rot(a,25);
+        c ^= b; c -= rot(b,16);
+        a ^= c; a -= rot(c,4);
+        b ^= a; b -= rot(a,14);
+        c ^= b; c -= rot(b,24);
+    }
+};
+
 typedef DCE2_DbMap<uint32_t, DCE2_Smb2TreeTracker*, std::hash<uint32_t> > DCE2_DbMapTtracker;
+typedef DCE2_DbMap<struct SmbFlowKey, DCE2_Smb2SsnData*, SmbFlowKeyHash> DCE2_DbMapConntracker;
 class DCE2_Smb2SessionTracker
 {
 public:
 
-    DCE2_Smb2SessionTracker() { }
+    DCE2_Smb2SessionTracker() { memory::MemoryCap::update_allocations(sizeof(*this)); }
+
+    ~DCE2_Smb2SessionTracker();
+
+    void removeSessionFromAllConnection();
 
-    void insertTtracker(uint32_t tree_id, DCE2_Smb2TreeTracker* ttr)
+    // tree tracker
+    bool insertTtracker(uint32_t tree_id, DCE2_Smb2TreeTracker* ttr)
     {
-        tree_trackers.Insert(tree_id, ttr);
+        return tree_trackers.Insert(tree_id, ttr);
     }
 
     DCE2_Smb2TreeTracker* findTtracker(uint32_t tree_id)
@@ -250,55 +376,50 @@ public:
 
     void removeTtracker(uint32_t tree_id)
     {
-        // Remove any dangling request trackers with tree id
-        removeRtrackerWithTid(tree_id);
         tree_trackers.Remove(tree_id);
     }
 
-    DCE2_Smb2RequestTracker* findRtracker(uint64_t mid)
+    // ssd tracker
+    bool insertConnTracker(SmbFlowKey key, DCE2_Smb2SsnData* ssd)
     {
-        return create_request_trackers.Find(mid);
+        return conn_trackers.Insert(key, ssd);
     }
 
-    void insertRtracker(uint64_t message_id, DCE2_Smb2RequestTracker* rtracker)
+    DCE2_Smb2SsnData* findConnTracker(SmbFlowKey key)
     {
-        create_request_trackers.Insert(message_id, rtracker);
+        return conn_trackers.Find(key);
     }
 
-    void removeRtracker(uint64_t message_id)
+    void removeConnTracker(SmbFlowKey key)
     {
-        create_request_trackers.Remove(message_id);
+        conn_trackers.Remove(key);
     }
 
-    void removeRtrackerWithTid(uint32_t tid)
+    int getConnTrackerSize()
     {
-        auto all_requests = create_request_trackers.get_all_entry();
-        for ( auto & h : all_requests )
-        {
-            if (h.second->get_tree_tracker() and h.second->get_tree_tracker()->get_tid() == tid)
-                removeRtracker(h.first); // this is message id
-        }
+        return conn_trackers.GetSize();
     }
 
     uint16_t getTotalRequestsPending()
     {
         uint16_t total_count = 0;
         auto all_tree_trackers = tree_trackers.get_all_entry();
-        for ( auto & h : all_tree_trackers )
+        for ( auto& h : all_tree_trackers )
         {
-            total_count += h.second->getDataRtrackerSize(); // all read/write
+            total_count += h.second->getRtrackerSize();
         }
-        total_count += create_request_trackers.GetSize(); // all create
         return total_count;
     }
 
-    void set_session_id(uint64_t sid) { session_id = sid; }
-    uint64_t get_session_id() { return session_id; }
+    void set_session_id(uint64_t sid)
+    {
+        session_id = sid;
+        conn_trackers.SetDoNotFree();
+    }
 
-private:
-    uint64_t session_id;
+    DCE2_DbMapConntracker conn_trackers;
     DCE2_DbMapTtracker tree_trackers;
-    DCE2_DbMapRtracker create_request_trackers;
+    uint64_t session_id = 0;
 };
 
 typedef DCE2_DbMap<uint64_t, DCE2_Smb2SessionTracker*, std::hash<uint64_t> > DCE2_DbMapStracker;
@@ -419,6 +540,9 @@ struct Smb2CreateRequestHdr
     uint32_t create_contexts_length;  /* length of contexts */
 };
 
+// file attribute for create response
+#define SMB2_CREATE_RESPONSE_DIRECTORY 0x10
+
 struct Smb2CreateResponseHdr
 {
     uint16_t structure_size;          /* This MUST be set to 89 */
index c71c6d47ab37516de6c105b7174e3ad8807b432b..b17dea4a5c1c18fa8c527a559e11aec49137afe5 100644 (file)
 #include "dce_smb2_commands.h"
 #include "hash/hash_key_operations.h"
 #include "log/messages.h"
-#include "main/snort_debug.h"
 #include "packet_io/active.h"
 #include "protocols/packet.h"
 
 using namespace snort;
 #define UNKNOWN_FILE_SIZE (~0)
-#define SMB2_CHECK_HDR_ERROR(smb_data, end, strcuture_size, counter)\
-{ \
-    if ((smb_data + (strcuture_size)) > end)\
-    {\
-        counter++;\
-        return;\
-    }\
-}
 
-static inline FileContext* get_smb_file_context(uint64_t file_id, uint64_t multi_file_processing_id,
+#define SMB2_CHECK_HDR_ERROR(smb_data, end, strcuture_size, counter) \
+    { \
+        if ((smb_data + (strcuture_size)) > end) \
+        { \
+            counter ++; \
+            return; \
+        } \
+    }
+
+static inline FileContext* get_smb_file_context(uint64_t file_id, uint64_t
+    multi_file_processing_id,
     bool to_create = false)
 {
     FileFlows* file_flows = FileFlows::get_file_flows(DetectionEngine::get_current_packet()->flow);
@@ -57,8 +58,23 @@ static inline FileContext* get_smb_file_context(uint64_t file_id, uint64_t multi
     return file_flows->get_file_context(file_id, to_create, multi_file_processing_id);
 }
 
-void DCE2_Smb2ProcessFileData(DCE2_Smb2SsnData* ssd, const uint8_t* file_data,
-    uint32_t data_size, FileDirection dir)
+static void DCE2_Smb2CleanFtrackerTcpRef(DCE2_Smb2SessionTracker* str, uint64_t file_id)
+{
+    auto all_conn_trackers = str->conn_trackers.get_all_entry();
+    for ( auto& h : all_conn_trackers )
+    {
+        if (h.second->ftracker_tcp)
+        {
+            if (h.second->ftracker_tcp->file_id == file_id)
+            {
+                h.second->ftracker_tcp = nullptr;
+            }
+        }
+    }
+}
+
+bool DCE2_Smb2ProcessFileData(DCE2_Smb2SsnData* ssd, const uint8_t* file_data,
+    uint32_t data_size)
 {
     int64_t file_detection_depth = DCE2_ScSmbFileDepth((dce2SmbProtoConf*)ssd->sd.config);
     int64_t detection_size = 0;
@@ -83,18 +99,34 @@ void DCE2_Smb2ProcessFileData(DCE2_Smb2SsnData* ssd, const uint8_t* file_data,
 
     Packet* p = DetectionEngine::get_current_packet();
     ssd->ftracker_tcp->bytes_processed += detection_size;
+    FileDirection dir = ssd->ftracker_tcp->upload ? FILE_UPLOAD : FILE_DOWNLOAD;
 
     // Do not process data beyond file size if file size is known.
     FileFlows* file_flows = FileFlows::get_file_flows(p->flow);
-    if ( !file_flows or ( ssd->ftracker_tcp->file_size and
-        ssd->ftracker_tcp->bytes_processed > ssd->ftracker_tcp->file_size ) )
+    if ( !file_flows or (ssd->ftracker_tcp->file_size and
+        ssd->ftracker_tcp->bytes_processed > ssd->ftracker_tcp->file_size) )
     {
         dce2_smb_stats.v2_extra_file_data_err++;
-        return;
+
+        DCE2_Smb2TreeTracker* ttr = ssd->ftracker_tcp->ttr;
+        uint64_t file_id = ssd->ftracker_tcp->file_id;
+        DCE2_Smb2CleanFtrackerTcpRef(ssd->ftracker_tcp->str, file_id);
+        ttr->removeFtracker(file_id);
+
+        return false;
     }
 
-    file_flows->file_process(p, ssd->ftracker_tcp->file_name_hash, file_data, data_size,
-        ssd->ftracker_tcp->file_offset, dir, ssd->ftracker_tcp->file_id);
+    if (!file_flows->file_process(p, ssd->ftracker_tcp->file_name_hash, file_data, data_size,
+        ssd->ftracker_tcp->file_offset, dir, ssd->ftracker_tcp->file_id))
+    {
+        DCE2_Smb2TreeTracker* ttr = ssd->ftracker_tcp->ttr;
+        uint64_t file_id = ssd->ftracker_tcp->file_id;
+        DCE2_Smb2CleanFtrackerTcpRef(ssd->ftracker_tcp->str, file_id);
+        ttr->removeFtracker(file_id);
+
+        return false;
+    }
+    return true;
 }
 
 //-------------------------------------------------------------------------
@@ -140,9 +172,11 @@ void DCE2_Smb2TreeConnect(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
             smb_data, end, SMB2_TREE_CONNECT_RESPONSE_STRUC_SIZE,
             dce2_smb_stats.v2_tree_cnct_resp_hdr_err)
 
-        if (!DCE2_Smb2FindElseCreateTid(ssd, tid,
-                ((const Smb2TreeConnectResponseHdr*)smb_data)->share_type, str))
+        if (!DCE2_Smb2InsertTid(ssd, tid,
+            ((const Smb2TreeConnectResponseHdr*)smb_data)->share_type, str))
+        {
             dce2_smb_stats.v2_tree_cnct_ignored++;
+        }
     }
     else if (structure_size != SMB2_TREE_CONNECT_REQUEST_STRUC_SIZE)
     {
@@ -198,22 +232,19 @@ static void DCE2_Smb2CreateRequest(DCE2_Smb2SsnData* ssd,
             return;
         }
 
-        char* file_name = DCE2_SmbGetFileName(file_data, size, true,
-            &name_len);
-
         if (ssd->max_outstanding_requests > str->getTotalRequestsPending())
         {
-            DCE2_Smb2RequestTracker* rtracker = str->findRtracker(mid);
+            DCE2_Smb2RequestTracker* rtracker = ttr->findRtracker(mid);
             if (rtracker) // Cleanup existing tracker
-                str->removeRtracker(mid);
+                ttr->removeRtracker(mid);
 
-            rtracker = new DCE2_Smb2RequestTracker(0, 0, file_name, name_len, ttr);
+            char* file_name = DCE2_SmbGetFileName(file_data, size, true, &name_len);
 
-            str->insertRtracker(mid, rtracker);
+            rtracker = new DCE2_Smb2RequestTracker(file_name, name_len);
+            ttr->insertRtracker(mid, rtracker);
         }
         else
         {
-            snort_free(file_name);
             dce_alert(GID_DCE2, DCE2_SMB_MAX_REQS_EXCEEDED, (dce2CommonStats*)&dce2_smb_stats,
                 ssd->sd);
         }
@@ -226,50 +257,59 @@ static void DCE2_Smb2CreateRequest(DCE2_Smb2SsnData* ssd,
 
 //-------------------------------------------------------------------------
 // Process create response to create file tracker with file id and file
-// size. Request tracker is cleaned after updating file name in file tracker 
+// size. Request tracker is cleaned after updating file name in file tracker
 //-------------------------------------------------------------------------
 static void DCE2_Smb2CreateResponse(DCE2_Smb2SsnData*,
-    const Smb2CreateResponseHdr* smb_create_hdr,
-    DCE2_Smb2RequestTracker* rtracker)
+    const Smb2CreateResponseHdr* smb_create_hdr, DCE2_Smb2RequestTracker* rtracker,
+    DCE2_Smb2TreeTracker* ttr, DCE2_Smb2SessionTracker* str, uint64_t fileId_persistent)
 {
     uint64_t file_size = 0;
-    uint64_t fileId_persistent = alignedNtohq((const uint64_t*)(&(smb_create_hdr->fileId_persistent)));
 
     if (smb_create_hdr->end_of_file)
     {
         file_size = alignedNtohq((const uint64_t*)(&(smb_create_hdr->end_of_file)));
     }
 
-    DCE2_Smb2FileTracker* ftracker = rtracker->get_tree_tracker()->findFtracker(fileId_persistent);
+    DCE2_Smb2FileTracker* ftracker = ttr->findFtracker(fileId_persistent);
     if (!ftracker)
     {
-        ftracker = new DCE2_Smb2FileTracker(
-            fileId_persistent, rtracker->get_file_name(), file_size);
-    }
-    else // compounded create request + read request case
-    {
-        ftracker->file_name.assign(rtracker->get_file_name());
-        ftracker->file_size = file_size;
+        ftracker = new DCE2_Smb2FileTracker(fileId_persistent, ttr, str);
+        ttr->insertFtracker(fileId_persistent, ftracker);
     }
+    ftracker->file_name = rtracker->fname;
+    ftracker->file_name_len = rtracker->fname_len;
+    ftracker->file_size = file_size;
 
-    ftracker->file_name_hash = str_to_hash(
-        (const uint8_t *)rtracker->get_file_name(), rtracker->get_file_name_len());
-
-    if (rtracker->get_file_name() and rtracker->get_file_name_len())
+    if (rtracker->fname and rtracker->fname_len)
     {
-        FileContext* file = get_smb_file_context(ftracker->file_name_hash, fileId_persistent, true);
-        if (file and file->verdict == FILE_VERDICT_UNKNOWN)
+        ftracker->file_name_hash = str_to_hash(
+            (const uint8_t*)rtracker->fname, rtracker->fname_len);
+
+        FileContext* file = get_smb_file_context(ftracker->file_name_hash, fileId_persistent,
+            true);
+        if (file)
+        {
+            if (file->verdict == FILE_VERDICT_UNKNOWN)
+            {
+                file->set_file_size(!file_size ? UNKNOWN_FILE_SIZE : file_size);
+                file->set_file_name(ftracker->file_name, ftracker->file_name_len);
+            }
+        }
+        else
         {
-            file->set_file_size(!file_size ? UNKNOWN_FILE_SIZE : file_size);
-            file->set_file_name(rtracker->get_file_name(), rtracker->get_file_name_len());
+            ftracker->ignore = true; // could not create file context, hence this file transfer
+                                     // cant be inspected
         }
+        rtracker->set_file_id(fileId_persistent); // to ensure file tracker will free file name
+    }
+    else
+    {
+        ftracker->ignore = true; // file can not be inspected as file name is null
     }
-
-    rtracker->get_tree_tracker()->insertFtracker(fileId_persistent, ftracker);
 }
 
 //-------------------------------------------------------------------------
-// Process create request to handle mid stream sessions by adding tree 
+// Process create request to handle mid stream sessions by adding tree
 // tracker if not already present. Process create response for only disk
 // share type.
 //-------------------------------------------------------------------------
@@ -277,19 +317,14 @@ void DCE2_Smb2Create(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
     const uint8_t* smb_data, const uint8_t* end, uint64_t mid, uint64_t sid, uint32_t tid)
 {
     DCE2_Smb2SessionTracker* str = DCE2_Smb2FindElseCreateSid(ssd, sid);
+    DCE2_Smb2TreeTracker* ttr = str->findTtracker(tid);
     uint16_t structure_size = alignedNtohs((const uint16_t*)smb_data);
 
     if (structure_size == SMB2_ERROR_RESPONSE_STRUC_SIZE and Smb2Error(smb_hdr))
     {
-        //in case of compound create + read, a ftracker is already created, remove it
-        DCE2_Smb2RequestTracker* rtr = str->findRtracker(mid);
-        if ( rtr and rtr->get_tree_tracker() and rtr->get_file_id() )
-        {
-            if (ssd->ftracker_tcp->file_id == rtr->get_file_id())
-                ssd->ftracker_tcp = NULL;
-            rtr->get_tree_tracker()->removeFtracker(rtr->get_file_id());
-        }
-        str->removeRtracker(mid);
+        if (ttr)
+            ttr->removeRtracker(mid);
+
         dce2_smb_stats.v2_crt_err_resp++;
     }
     // Using structure size to decide whether it is response or request
@@ -299,15 +334,16 @@ void DCE2_Smb2Create(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
             smb_data, end, SMB2_CREATE_REQUEST_STRUC_SIZE - 1,
             dce2_smb_stats.v2_crt_req_hdr_err)
 
-        DCE2_Smb2TreeTracker* ttr = str->findTtracker(tid);
         if (!ttr)
         {
             ttr = DCE2_Smb2InsertTid(ssd, tid, SMB2_SHARE_TYPE_DISK, str);
+            if (!ttr)
+            {
+                return;
+            }
         }
         else if (SMB2_SHARE_TYPE_DISK != ttr->get_share_type())
         {
-            debug_logf(dce_smb_trace, nullptr, "Not handling create request for IPC with TID (%u)\n",
-                tid);
             dce2_smb_stats.v2_crt_req_ipc++;
             return;
         }
@@ -319,27 +355,36 @@ void DCE2_Smb2Create(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
             smb_data, end, SMB2_CREATE_RESPONSE_STRUC_SIZE - 1,
             dce2_smb_stats.v2_crt_resp_hdr_err)
 
-        DCE2_Smb2RequestTracker* rtr = str->findRtracker(mid);
+        if (!ttr)
+        {
+            dce2_smb_stats.v2_crt_tree_trkr_misng++;
+            return;
+        }
+
+        DCE2_Smb2RequestTracker* rtr = ttr->findRtracker(mid);
         if (!rtr)
         {
-            debug_logf(dce_smb_trace, nullptr,
-                "No create request received for this MID (%" PRIu64 ")\n", mid);
             dce2_smb_stats.v2_crt_rtrkr_misng++;
             return;
         }
-        // Check required only for null tree tracker since for IPC,
-        // the request tracker itself is not added.
-        if (!rtr->get_tree_tracker())
+
+        uint64_t fileId_persistent = alignedNtohq((const uint64_t*)(
+                &(((const Smb2CreateResponseHdr*)smb_data)->fileId_persistent)));
+
+        if (((const Smb2CreateResponseHdr*)smb_data)->file_attributes &
+            SMB2_CREATE_RESPONSE_DIRECTORY)
         {
-            debug_logf(dce_smb_trace, nullptr,
-                "Tree tracker is missing for create request\n");
-            dce2_smb_stats.v2_crt_tree_trkr_misng++;
-            str->removeRtracker(mid);
+            ttr->removeRtracker(mid);
+            if (ssd->ftracker_tcp and ssd->ftracker_tcp->file_id == fileId_persistent)
+                ssd->ftracker_tcp = nullptr;
+            DCE2_Smb2CleanFtrackerTcpRef(str, fileId_persistent);
+            ttr->removeFtracker(fileId_persistent);
             return;
         }
-        
-        DCE2_Smb2CreateResponse(ssd, (const Smb2CreateResponseHdr*)smb_data, rtr);
-        str->removeRtracker(mid);
+
+        DCE2_Smb2CreateResponse(ssd, (const Smb2CreateResponseHdr*)smb_data, rtr, ttr,
+            str, fileId_persistent);
+        ttr->removeRtracker(mid);
     }
     else
     {
@@ -352,7 +397,8 @@ void DCE2_Smb2Create(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
 // download request with unknown size.
 //-------------------------------------------------------------------------
 void DCE2_Smb2CloseCmd(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
-    const uint8_t* smb_data, const uint8_t* end, DCE2_Smb2TreeTracker* ttr)
+    const uint8_t* smb_data, const uint8_t* end, DCE2_Smb2TreeTracker* ttr,
+    DCE2_Smb2SessionTracker* str)
 {
     uint16_t structure_size = alignedNtohs((const uint16_t*)smb_data);
 
@@ -367,7 +413,8 @@ void DCE2_Smb2CloseCmd(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
             smb_data, end, SMB2_CLOSE_REQUEST_STRUC_SIZE,
             dce2_smb_stats.v2_cls_req_hdr_err)
 
-        uint64_t fileId_persistent = alignedNtohq(&(((const Smb2CloseRequestHdr*)smb_data)->fileId_persistent));
+        uint64_t fileId_persistent = alignedNtohq(&(((const
+            Smb2CloseRequestHdr*)smb_data)->fileId_persistent));
         DCE2_Smb2FileTracker* ftracker =  ttr->findFtracker(fileId_persistent);
         if (!ftracker)
         {
@@ -375,32 +422,28 @@ void DCE2_Smb2CloseCmd(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
             return;
         }
 
-        if (!ftracker->file_size and ftracker->file_offset)
+        if (!ftracker->ignore and !ftracker->file_size and ftracker->file_offset)
         {
-            // If close command request comes just after create response, we dont have 
-            // information to know the direction, hence below code was included.
-            FileDirection dir = DetectionEngine::get_current_packet()->is_from_client() ?
-                FILE_UPLOAD : FILE_DOWNLOAD;
-
             ftracker->file_size = ftracker->file_offset;
             FileContext* file = get_smb_file_context(ftracker->file_name_hash, fileId_persistent);
             if (file)
             {
                 file->set_file_size(ftracker->file_size);
             }
-           
+
+            ssd->ftracker_tcp = ftracker;
+
             // In case of upload/download of file with UNKNOWN size, we will not be able to
             // detect malicious file during write request or read response. Once the close
             // command request comes, we will go for file inspection and block an subsequent
             // upload/download request for this file even with unknown size
-            DCE2_Smb2ProcessFileData(ssd, nullptr, 0, dir);
+            DCE2_Smb2ProcessFileData(ssd, nullptr, 0);
+        }
+        else
+        {
+            DCE2_Smb2CleanFtrackerTcpRef(str, fileId_persistent);
+            ttr->removeFtracker(fileId_persistent);
         }
-
-        if (ssd->ftracker_tcp and ssd->ftracker_tcp->file_id == fileId_persistent)
-            ssd->ftracker_tcp = nullptr;
-
-        ttr->removeFtracker(fileId_persistent);
-
     }
     else if (structure_size != SMB2_CLOSE_RESPONSE_STRUC_SIZE)
     {
@@ -435,23 +478,26 @@ void DCE2_Smb2SetInfo(DCE2_Smb2SsnData*, const Smb2Hdr* smb_hdr,
         {
             uint64_t file_size = alignedNtohq((const uint64_t*)file_data);
             uint64_t fileId_persistent = alignedNtohq(&(smb_set_info_hdr->fileId_persistent));
-
             DCE2_Smb2FileTracker* ftracker = ttr->findFtracker(fileId_persistent);
-            if (ftracker)
+            if (ftracker and !ftracker->ignore)
             {
                 ftracker->file_size = file_size;
-
-                FileContext* file = get_smb_file_context(ftracker->file_name_hash, fileId_persistent);
+                FileContext* file = get_smb_file_context(ftracker->file_name_hash,
+                    fileId_persistent);
                 if (file)
                 {
                     file->set_file_size(ftracker->file_size);
                 }
             }
             else
+            {
                 dce2_smb_stats.v2_stinf_req_ftrkr_misng++;
+            }
         }
         else
+        {
             dce2_smb_stats.v2_stinf_req_hdr_err++;
+        }
     }
     else if (structure_size != SMB2_SET_INFO_RESPONSE_STRUC_SIZE)
     {
@@ -466,16 +512,18 @@ static void DCE2_Smb2ReadRequest(DCE2_Smb2SsnData* ssd,
     const Smb2ReadRequestHdr* smb_read_hdr, const uint8_t*, DCE2_Smb2SessionTracker* str,
     DCE2_Smb2TreeTracker* ttr, uint64_t message_id)
 {
-    DCE2_Smb2RequestTracker* readtracker = nullptr;
-
     uint64_t offset = alignedNtohq((const uint64_t*)(&(smb_read_hdr->offset)));
-    uint64_t fileId_persistent = alignedNtohq((const uint64_t*)(&(smb_read_hdr->fileId_persistent)));
+    uint64_t fileId_persistent = alignedNtohq((const
+        uint64_t*)(&(smb_read_hdr->fileId_persistent)));
 
     if (ssd->max_outstanding_requests > str->getTotalRequestsPending())
     {
-         readtracker = new DCE2_Smb2RequestTracker(
-             offset, fileId_persistent, nullptr, 0, nullptr);
-         ttr->insertDataRtracker(message_id, readtracker);
+        DCE2_Smb2RequestTracker* readtracker = ttr->findRtracker(message_id);
+        if (!readtracker)
+        {
+            readtracker = new DCE2_Smb2RequestTracker(fileId_persistent, offset);
+            ttr->insertRtracker(message_id, readtracker);
+        }
     }
     else
     {
@@ -487,7 +535,7 @@ static void DCE2_Smb2ReadRequest(DCE2_Smb2SsnData* ssd,
     DCE2_Smb2FileTracker* ftracker =  ttr->findFtracker(fileId_persistent);
     if (!ftracker) // compounded create request + read request case
     {
-        ftracker = new DCE2_Smb2FileTracker(fileId_persistent, nullptr, 0);
+        ftracker = new DCE2_Smb2FileTracker(fileId_persistent, ttr, str);
         ttr->insertFtracker(fileId_persistent, ftracker);
     }
 
@@ -511,7 +559,7 @@ static void DCE2_Smb2ReadResponse(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
     uint16_t data_offset;
     DCE2_Smb2RequestTracker* request;
 
-    request = ttr->findDataRtracker(message_id);
+    request = ttr->findRtracker(message_id);
     if (!request)
     {
         dce2_smb_stats.v2_read_rtrkr_misng++;
@@ -524,18 +572,18 @@ static void DCE2_Smb2ReadResponse(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
     }
 
     DCE2_Smb2FileTracker* ftracker =  ttr->findFtracker(request->get_file_id());
-    if ( ftracker ) // file tracker can never be NULL for read response
+    if ( ftracker and !ftracker->ignore )
     {
         ftracker->file_offset = request->get_offset();
-        ttr->removeDataRtracker(message_id);
+        ttr->removeRtracker(message_id);
+
         ssd->ftracker_tcp = ftracker;
 
-        DCE2_Smb2ProcessFileData(ssd, file_data, data_size, FILE_DOWNLOAD);
+        if (!DCE2_Smb2ProcessFileData(ssd, file_data, data_size))
+            return;
         ftracker->file_offset += data_size;
 
         uint32_t total_data_length = alignedNtohl((const uint32_t*)&(smb_read_hdr->length));
-        debug_logf(dce_smb_trace, nullptr, "smbv2 total_data=%d data_size=%d ssd=%p\n", total_data_length,data_size,
-           (void*)ssd);
         if (total_data_length > (uint32_t)data_size)
         {
             ftracker->smb2_pdu_state = DCE2_SMB_PDU_STATE__RAW_DATA;
@@ -554,14 +602,13 @@ void DCE2_Smb2Read(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
 
     if (Smb2Error(smb_hdr) and structure_size == SMB2_ERROR_RESPONSE_STRUC_SIZE)
     {
-        DCE2_Smb2RequestTracker* rtr = ttr->findDataRtracker(mid);
+        DCE2_Smb2RequestTracker* rtr = ttr->findRtracker(mid);
         if (rtr and rtr->get_file_id())
         {
-            if (ssd->ftracker_tcp->file_id == rtr->get_file_id())
-                ssd->ftracker_tcp = NULL;
+            DCE2_Smb2CleanFtrackerTcpRef(str, rtr->get_file_id());
             ttr->removeFtracker(rtr->get_file_id());
         }
-        ttr->removeDataRtracker(mid);
+        ttr->removeRtracker(mid);
         dce2_smb_stats.v2_read_err_resp++;
     }
     // Using structure size to decide whether it is response or request
@@ -598,15 +645,17 @@ static void DCE2_Smb2WriteRequest(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
     int data_size = end - file_data;
     uint64_t fileId_persistent, offset;
     uint16_t data_offset;
-    DCE2_Smb2RequestTracker* writetracker = nullptr;
 
     fileId_persistent = alignedNtohq((const uint64_t*)(&(smb_write_hdr->fileId_persistent)));
 
     if (ssd->max_outstanding_requests > str->getTotalRequestsPending())
     {
-         writetracker = new DCE2_Smb2RequestTracker(
-              0, fileId_persistent, nullptr, 0, nullptr);
-         ttr->insertDataRtracker(mid, writetracker);
+        DCE2_Smb2RequestTracker* writetracker = ttr->findRtracker(mid);
+        if (!writetracker)
+        {
+            writetracker = new DCE2_Smb2RequestTracker(fileId_persistent);
+            ttr->insertRtracker(mid, writetracker);
+        }
     }
     else
     {
@@ -623,7 +672,7 @@ static void DCE2_Smb2WriteRequest(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
 
     offset = alignedNtohq((const uint64_t*)(&(smb_write_hdr->offset)));
     DCE2_Smb2FileTracker* ftracker = ttr->findFtracker(fileId_persistent);
-    if (ftracker) // file tracker can not be NULL here
+    if (ftracker and !ftracker->ignore) // file tracker can not be nullptr here
     {
         if (ftracker->file_size and (offset > ftracker->file_size))
         {
@@ -631,11 +680,14 @@ static void DCE2_Smb2WriteRequest(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
                 ssd->sd);
         }
         ftracker->file_offset = offset;
+        ftracker->upload = true;
+
         ssd->ftracker_tcp = ftracker;
-        DCE2_Smb2ProcessFileData(ssd, file_data, data_size, FILE_UPLOAD);
+
+        if (!DCE2_Smb2ProcessFileData(ssd, file_data, data_size))
+            return;
         ftracker->file_offset += data_size;
         uint32_t total_data_length = alignedNtohl((const uint32_t*)&(smb_write_hdr->length));
-        debug_logf(dce_smb_trace, nullptr, "smbv2 total_data=%d data_size=%d\n",total_data_length,data_size);
         if (total_data_length > (uint32_t)data_size)
         {
             ftracker->smb2_pdu_state = DCE2_SMB_PDU_STATE__RAW_DATA;
@@ -654,14 +706,13 @@ void DCE2_Smb2Write(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
 
     if (structure_size == SMB2_ERROR_RESPONSE_STRUC_SIZE and Smb2Error(smb_hdr))
     {
-        DCE2_Smb2RequestTracker* wtr = ttr->findDataRtracker(mid);
+        DCE2_Smb2RequestTracker* wtr = ttr->findRtracker(mid);
         if (wtr and wtr->get_file_id())
         {
-            if (ssd->ftracker_tcp->file_id == wtr->get_file_id())
-                ssd->ftracker_tcp = NULL;
+            DCE2_Smb2CleanFtrackerTcpRef(str, wtr->get_file_id());
             ttr->removeFtracker(wtr->get_file_id());
         }
-        ttr->removeDataRtracker(mid);
+        ttr->removeRtracker(mid);
         dce2_smb_stats.v2_wrt_err_resp++;
     }
     // Using structure size to decide whether it is response or request
@@ -670,11 +721,12 @@ void DCE2_Smb2Write(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
         SMB2_CHECK_HDR_ERROR(
             smb_data, end, SMB2_WRITE_REQUEST_STRUC_SIZE - 1,
             dce2_smb_stats.v2_wrt_req_hdr_err)
-        DCE2_Smb2WriteRequest(ssd, smb_hdr, (const Smb2WriteRequestHdr*)smb_data, end, str, ttr, mid);
+        DCE2_Smb2WriteRequest(ssd, smb_hdr, (const Smb2WriteRequestHdr*)smb_data, end, str, ttr,
+            mid);
     }
     else if (structure_size == SMB2_WRITE_RESPONSE_STRUC_SIZE)
     {
-        ttr->removeDataRtracker(mid);
+        ttr->removeRtracker(mid);
     }
     else
     {
@@ -687,14 +739,20 @@ void DCE2_Smb2Write(DCE2_Smb2SsnData* ssd, const Smb2Hdr* smb_hdr,
 // trackers and their corresponding file trackers
 //-------------------------------------------------------------------------
 void DCE2_Smb2Logoff(DCE2_Smb2SsnData* ssd, const uint8_t* smb_data,
-        const uint64_t sid)
+    const uint64_t sid)
 {
     if (alignedNtohs((const uint16_t*)smb_data) == SMB2_LOGOFF_REQUEST_STRUC_SIZE)
     {
-        DCE2_Smb2RemoveSidInSsd(ssd, sid);
+        DCE2_Smb2SessionTracker* str = DCE2_Smb2FindSidInSsd(ssd, sid);
+        if (str)
+        {
+            str->removeSessionFromAllConnection();
+            DCE2_SmbSessionCacheRemove(sid);
+        }
     }
     else
     {
         dce2_smb_stats.v2_logoff_inv_str_sz++;
     }
 }
+
index f67553870da9c7df506d3c5835cab01c1d751bec..97a86d9bc00abe23b2a8b74aede3d0b605834b5d 100644 (file)
@@ -52,14 +52,15 @@ void DCE2_Smb2Write(DCE2_Smb2SsnData*, const Smb2Hdr*,
 void DCE2_Smb2SetInfo(DCE2_Smb2SsnData*, const Smb2Hdr*,
     const uint8_t* smb_data, const uint8_t* end, DCE2_Smb2TreeTracker* ttr);
 
-void DCE2_Smb2ProcessFileData(DCE2_Smb2SsnData*, const uint8_t* file_data,
-    uint32_t data_size, FileDirection dir);
+bool DCE2_Smb2ProcessFileData(DCE2_Smb2SsnData*, const uint8_t* file_data,
+    uint32_t data_size);
 
 void DCE2_Smb2CloseCmd(DCE2_Smb2SsnData*, const Smb2Hdr*,
-    const uint8_t* smb_data, const uint8_t* end, DCE2_Smb2TreeTracker* ttr);
+    const uint8_t* smb_data, const uint8_t* end, DCE2_Smb2TreeTracker* ttr,
+    DCE2_Smb2SessionTracker* str);
 
 void DCE2_Smb2Logoff(DCE2_Smb2SsnData*, const uint8_t* smb_data,
-        const uint64_t sid);
+    const uint64_t sid);
 
 #endif
 
index 995ba95997942512f99ec706b2d267df0a0af5a0..9e68f5d523811d3937ccfc9a2f280da720acfc8d 100644 (file)
@@ -34,6 +34,16 @@ using namespace snort;
 size_t session_cache_size;
 THREAD_LOCAL SmbSessionCache* smb2_session_cache;
 
+Smb2SidHashKey get_key(uint64_t sid)
+{
+    Smb2SidHashKey key;
+    Flow* flow = DetectionEngine::get_current_packet()->flow;
+    memcpy(&key.cip, &flow->client_ip, sizeof(SfIp));
+    memcpy(&key.sip, &flow->server_ip, sizeof(SfIp));
+    key.sid = sid;
+    return key;
+}
+
 DCE2_Smb2SessionTracker* DCE2_Smb2FindElseCreateSid(DCE2_Smb2SsnData* ssd, const
     uint64_t sid)
 {
@@ -47,16 +57,16 @@ DCE2_Smb2SessionTracker* DCE2_Smb2FindElseCreateSid(DCE2_Smb2SsnData* ssd, const
         stracker = DCE2_SmbSessionCacheFindElseCreate(sid, &entry_created);
         assert(stracker);
         if (entry_created)
-        {
             stracker->set_session_id(sid);
-        }
+
         DCE2_Smb2InsertSidInSsd(ssd, sid, stracker);
     }
 
     return stracker;
 }
 
-DCE2_Smb2TreeTracker* DCE2_Smb2InsertTid(DCE2_Smb2SsnData* ssd, const uint32_t tid, uint8_t share_type,
+DCE2_Smb2TreeTracker* DCE2_Smb2InsertTid(DCE2_Smb2SsnData* ssd, const uint32_t tid, uint8_t
+    share_type,
     DCE2_Smb2SessionTracker* str)
 {
     if (share_type == SMB2_SHARE_TYPE_DISK and
@@ -67,20 +77,33 @@ DCE2_Smb2TreeTracker* DCE2_Smb2InsertTid(DCE2_Smb2SsnData* ssd, const uint32_t t
         return nullptr;
     }
 
-    DCE2_Smb2TreeTracker* ttracker = new DCE2_Smb2TreeTracker(tid, share_type);
-    str->insertTtracker(tid, ttracker);
+    DCE2_Smb2TreeTracker* ttracker = str->findTtracker(tid);
+    if (!ttracker)
+    {
+        ttracker = new DCE2_Smb2TreeTracker(tid, share_type);
+        str->insertTtracker(tid, ttracker);
+    }
+
     return ttracker;
 }
 
-DCE2_Smb2TreeTracker* DCE2_Smb2FindElseCreateTid(DCE2_Smb2SsnData* ssd, const uint32_t tid,
-    uint8_t share_type, DCE2_Smb2SessionTracker* str)
+void DCE2_Smb2RemoveAllSession(DCE2_Smb2SsnData* ssd)
 {
-    DCE2_Smb2TreeTracker* ttr = str->findTtracker(tid);
-    if (!ttr)
+    SmbFlowKey key;
+    get_flow_key(&key);
+    ssd->ftracker_tcp = nullptr;
+
+    // iterate over smb sessions for this tcp connection and cleanup its instance from them
+    auto all_session_trackers = ssd->session_trackers.get_all_entry();
+    for ( auto& h : all_session_trackers )
     {
-        ttr = DCE2_Smb2InsertTid(ssd, tid, share_type, str);
+        ssd->session_trackers.Remove(h.second->session_id);  // remove session tracker from this
+                                                             // tcp conn
+        h.second->removeConnTracker(key); // remove tcp connection from session tracker
+        if (!h.second->getConnTrackerSize()) // if no tcp connection present in session tracker,
+                                             // delete session tracker
+        {
+            DCE2_SmbSessionCacheRemove(h.second->session_id);
+        }
     }
-
-    return ttr;
 }
-
index f8fe5a57d958a9a0d736dda81d00e0eaeb4c5d5c..8375f5e796cf93ac15b55e94f18fcc983445b987 100644 (file)
 #include "dce_smb.h"
 #include "dce_smb2.h"
 #include "file_api/file_flows.h"
-#include "hash/lru_cache_shared.h"
+#include "sfip/sf_ip.h"
 
-#define SMB2_SID_HASH(sid) std::hash<uint64_t>()(sid)
+struct Smb2SidHashKey
+{
+    snort::SfIp cip; // client ip
+    snort::SfIp sip; // server ip
+    uint64_t sid;
+    bool operator==(const Smb2SidHashKey& other) const
+    {
+        return( sid == other.sid and
+               cip == other.cip and
+               sip == other.sip );
+    }
+};
+
+struct Smb2SidHash
+{
+    size_t operator()(const Smb2SidHashKey& key) const
+    {
+        const uint32_t* cip64 = key.cip.get_ip6_ptr();
+        const uint32_t* sip64 = key.cip.get_ip6_ptr();
+        const uint32_t sid_lo = key.sid & 0xFFFFFFFF;
+        const uint32_t sid_hi = key.sid >> 32;
+        uint32_t a, b, c;
+        a = b = c = 133824503;
+        a += cip64[0]; b += cip64[1]; c += cip64[2];
+        mix(a, b, c);
+        a += cip64[3]; b += sip64[0]; c += sip64[2];
+        mix(a, b, c);
+        a += sip64[3]; b += sid_lo; c += sid_hi;
+        finalize(a, b, c);
+        return c;
+    }
+
+private:
+    inline uint32_t rot(uint32_t x, unsigned k) const
+    { return (x << k) | (x >> (32 - k)); }
+
+    inline void mix(uint32_t& a, uint32_t& b, uint32_t& c) const
+    {
+        a -= c; a ^= rot(c, 4); c += b;
+        b -= a; b ^= rot(a, 6); a += c;
+        c -= b; c ^= rot(b, 8); b += a;
+        a -= c; a ^= rot(c,16); c += b;
+        b -= a; b ^= rot(a,19); a += c;
+        c -= b; c ^= rot(b, 4); b += a;
+    }
+
+    inline void finalize(uint32_t& a, uint32_t& b, uint32_t& c) const
+    {
+        c ^= b; c -= rot(b,14);
+        a ^= c; a -= rot(c,11);
+        b ^= a; b -= rot(a,25);
+        c ^= b; c -= rot(b,16);
+        a ^= c; a -= rot(c,4);
+        b ^= a; b -= rot(a,14);
+        c ^= b; c -= rot(b,24);
+    }
+};
+
+Smb2SidHashKey get_key(uint64_t sid);
 
 template<typename Key, typename Value, typename Hash>
 class SmbSessionCache_map : public LruCacheShared<Key, Value, Hash>
@@ -36,11 +94,13 @@ public:
     SmbSessionCache_map() = delete;
     SmbSessionCache_map(const SmbSessionCache_map& arg) = delete;
     SmbSessionCache_map& operator=(const SmbSessionCache_map& arg) = delete;
-    SmbSessionCache_map(const size_t initial_size) : LruCacheShared<Key, Value, Hash>(initial_size) {}
-    virtual ~SmbSessionCache_map() {}
+    SmbSessionCache_map(const size_t initial_size) : LruCacheShared<Key, Value, Hash>(initial_size)
+    {
+    }
+    virtual ~SmbSessionCache_map() { }
 };
 
-typedef SmbSessionCache_map<uint64_t, DCE2_Smb2SessionTracker, std::hash<uint64_t> > SmbSessionCache;
+typedef SmbSessionCache_map<Smb2SidHashKey, DCE2_Smb2SessionTracker, Smb2SidHash> SmbSessionCache;
 
 extern THREAD_LOCAL SmbSessionCache* smb2_session_cache;
 extern size_t session_cache_size;
@@ -53,18 +113,18 @@ inline void DCE2_SmbSessionCacheInit(const size_t cache_size)
 
 inline DCE2_Smb2SessionTracker* DCE2_SmbSessionCacheFind(uint64_t sid)
 {
-    return (smb2_session_cache->find(SMB2_SID_HASH(sid))).get();
+    return (smb2_session_cache->find(get_key(sid))).get();
 }
 
 inline DCE2_Smb2SessionTracker* DCE2_SmbSessionCacheFindElseCreate(uint64_t sid,
     bool* entry_created)
 {
-    return (smb2_session_cache->find_else_create(SMB2_SID_HASH(sid), entry_created)).get();
+    return (smb2_session_cache->find_else_create(get_key(sid), entry_created)).get();
 }
 
 inline bool DCE2_SmbSessionCacheRemove(uint64_t sid)
 {
-    return smb2_session_cache->remove(SMB2_SID_HASH(sid));
+    return smb2_session_cache->remove(get_key(sid));
 }
 
 // SMB2 functions for fetching sid, tid, request type and so on.
@@ -97,6 +157,11 @@ inline DCE2_Smb2SessionTracker* DCE2_Smb2FindSidInSsd(DCE2_Smb2SsnData* ssd, con
 inline void DCE2_Smb2InsertSidInSsd(DCE2_Smb2SsnData* ssd, const uint64_t sid,
     DCE2_Smb2SessionTracker* stracker)
 {
+    // add ssd in session tracker's tcp trackers database
+    SmbFlowKey key;
+    get_flow_key(&key);
+    stracker->insertConnTracker(key, ssd);
+
     ssd->session_trackers.Insert(sid, stracker);
 }
 
@@ -108,12 +173,11 @@ inline void DCE2_Smb2RemoveSidInSsd(DCE2_Smb2SsnData* ssd, const uint64_t sid)
 DCE2_Smb2TreeTracker* DCE2_Smb2InsertTid(DCE2_Smb2SsnData*, const uint32_t tid, uint8_t share_type,
     DCE2_Smb2SessionTracker*);
 
-DCE2_Smb2TreeTracker* DCE2_Smb2FindElseCreateTid(DCE2_Smb2SsnData*, const uint32_t tid,
-    uint8_t share_type, DCE2_Smb2SessionTracker*);
-
 DCE2_Smb2SessionTracker* DCE2_Smb2FindElseCreateSid(DCE2_Smb2SsnData*, const uint64_t sid);
 
 DCE2_Ret DCE2_Smb2InitData(DCE2_Smb2SsnData*);
 
+void DCE2_Smb2RemoveAllSession(DCE2_Smb2SsnData* ssd);
+
 #endif
 
index e863032db3517f2de76a40b6aebedd525533bb9a..98779b20be71fe88ddd0ce9c839128610de195f9 100644 (file)
@@ -114,24 +114,28 @@ inline DCE2_List* DCE2_ScSmbInvalidShares(const dce2SmbProtoConf* sc)
     return sc->smb_invalid_shares;
 }
 
+#define SMB_DEFAULT_MAX_CREDIT        8192
+#define SMB_DEFAULT_MEMCAP            8388608
+#define SMB_DEFAULT_MAX_COMPOUND_REQ  3
+
 inline uint16_t DCE2_ScSmbMaxCredit(const dce2SmbProtoConf* sc)
 {
     if (sc == nullptr)
-        return 8192;
+        return SMB_DEFAULT_MAX_CREDIT;
     return sc->smb_max_credit;
 }
 
 inline size_t DCE2_ScSmbMemcap(const dce2SmbProtoConf* sc)
 {
     if (sc == nullptr)
-        return 8388608;
+        return SMB_DEFAULT_MEMCAP;
     return sc->memcap;
 }
 
 inline uint16_t DCE2_ScSmbMaxCompound(const dce2SmbProtoConf* sc)
 {
     if (sc == nullptr)
-        return 3;
+        return SMB_DEFAULT_MAX_COMPOUND_REQ;
     return sc->smb_max_compound;
 }
 
index e2767699f217e5f3a582a5abf221bcec7a48ed0a..0ff63b079b0665262ca9a46ed45ff9cfad3168b2 100644 (file)
@@ -32,6 +32,7 @@
 #include "detection/detect.h"
 #include "file_api/file_service.h"
 #include "main/snort_debug.h"
+#include "memory/memory_cap.h"
 #include "packet_io/active.h"
 #include "protocols/packet.h"
 #include "utils/util.h"
@@ -829,7 +830,8 @@ static void DCE2_SmbProcessCommand(DCE2_SmbSsnData* ssd, const SmbNtHdr* smb_hdr
         if (smb_com2 == SMB_COM_NO_ANDX_COMMAND)
             break;
 
-        debug_logf(dce_smb_trace, nullptr, "Chained SMB command: %s\n", get_smb_com_string(smb_com2));
+        debug_logf(dce_smb_trace, nullptr, "Chained SMB command: %s\n", get_smb_com_string(
+            smb_com2));
 
         num_chained++;
         if (DCE2_ScSmbMaxChain((dce2SmbProtoConf*)ssd->sd.config) &&
@@ -1030,7 +1032,8 @@ static DCE2_SmbRequestTracker* DCE2_SmbInspect(DCE2_SmbSsnData* ssd, const SmbNt
 {
     int smb_com = SmbCom(smb_hdr);
 
-    debug_logf(dce_smb_trace, nullptr, "SMB command: %s (0x%02X)\n", get_smb_com_string(smb_com), smb_com);
+    debug_logf(dce_smb_trace, nullptr, "SMB command: %s (0x%02X)\n", get_smb_com_string(smb_com),
+        smb_com);
 
     if (smb_com_funcs[smb_com] == nullptr)
     {
@@ -1297,6 +1300,7 @@ Dce2SmbFlowData::Dce2SmbFlowData() : FlowData(inspector_id)
         dce2_smb_stats.max_concurrent_sessions = dce2_smb_stats.concurrent_sessions;
     smb_version = DCE2_SMB_VERSION_NULL;
     dce2_smb_session_data = nullptr;
+    memory::MemoryCap::update_allocations(sizeof(*this));
 }
 
 Dce2SmbFlowData::~Dce2SmbFlowData()
@@ -1308,10 +1312,13 @@ Dce2SmbFlowData::~Dce2SmbFlowData()
     }
     else
     {
+        DCE2_Smb2RemoveAllSession((DCE2_Smb2SsnData*)dce2_smb_session_data);
         delete (DCE2_Smb2SsnData*)dce2_smb_session_data;
+        memory::MemoryCap::update_deallocations(sizeof(*(DCE2_Smb2SsnData*)dce2_smb_session_data));
     }
     assert(dce2_smb_stats.concurrent_sessions > 0);
     dce2_smb_stats.concurrent_sessions--;
+    memory::MemoryCap::update_deallocations(sizeof(*this));
 }
 
 unsigned Dce2SmbFlowData::inspector_id = 0;
@@ -1615,11 +1622,12 @@ void DCE2_Smb1Process(DCE2_SmbSsnData* ssd)
                 {
                     // Upgrade connection to SMBv2
                     dce2SmbProtoConf* config = (dce2SmbProtoConf*)ssd->sd.config;
-                    Dce2SmbFlowData* fd = (Dce2SmbFlowData*)p->flow->get_flow_data(Dce2SmbFlowData::inspector_id);
+                    Dce2SmbFlowData* fd = (Dce2SmbFlowData*)p->flow->get_flow_data(
+                        Dce2SmbFlowData::inspector_id);
                     p->flow->free_flow_data(fd);
                     DCE2_Smb2SsnData* dce2_smb2_sess = dce2_create_new_smb2_session(p, config);
                     DCE2_Smb2Process(dce2_smb2_sess);
-                    if(!dce2_detected)
+                    if (!dce2_detected)
                         DCE2_Detect(&dce2_smb2_sess->sd);
                 }
                 else
@@ -2552,6 +2560,7 @@ static inline DCE2_Smb2SsnData* set_new_dce2_smb2_session(Packet* p)
     Dce2SmbFlowData* fd = new Dce2SmbFlowData;
     fd->smb_version = DCE2_SMB_VERSION_2;
     fd->dce2_smb_session_data = new DCE2_Smb2SsnData();
+    memory::MemoryCap::update_allocations(sizeof(*(DCE2_Smb2SsnData*)(fd->dce2_smb_session_data)));
     DCE2_Smb2InitData((DCE2_Smb2SsnData*)fd->dce2_smb_session_data);
     p->flow->set_flow_data(fd);
     return((DCE2_Smb2SsnData*)fd->dce2_smb_session_data);
@@ -2581,12 +2590,12 @@ DCE2_Smb2SsnData* dce2_create_new_smb2_session(Packet* p, dce2SmbProtoConf* conf
 
 DCE2_SsnData* get_dce2_session_data(snort::Flow* flow)
 {
-   Dce2SmbFlowData* fd = (Dce2SmbFlowData*)flow->get_flow_data(Dce2SmbFlowData::inspector_id);
-   if (fd and fd->dce2_smb_session_data)
-       return (fd->smb_version == DCE2_SMB_VERSION_1) ?
-           (DCE2_SsnData*)(&((DCE2_SmbSsnData*)fd->dce2_smb_session_data)->sd) :
-           (DCE2_SsnData*)(&((DCE2_Smb2SsnData*)fd->dce2_smb_session_data)->sd);
-   else
-       return nullptr;
+    Dce2SmbFlowData* fd = (Dce2SmbFlowData*)flow->get_flow_data(Dce2SmbFlowData::inspector_id);
+    if (fd and fd->dce2_smb_session_data)
+        return (fd->smb_version == DCE2_SMB_VERSION_1) ?
+               (DCE2_SsnData*)(&((DCE2_SmbSsnData*)fd->dce2_smb_session_data)->sd) :
+               (DCE2_SsnData*)(&((DCE2_Smb2SsnData*)fd->dce2_smb_session_data)->sd);
+    else
+        return nullptr;
 }