]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2876 in SNORT/snort3 from ~DIPANDIT/snort3:multichannel to master
authorBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Fri, 4 Jun 2021 21:44:56 +0000 (21:44 +0000)
committerBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Fri, 4 Jun 2021 21:44:56 +0000 (21:44 +0000)
Squashed commit of the following:

commit 3efdf0d7995c31a869edcfc34f1a57bf42cbed0a
Author: Dipto Pandit (dipandit) <dipandit@cisco.com>
Date:   Sun May 30 12:00:55 2021 -0400

    dce_rpc: SMB multichannel - handle negotiate command to create expected flow

commit 5bb575762f0ea11171a167deb59e199177576ae9
Author: Dipto Pandit (dipandit) <dipandit@cisco.com>
Date:   Tue Mar 23 03:31:05 2021 -0400

    dce_rpc: SMB multichannel - own memory tracking in global cache

commit 82b15dd65de7c4d44b36622c3ffd15f3199b877e
Author: Dipto Pandit (dipandit) <dipandit@cisco.com>
Date:   Wed May 5 10:11:55 2021 -0400

    dce_rpc: SMB multichannel - added smb multichannel file support

commit bc61306af569d4dd0b4d865db70597d590760efd
Author: Dipto Pandit (dipandit) <dipandit@cisco.com>
Date:   Wed Feb 17 07:25:36 2021 -0500

    dce_rpc: SMB multichannel - make session cache global

commit c24c372ee61232a27cff1e0a5d92734b96aa106d
Author: Dipto Pandit <dipandit@cisco.com>
Date:   Thu Oct 8 06:55:59 2020 -0400

    dce_rpc: SMB multichannel - introduce locks

14 files changed:
src/service_inspectors/dce_rpc/dce_smb2.cc
src/service_inspectors/dce_rpc/dce_smb2.h
src/service_inspectors/dce_rpc/dce_smb2_file.cc
src/service_inspectors/dce_rpc/dce_smb2_file.h
src/service_inspectors/dce_rpc/dce_smb2_request.h
src/service_inspectors/dce_rpc/dce_smb2_session.cc
src/service_inspectors/dce_rpc/dce_smb2_session.h
src/service_inspectors/dce_rpc/dce_smb2_session_cache.h
src/service_inspectors/dce_rpc/dce_smb2_tree.cc
src/service_inspectors/dce_rpc/dce_smb2_tree.h
src/service_inspectors/dce_rpc/dce_smb_common.cc
src/service_inspectors/dce_rpc/dce_smb_common.h
src/service_inspectors/dce_rpc/dce_smb_inspector.cc
src/service_inspectors/dce_rpc/dce_smb_inspector.h

index ae72bca5bd81f271a0a1ab75bdba7e09b9b7bf70..1fe8b11b20d34eab952c7a009894919e6eb7b720 100644 (file)
@@ -27,6 +27,7 @@
 
 #include "detection/detection_util.h"
 #include "flow/flow_key.h"
+#include "stream/stream.h"
 
 #include "dce_smb2_file.h"
 #include "dce_smb2_session.h"
@@ -35,7 +36,7 @@
 
 using namespace snort;
 
-THREAD_LOCAL Dce2Smb2SessionCache* smb2_session_cache;
+Dce2Smb2SessionCache smb2_session_cache(SMB_DEFAULT_MEMCAP);
 
 const char* smb2_command_string[SMB2_COM_MAX] = {
     "SMB2_COM_NEGOTIATE",
@@ -69,10 +70,9 @@ static inline bool Smb2Error(const Smb2Hdr* hdr)
     return (SMB_NT_STATUS_SEVERITY__ERROR == (uint8_t)(hdr->status >> 30));
 }
 
-Smb2FlowKey get_smb2_flow_key(void)
+uint32_t get_smb2_flow_key(const FlowKey* flow_key)
 {
     Smb2FlowKey key;
-    const FlowKey* flow_key = DetectionEngine::get_current_packet()->flow->key;
 
     key.ip_l[0] = flow_key->ip_l[0];
     key.ip_l[1] = flow_key->ip_l[1];
@@ -94,7 +94,8 @@ Smb2FlowKey get_smb2_flow_key(void)
     key.version = flow_key->version;
     key.padding = 0;
 
-    return key;
+    Smb2KeyHash hasher;
+    return hasher(key);
 }
 
 //Dce2Smb2SessionData member functions
@@ -103,24 +104,26 @@ Dce2Smb2SessionData::Dce2Smb2SessionData(const Packet* p,
     const dce2SmbProtoConf* proto) : Dce2SmbSessionData(p, proto)
 {
     tcp_file_tracker = nullptr;
-    flow_key = get_smb2_flow_key();
+    flow_key = get_smb2_flow_key(tcp_flow->key);
     debug_logf(dce_smb_trace, p, "smb2 session created\n");
     memory::MemoryCap::update_allocations(sizeof(*this));
 }
 
 Dce2Smb2SessionData::~Dce2Smb2SessionData(void)
 {
+    session_data_mutex.lock();
     for (auto it_session : connected_sessions)
     {
-        if (it_session.second->detach_flow(flow_key))
-            smb2_session_cache->remove(it_session.second->get_key());
+        it_session.second->detach_flow(flow_key);
     }
+    session_data_mutex.unlock();
     memory::MemoryCap::update_deallocations(sizeof(*this));
 }
 
 void Dce2Smb2SessionData::reset_matching_tcp_file_tracker(
     Dce2Smb2FileTracker* file_tracker)
 {
+    std::lock_guard<std::mutex> guard(tcp_file_tracker_mutex);
     if (tcp_file_tracker == file_tracker)
         tcp_file_tracker = nullptr;
 }
@@ -141,23 +144,22 @@ Smb2SessionKey Dce2Smb2SessionData::get_session_key(uint64_t session_id)
 
 Dce2Smb2SessionTracker* Dce2Smb2SessionData::find_session(uint64_t session_id)
 {
+    std::lock_guard<std::mutex> guard(session_data_mutex);
     auto it_session = connected_sessions.find(session_id);
+
     if (it_session != connected_sessions.end())
     {
         Dce2Smb2SessionTracker* session = it_session->second;
         //we already have the session, but call find to update the LRU
-        smb2_session_cache->find_session(session->get_key());
+        smb2_session_cache.find_session(session->get_key(), this);
         return session;
     }
     else
     {
-        Dce2Smb2SessionTracker* session = smb2_session_cache->find_session(
-            get_session_key(session_id));
+        Dce2Smb2SessionTracker* session = smb2_session_cache.find_session(
+            get_session_key(session_id), this);
         if (session)
-        {
-            session->attach_flow(flow_key, this);
             connected_sessions.insert(std::make_pair(session_id,session));
-        }
         return session;
     }
 }
@@ -166,9 +168,8 @@ Dce2Smb2SessionTracker* Dce2Smb2SessionData::find_session(uint64_t session_id)
 Dce2Smb2SessionTracker* Dce2Smb2SessionData::create_session(uint64_t session_id)
 {
     Smb2SessionKey session_key = get_session_key(session_id);
-    Dce2Smb2SessionTracker* session = smb2_session_cache->find_else_create_session(session_key);
-    session->init(session_id, session_key);
-    session->attach_flow(flow_key, this);
+    std::lock_guard<std::mutex> guard(session_data_mutex);
+    Dce2Smb2SessionTracker* session = smb2_session_cache.find_else_create_session(session_key, this);
     connected_sessions.insert(std::make_pair(session_id, session));
     return session;
 }
@@ -176,7 +177,6 @@ Dce2Smb2SessionTracker* Dce2Smb2SessionData::create_session(uint64_t session_id)
 void Dce2Smb2SessionData::remove_session(uint64_t session_id)
 {
     connected_sessions.erase(session_id);
-    smb2_session_cache->remove(get_session_key(session_id));
 }
 
 void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
@@ -230,9 +230,9 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
     uint16_t command = alignedNtohs(&(smb_hdr->command));
     uint64_t session_id = Smb2Sid(smb_hdr);
     debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
-        "%s : mid %" PRIu64 " sid %" PRIu64 " tid %" PRIu32 "\n",
+        "%s : flow %" PRIu32 " mid %" PRIu64 " sid %" PRIu64 " tid %" PRIu32 "\n",
         (command < SMB2_COM_MAX ? smb2_command_string[command] : "unknown"),
-        Smb2Mid(smb_hdr), session_id, Smb2Tid(smb_hdr));
+        flow_key, Smb2Mid(smb_hdr), session_id, Smb2Tid(smb_hdr));
     // Try to find the session.
     // The case when session is not available will be handled per command.
     Dce2Smb2SessionTracker* session = find_session(session_id);
@@ -240,6 +240,27 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
     switch (command)
     {
     //commands processed by flow
+    case SMB2_COM_NEGOTIATE:
+        if (SMB2_COMMAND_TYPE(NEGOTIATE, RESPONSE))
+        {
+            const Smb2NegotiateResponseHdr* neg_resp_hdr = (const Smb2NegotiateResponseHdr*)smb_data;
+            if (neg_resp_hdr->capabilities & SMB2_GLOBAL_CAP_MULTI_CHANNEL)
+            {
+                Packet *p = DetectionEngine::get_current_packet();
+                Dce2SmbFlowData* fd =
+                    create_expected_smb_flow_data(p, (dce2SmbProtoConf *)sd.config);
+                if (fd)
+                {
+                    int result = Stream::set_snort_protocol_id_expected(p, PktType::TCP,
+                        IpProtocol::TCP, p->ptrs.ip_api.get_dst() , 0 ,p->ptrs.ip_api.get_src(),
+                        p->flow->server_port , snort_protocol_id_smb, fd, false, true);
+                
+                    if (result < 0)
+                        delete fd;
+                }
+            }
+        }
+        break;
     case SMB2_COM_SESSION_SETUP:
         dce2_smb_stats.v2_setup++;
         SMB2_HANDLE_ERROR_RESPONSE(setup)
@@ -256,21 +277,24 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
     case SMB2_COM_LOGOFF:
         dce2_smb_stats.v2_logoff++;
         if (SMB2_COMMAND_TYPE(LOGOFF, REQUEST))
-            remove_session(session_id);
+        {
+            session_data_mutex.lock();
+            smb2_session_cache.remove(get_session_key(session_id));
+            session_data_mutex.unlock();
+        }
         else
             SMB2_HANDLE_INVALID_STRUC_SIZE(logoff)
         break;
     //commands processed by session
     case SMB2_COM_TREE_CONNECT:
         dce2_smb_stats.v2_tree_cnct++;
-
         SMB2_HANDLE_ERROR_RESPONSE(tree_cnct)
         if (SMB2_COMMAND_TYPE(TREE_CONNECT, RESPONSE))
         {
             SMB2_HANDLE_HEADER_ERROR(TREE_CONNECT, RESPONSE, tree_cnct_resp)
             if (!session)
                 session = create_session(session_id);
-            session->process(command, SMB2_CMD_TYPE_RESPONSE, smb_hdr, end);
+            session->process(command, SMB2_CMD_TYPE_RESPONSE, smb_hdr, end, flow_key);
         }
         else if (!SMB2_COMMAND_TYPE(TREE_CONNECT,REQUEST))
             SMB2_HANDLE_INVALID_STRUC_SIZE(tree_cnct)
@@ -284,7 +308,7 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             if (SMB2_COMMAND_TYPE(TREE_DISCONNECT, REQUEST))
             {
                 SMB2_HANDLE_HEADER_ERROR(TREE_DISCONNECT, REQUEST, tree_discn_req)
-                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end);
+                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end, flow_key);
             }
             else
             {
@@ -298,7 +322,6 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
     case SMB2_COM_CREATE:
     {
         dce2_smb_stats.v2_crt++;
-
         uint8_t command_type = SMB2_GET_COMMAND_TYPE(CREATE);
         if (SMB2_CMD_TYPE_INVALID == command_type)
             SMB2_HANDLE_INVALID_STRUC_SIZE(crt)
@@ -309,7 +332,7 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
 
         if (!session)
             session = create_session(session_id);
-        session->process(command, command_type, smb_hdr, end);
+        session->process(command, command_type, smb_hdr, end, flow_key);
     }
         break;
 
@@ -321,7 +344,7 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             if (SMB2_COMMAND_TYPE(CLOSE, REQUEST))
             {
                 SMB2_HANDLE_HEADER_ERROR(CLOSE, REQUEST, cls_req)
-                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end);
+                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end, flow_key);
             }
             else if (!SMB2_COMMAND_TYPE(CLOSE, RESPONSE))
             {
@@ -340,7 +363,7 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             if (SMB2_COMMAND_TYPE(SET_INFO, REQUEST))
             {
                 SMB2_HANDLE_HEADER_ERROR(SET_INFO, REQUEST, stinf_req)
-                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end);
+                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end, flow_key);
             }
             else if (!SMB2_COMMAND_TYPE(SET_INFO, RESPONSE))
             {
@@ -370,7 +393,7 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             }
             else
                 SMB2_HANDLE_INVALID_STRUC_SIZE(read)
-            session->process(command, command_type, smb_hdr, end);
+            session->process(command, command_type, smb_hdr, end, flow_key);
         }
         else
             dce2_smb_stats.v2_session_ignored++;
@@ -395,7 +418,7 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             }
             else
                 SMB2_HANDLE_INVALID_STRUC_SIZE(wrt)
-            session->process(command, command_type, smb_hdr, end);
+            session->process(command, command_type, smb_hdr, end, flow_key);
         }
         else
             dce2_smb_stats.v2_session_ignored++;
@@ -406,12 +429,12 @@ void Dce2Smb2SessionData::process_command(const Smb2Hdr* smb_hdr,
             if (SMB2_COMMAND_TYPE(IOCTL, REQUEST))
             {
                 SMB2_HANDLE_HEADER_ERROR(IOCTL, REQUEST, ioctl_req)
-                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end);
+                session->process(command, SMB2_CMD_TYPE_REQUEST, smb_hdr, end, flow_key);
             }
             else if ( SMB2_COMMAND_TYPE(IOCTL, RESPONSE))
             {
                 SMB2_HANDLE_HEADER_ERROR(IOCTL, RESPONSE, ioctl_resp)
-                session->process(command, SMB2_CMD_TYPE_RESPONSE, smb_hdr, end);
+                session->process(command, SMB2_CMD_TYPE_RESPONSE, smb_hdr, end, flow_key);
             }
             else
                 SMB2_HANDLE_INVALID_STRUC_SIZE(ioctl)
@@ -455,7 +478,6 @@ void Dce2Smb2SessionData::process()
         do
         {
             process_command(smb_hdr, data_ptr +  data_len);
-
             // In case of message compounding, find the offset of the next smb command
             next_command_offset = alignedNtohl(&(smb_hdr->next_command));
             if (next_command_offset + (const uint8_t*)smb_hdr > (data_ptr + data_len))
@@ -482,13 +504,16 @@ void Dce2Smb2SessionData::process()
         }
         while (next_command_offset and smb_hdr);
     }
-    else if ( tcp_file_tracker and tcp_file_tracker->accepting_raw_data())
+    else
     {
-        debug_logf(dce_smb_trace, p, "processing raw data for file id %" PRIu64 "\n",
-            tcp_file_tracker->get_file_id());
-
-        if (!tcp_file_tracker->process_data(data_ptr, data_len))
-            tcp_file_tracker->get_parent()->close_file(tcp_file_tracker->get_file_id());
+        tcp_file_tracker_mutex.lock();
+        if ( tcp_file_tracker and tcp_file_tracker->accepting_raw_data())
+        {
+            debug_logf(dce_smb_trace, p, "processing raw data for file id %" PRIu64 "\n",
+                tcp_file_tracker->get_file_id());
+            tcp_file_tracker->process_data(flow_key, data_ptr, data_len);
+        }
+        tcp_file_tracker_mutex.unlock();
     }
 }
 
@@ -497,7 +522,10 @@ void Dce2Smb2SessionData::set_reassembled_data(uint8_t* nb_ptr, uint16_t co_len)
     NbssHdr* nb_hdr = (NbssHdr*)nb_ptr;
     SmbNtHdr* smb_hdr = (SmbNtHdr*)((uint8_t*)nb_hdr + sizeof(NbssHdr));
 
+    tcp_file_tracker_mutex.lock();
     uint32_t tid = (tcp_file_tracker) ? tcp_file_tracker->get_parent()->get_tree_id() : 0;
+    tcp_file_tracker_mutex.unlock();
+
     smb_hdr->smb_tid = alignedNtohl((const uint32_t*)&tid);
 
     if (DetectionEngine::get_current_packet()->is_from_client())
@@ -509,6 +537,8 @@ void Dce2Smb2SessionData::set_reassembled_data(uint8_t* nb_ptr, uint16_t co_len)
             nb_len = UINT16_MAX;
         write->structure_size = SMB2_WRITE_REQUEST_STRUC_SIZE;
         nb_hdr->length = htons((uint16_t)nb_len);
+
+        tcp_file_tracker_mutex.lock();
         if (tcp_file_tracker)
         {
             uint64_t fid = tcp_file_tracker->get_file_id();
@@ -517,6 +547,7 @@ void Dce2Smb2SessionData::set_reassembled_data(uint8_t* nb_ptr, uint16_t co_len)
         }
         else
             write->fileId_persistent = write->fileId_volatile = 0;
+        tcp_file_tracker_mutex.unlock();
 
         write->length = alignedNtohs(&co_len);
     }
@@ -533,4 +564,3 @@ void Dce2Smb2SessionData::set_reassembled_data(uint8_t* nb_ptr, uint16_t co_len)
         read->length = alignedNtohs(&co_len);
     }
 }
-
index 32c94093444689e41382085018fddde11e3e0642..1d04cee84db64ec3c5b0cbee4a90582e39ff712b 100644 (file)
@@ -28,6 +28,7 @@
 #include "main/thread_config.h"
 #include "memory/memory_cap.h"
 #include "utils/util.h"
+#include <mutex>
 
 #include "dce_smb_common.h"
 
@@ -101,6 +102,23 @@ struct Smb2SyncHdr
     uint8_t signature[16];    /* signature of the message */
 };
 
+struct Smb2NegotiateResponseHdr
+{
+    uint16_t structure_size;
+    uint16_t security_mode;
+    uint16_t dialect_revision;
+    uint16_t negotiate_context_count;
+    uint64_t servier_guid[2];
+    uint32_t capabilities;
+    uint32_t max_transaction_size;
+    uint32_t max_read_size;
+    uint32_t max_write_size;
+    uint64_t system_time;
+    uint64_t server_start_time;
+    uint16_t security_buffer_offset;
+    uint16_t security_buffer_length;
+};
+
 struct Smb2WriteRequestHdr
 {
     uint16_t structure_size;  /* This MUST be set to 49 */
@@ -237,6 +255,9 @@ struct Smb2TreeConnectResponseHdr
 
 #define SMB2_ERROR_RESPONSE_STRUC_SIZE 9
 
+#define SMB2_NEGOTIATE_REQUEST_STRUC_SIZE 36
+#define SMB2_NEGOTIATE_RESPONSE_STRUC_SIZE 65
+
 #define SMB2_CREATE_REQUEST_STRUC_SIZE 57
 #define SMB2_CREATE_RESPONSE_STRUC_SIZE 89
 
@@ -257,8 +278,6 @@ struct Smb2TreeConnectResponseHdr
 #define SMB2_TREE_DISCONNECT_REQUEST_STRUC_SIZE 4
 #define SMB2_TREE_DISCONNECT_RESPONSE_STRUC_SIZE 4
 
-#define SMB2_FILE_ENDOFFILE_INFO 0x14
-
 #define SMB2_SETUP_REQUEST_STRUC_SIZE 25
 #define SMB2_SETUP_RESPONSE_STRUC_SIZE 9
 
@@ -268,6 +287,9 @@ struct Smb2TreeConnectResponseHdr
 #define SMB2_IOCTL_REQUEST_STRUC_SIZE 57
 #define SMB2_IOCTL_RESPONSE_STRUC_SIZE 49
 
+#define SMB2_FILE_ENDOFFILE_INFO 0x14
+#define SMB2_GLOBAL_CAP_MULTI_CHANNEL 0x08
+
 #define GET_CURRENT_PACKET snort::DetectionEngine::get_current_packet()
 
 class Dce2Smb2FileTracker;
@@ -319,27 +341,18 @@ struct Smb2FlowKey
     uint8_t pkt_type;
     uint8_t version;
     uint8_t padding;
+};
 
-    bool operator==(const Smb2FlowKey& other) const
+struct Smb2MessageKey
+{
+    uint64_t mid;
+    uint32_t flow_key;
+    uint32_t padding;
+
+    bool operator==(const Smb2MessageKey& other) const
     {
-        return (ip_l[0] == other.ip_l[0] and
-               ip_l[1] == other.ip_l[1] and
-               ip_l[2] == other.ip_l[2] and
-               ip_l[3] == other.ip_l[3] and
-               ip_h[0] == other.ip_h[0] and
-               ip_l[1] == other.ip_l[1] and
-               ip_l[2] == other.ip_l[2] and
-               ip_l[3] == other.ip_l[3] and
-               mplsLabel == other.mplsLabel and
-               port_l == other.port_l and
-               port_h == other.port_h and
-               group_l == other.group_l and
-               group_h == other.group_h and
-               vlan_tag == other.vlan_tag and
-               addressSpaceId == other.addressSpaceId and
-               ip_protocol == other.ip_protocol and
-               pkt_type == other.pkt_type and
-               version == other.version);
+        return (mid == other.mid and
+               flow_key == other.flow_key);
     }
 };
 PADDING_GUARD_END
@@ -359,6 +372,11 @@ struct Smb2KeyHash
         return do_hash_session_key((const uint32_t*)&key);
     }
 
+    size_t operator()(const Smb2MessageKey& key) const
+    {
+        return do_hash_message_key((const uint32_t*)&key);
+    }
+
 private:
     size_t do_hash_flow_key(const uint32_t* d) const
     {
@@ -383,6 +401,15 @@ private:
         return c;
     }
 
+    size_t do_hash_message_key(const uint32_t* d) const
+    {
+        uint32_t a, b, c;
+        a = b = c = SMB_KEY_HASH_HARDENER;
+        a += d[0]; b += d[1]; c += d[2]; mix(a, b, c);
+        finalize(a, b, c);
+        return c;
+    }
+
     inline uint32_t rot(uint32_t x, unsigned k) const
     { return (x << k) | (x >> (32 - k)); }
 
@@ -408,7 +435,7 @@ private:
     }
 };
 
-Smb2FlowKey get_smb2_flow_key(void);
+uint32_t get_smb2_flow_key(const snort::FlowKey*);
 
 class Dce2Smb2SessionData : public Dce2SmbSessionData
 {
@@ -420,9 +447,13 @@ public:
     void remove_session(uint64_t);
     void handle_retransmit(FilePosition, FileVerdict) override { }
     void reset_matching_tcp_file_tracker(Dce2Smb2FileTracker*);
-    void set_tcp_file_tracker(Dce2Smb2FileTracker* file_tracker)
-    { tcp_file_tracker = file_tracker; }
     void set_reassembled_data(uint8_t*, uint16_t) override;
+    uint32_t get_flow_key() { return flow_key; }
+    void set_tcp_file_tracker(Dce2Smb2FileTracker* file_tracker)
+    {
+        std::lock_guard<std::mutex> guard(session_data_mutex);
+        tcp_file_tracker = file_tracker;
+    }
 
 private:
     void process_command(const Smb2Hdr*, const uint8_t*);
@@ -430,13 +461,15 @@ private:
     Dce2Smb2SessionTracker* create_session(uint64_t);
     Dce2Smb2SessionTracker* find_session(uint64_t);
 
-    Smb2FlowKey flow_key;
+    uint32_t flow_key;
     Dce2Smb2FileTracker* tcp_file_tracker;
     Dce2Smb2SessionTrackerMap connected_sessions;
+    std::mutex session_data_mutex;
+    std::mutex tcp_file_tracker_mutex;
 };
 
 using Dce2Smb2SessionDataMap =
-    std::unordered_map<Smb2FlowKey, Dce2Smb2SessionData*, Smb2KeyHash>;
+    std::unordered_map<uint32_t, Dce2Smb2SessionData*, std::hash<uint32_t> >;
 
 #endif  /* _DCE_SMB2_H_ */
 
index 106d94ceae75fbf87b5df6d65a94f15db0bbb4c5..d74ca11ae2afe48aa586083c805b0ee2c118878a 100644 (file)
@@ -50,57 +50,83 @@ inline void Dce2Smb2FileTracker::file_detect()
     dce2_detected = 1;
 }
 
-void Dce2Smb2FileTracker::set_info(char* file_name_v, uint16_t name_len_v,
-    uint64_t size_v, bool create)
+std::pair<bool, Dce2Smb2SessionData*> Dce2Smb2FileTracker::update_processing_flow(
+    Dce2Smb2SessionData* current_flow)
 {
-    if (file_name_v and name_len_v)
+    std::lock_guard<std::mutex> guard(process_file_mutex);
+    bool switched = false;
+    Dce2Smb2SessionData* processing_flow = parent_tree->get_parent()->get_flow(file_flow_key);
+    if (!processing_flow)
     {
-        file_name = file_name_v;
+        switched = true;
+        if (current_flow)
+            processing_flow = current_flow;
+        else
+        {
+            Flow* flow = DetectionEngine::get_current_packet()->flow;
+            Dce2SmbFlowData* current_flow_data = (Dce2SmbFlowData*)(flow->get_flow_data(Dce2SmbFlowData::inspector_id));
+            processing_flow = (Dce2Smb2SessionData*)current_flow_data->get_smb_session_data();
+        }
+        file_flow_key = processing_flow->get_flow_key();
+    }
+    return std::make_pair(switched, processing_flow);
+}
+
+void Dce2Smb2FileTracker::set_info(char* file_name_v, uint16_t name_len_v, uint64_t size_v)
+{
+    if (file_name_v and name_len_v and !file_name)
+    {
+        file_name = (char*)snort_alloc(name_len_v + 1);
+        memcpy(file_name, file_name_v, name_len_v);
         file_name_len = name_len_v;
         file_name_hash = str_to_hash((uint8_t*)file_name, file_name_len);
     }
     file_size = size_v;
-    FileContext* file = get_smb_file_context(file_name_hash, file_id, create);
+    auto updated_flow = update_processing_flow();
+    Flow* flow = updated_flow.second->get_tcp_flow();
+    FileContext* file = get_smb_file_context(flow, file_name_hash, file_id, true);
     debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "set file info: file size %"
         PRIu64 " fid %" PRIu64 " file_name_hash %" PRIu64 " file context "
-        "%sfound\n", file_size, file_id, file_name_hash, (file ? "" : "not "));
+        "%sfound\n", size_v, file_id, file_name_hash, (file ? "" : "not "));
     if (file)
     {
         ignore = false;
         if (file->verdict == FILE_VERDICT_UNKNOWN)
         {
-            if (file_name_v and name_len_v)
+            if ((file_name_v and name_len_v) or updated_flow.first)
                 file->set_file_name(file_name, file_name_len);
-            file->set_file_size(file_size ? file_size : UNKNOWN_FILE_SIZE);
+            file->set_file_size(size_v ? size_v : UNKNOWN_FILE_SIZE);
         }
     }
 }
 
-bool Dce2Smb2FileTracker::close()
+bool Dce2Smb2FileTracker::close(const uint32_t current_flow_key)
 {
+    uint64_t file_offset = file_offsets[current_flow_key];
     if (!ignore and !file_size and file_offset)
     {
         file_size = file_offset;
-        FileContext* file =
-            get_smb_file_context(file_name_hash, file_id, false);
+        Dce2Smb2SessionData* processing_flow = update_processing_flow().second;
+        Flow* flow = processing_flow->get_tcp_flow();
+        FileContext* file = get_smb_file_context(flow, file_name_hash, file_id, false);
         if (file)
             file->set_file_size(file_size);
-        return (!process_data(nullptr, 0));
+        return (!process_data(current_flow_key, nullptr, 0));
     }
     return true;
 }
 
-bool Dce2Smb2FileTracker::process_data(const uint8_t* file_data,
-    uint32_t data_size, uint64_t offset)
+bool Dce2Smb2FileTracker::process_data(const uint32_t current_flow_key, const uint8_t* file_data,
+    uint32_t data_size, const uint64_t offset)
 {
-    file_offset = offset;
-    return process_data(file_data, data_size);
+    file_offsets[current_flow_key] = offset;
+    return process_data(current_flow_key, file_data, data_size);
 }
 
-bool Dce2Smb2FileTracker::process_data(const uint8_t* file_data,
+bool Dce2Smb2FileTracker::process_data(const uint32_t current_flow_key, const uint8_t* file_data,
     uint32_t data_size)
 {
-    Dce2Smb2SessionData* current_flow = parent_tree->get_parent()->get_current_flow();
+    Dce2Smb2SessionData* current_flow = parent_tree->get_parent()->get_flow(current_flow_key);
 
     if (parent_tree->get_share_type() != SMB2_SHARE_TYPE_DISK)
     {
@@ -108,13 +134,14 @@ bool Dce2Smb2FileTracker::process_data(const uint8_t* file_data,
         {
             data_size = UINT16_MAX;
         }
-        DCE2_CoProcess(current_flow->get_dce2_session_data(), get_parent()->get_cotracker(),
+        DCE2_CoProcess(current_flow->get_dce2_session_data(), parent_tree->get_cotracker(),
             file_data, data_size);
         return true;
     }
 
     int64_t file_detection_depth = current_flow->get_smb_file_depth();
     int64_t detection_size = 0;
+    uint64_t file_offset = file_offsets[current_flow_key];
 
     if (file_detection_depth == 0)
         detection_size = data_size;
@@ -145,45 +172,51 @@ bool Dce2Smb2FileTracker::process_data(const uint8_t* file_data,
             &dce2_smb_stats, *(current_flow->get_dce2_session_data()));
     }
 
+    auto updated_flow = update_processing_flow(current_flow);
+    Dce2Smb2SessionData* processing_flow = updated_flow.second;
+
     debug_logf(dce_smb_trace, p, "file_process fid %" PRIu64 " data_size %"
         PRIu32 " offset %" PRIu64 "\n", file_id, data_size, file_offset);
 
-    FileFlows* file_flows = FileFlows::get_file_flows(p->flow);
+    FileFlows* file_flows = FileFlows::get_file_flows(processing_flow->get_tcp_flow());
 
     if (!file_flows)
         return true;
 
-    if (!file_flows->file_process(p, file_name_hash, file_data, data_size,
-        file_offset, direction, file_id))
+    if (updated_flow.first)
+    {
+        // update the new file context in case of flow switch
+        FileContext* file = file_flows->get_file_context(file_name_hash, true, file_id);
+        file->set_file_name(file_name, file_name_len);
+        file->set_file_size(file_size.load() ? file_size.load() : UNKNOWN_FILE_SIZE);
+    }
+
+    process_file_mutex.lock();
+    bool continue_processing = file_flows->file_process(p, file_name_hash, file_data, data_size,
+        file_offset, direction, file_id);
+    process_file_mutex.unlock();
+    if (!continue_processing)
     {
         debug_logf(dce_smb_trace, p, "file_process completed\n");
         return false;
     }
 
     file_offset += data_size;
+    file_offsets[current_flow_key] = file_offset;
     return true;
 }
 
 Dce2Smb2FileTracker::~Dce2Smb2FileTracker(void)
 {
-    debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
-        "file tracker %" PRIu64 " file name hash %" PRIu64 " terminating\n", file_id, file_name_hash);
-
-    if (file_name)
-        snort_free((void*)file_name);
-
-    Dce2Smb2SessionDataMap attached_flows = parent_tree->get_parent()->get_attached_flows();
-
-    for (auto it_flow : attached_flows)
+    if (smb_module_is_up)
     {
-        FileFlows* file_flows = FileFlows::get_file_flows(it_flow.second->get_flow(), false);
-        if (file_flows)
-            file_flows->remove_processed_file_context(file_name_hash, file_id);
-        it_flow.second->reset_matching_tcp_file_tracker(this);
+        debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "file tracker %" PRIu64
+            " file name hash %" PRIu64 " terminating\n", file_id, file_name_hash);
     }
 
-    parent_tree->close_file(file_id, false);
+    if (file_name)
+        snort_free((void*)file_name);
 
-    memory::MemoryCap::update_deallocations(sizeof(*this));
+    parent_tree->get_parent()->clean_file_context_from_flow(this, file_id, file_name_hash);
 }
 
index 0384c7a9bb640d731ed4dece6a395840658836b9..c5aa334d1ea7455bc3b7f937b45c71242816b4a0 100644 (file)
@@ -35,23 +35,21 @@ public:
     Dce2Smb2FileTracker(const Dce2Smb2FileTracker& arg) = delete;
     Dce2Smb2FileTracker& operator=(const Dce2Smb2FileTracker& arg) = delete;
 
-    Dce2Smb2FileTracker(uint64_t file_idv, Dce2Smb2TreeTracker* p_tree) : ignore(true),
-        file_name_len(0), file_offset(0), file_id(file_idv), file_size(0), file_name_hash(0),
-        file_name(nullptr), direction(FILE_DOWNLOAD), smb2_pdu_state(DCE2_SMB_PDU_STATE__COMMAND),
-        parent_tree(p_tree)
+    Dce2Smb2FileTracker(uint64_t file_idv, const uint32_t flow_key, Dce2Smb2TreeTracker* p_tree) :
+        ignore(true), file_name_len(0), file_flow_key(flow_key),
+        file_id(file_idv), file_size(0), file_name_hash(0), file_name(nullptr),
+        direction(FILE_DOWNLOAD), smb2_pdu_state(DCE2_SMB_PDU_STATE__COMMAND), parent_tree(p_tree)
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
             "file tracker %" PRIu64 " created\n", file_id);
-        memory::MemoryCap::update_allocations(sizeof(*this));
     }
 
     ~Dce2Smb2FileTracker();
-    bool process_data(const uint8_t*, uint32_t, uint64_t);
-    bool process_data(const uint8_t*, uint32_t);
-    bool close();
-    void set_info(char*, uint16_t, uint64_t, bool = false);
+    bool process_data(const uint32_t, const uint8_t*, uint32_t, const uint64_t);
+    bool process_data(const uint32_t, const uint8_t*, uint32_t);
+    bool close(const uint32_t);
+    void set_info(char*, uint16_t, uint64_t);
     void accept_raw_data_from(Dce2Smb2SessionData*);
-
     bool accepting_raw_data()
     { return (smb2_pdu_state == DCE2_SMB_PDU_STATE__RAW_DATA); }
 
@@ -61,16 +59,19 @@ public:
 
 private:
     void file_detect();
+    std::pair<bool, Dce2Smb2SessionData*> update_processing_flow(Dce2Smb2SessionData* = nullptr);
     bool ignore;
     uint16_t file_name_len;
-    uint64_t file_offset;
+    uint32_t file_flow_key;
     uint64_t file_id;
-    uint64_t file_size;
+    std::atomic<uint64_t> file_size;
     uint64_t file_name_hash;
     char* file_name;
     FileDirection direction;
-    Dce2SmbPduState smb2_pdu_state;
+    std::atomic<Dce2SmbPduState> smb2_pdu_state;
     Dce2Smb2TreeTracker* parent_tree;
+    std::unordered_map<uint32_t, uint64_t,std::hash<uint32_t> > file_offsets;
+    std::mutex process_file_mutex;
 };
 
 using  Dce2Smb2FileTrackerMap =
index b78230f6988dd2775a9bd863851f19778d28333b..76d496a34b13d9935d372ed779d110742b1aec1b 100644 (file)
@@ -38,25 +38,22 @@ public:
         : fname(nullptr), fname_len(0), file_id(file_id_v), offset(offset_v)
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "request tracker created\n");
-        memory::MemoryCap::update_allocations(sizeof(*this));
     }
 
     Dce2Smb2RequestTracker(char* fname_v, uint16_t fname_len_v)
         : fname(fname_v), fname_len(fname_len_v), file_id(0), offset(0)
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "request tracker created\n");
-        memory::MemoryCap::update_allocations(sizeof(*this));
     }
 
     ~Dce2Smb2RequestTracker()
     {
-        debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "request tracker terminating\n");
+        if (smb_module_is_up)
+            debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "request tracker terminating\n");
         if (fname)
             snort_free(fname);
-        memory::MemoryCap::update_deallocations(sizeof(*this));
     }
 
-    void reset_file_name() { fname = nullptr; fname_len = 0; }
     uint64_t get_offset() { return offset; }
     uint64_t get_file_id() { return file_id; }
     char* get_file_name() { return fname; }
@@ -70,7 +67,7 @@ private:
 };
 
 using Dce2Smb2RequestTrackerMap =
-    std::unordered_map<uint64_t, Dce2Smb2RequestTracker*, std::hash<uint64_t> >;
+    std::unordered_map<Smb2MessageKey, Dce2Smb2RequestTracker*, Smb2KeyHash>;
 
 #endif
 
index bff04ba4777ff75d24a1402c87ddd424f20e7959..d5f10faea63b150a4a62b205842ad35cfc023b3f 100644 (file)
 
 #include "dce_smb2_session.h"
 
+#include "dce_smb2_session_cache.h"
+
+#include "file_api/file_flows.h"
+
 uint32_t Smb2Tid(const Smb2Hdr* hdr)
 {
     return snort::alignedNtohl(&(((const Smb2SyncHdr*)hdr)->tree_id));
 }
 
-//init must be called when a session tracker is created.
-void Dce2Smb2SessionTracker::init(uint64_t sid,
-    const Smb2SessionKey& session_key_v)
-{
-    session_id = sid;
-    session_key = session_key_v;
-    debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "session tracker %" PRIu64
-        " created\n", session_id);
-}
-
-Dce2Smb2SessionData* Dce2Smb2SessionTracker::get_current_flow()
+Dce2Smb2SessionData* Dce2Smb2SessionTracker::get_flow(uint32_t flow_key)
 {
-    Smb2FlowKey flow_key = get_smb2_flow_key();
+    std::lock_guard<std::mutex> guard(attached_flows_mutex);
     auto it_flow = attached_flows.find(flow_key);
     return (it_flow != attached_flows.end()) ? it_flow->second : nullptr;
 }
 
 Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::find_tree_for_message(
-    uint64_t message_id)
+    const uint64_t message_id, const uint32_t flow_key)
 {
+    std::lock_guard<std::mutex> guard(connected_trees_mutex);
     for (auto it_tree : connected_trees)
     {
-        Dce2Smb2RequestTracker* request = it_tree.second->find_request(message_id);
+        Dce2Smb2RequestTracker* request = it_tree.second->find_request(message_id, flow_key);
         if (request)
             return it_tree.second;
     }
     return nullptr;
 }
 
-void Dce2Smb2SessionTracker::process(uint16_t command, uint8_t command_type,
-    const Smb2Hdr* smb_header, const uint8_t* end)
+void Dce2Smb2SessionTracker::process(const uint16_t command, uint8_t command_type,
+    const Smb2Hdr* smb_header, const uint8_t* end, const uint32_t current_flow_key)
 {
     Dce2Smb2TreeTracker* tree = nullptr;
     uint32_t tree_id = Smb2Tid(smb_header);
 
     if (tree_id)
     {
+        connected_trees_mutex.lock();
         auto it_tree = connected_trees.find(tree_id);
         if (it_tree != connected_trees.end())
             tree = it_tree->second;
+        connected_trees_mutex.unlock();
     }
     else
     {
         //async response case
-        tree = find_tree_for_message(Smb2Mid(smb_header));
+        tree = find_tree_for_message(Smb2Mid(smb_header), current_flow_key);
     }
 
     switch (command)
@@ -82,15 +79,16 @@ void Dce2Smb2SessionTracker::process(uint16_t command, uint8_t command_type,
     {
         uint8_t share_type = ((const Smb2TreeConnectResponseHdr*)
             ((const uint8_t*)smb_header + SMB2_HEADER_LENGTH))->share_type;
-        connect_tree(tree_id, share_type);
+        connect_tree(tree_id, current_flow_key, share_type);
     }
     break;
-
     case SMB2_COM_TREE_DISCONNECT:
         if (tree)
         {
             delete tree;
+            connected_trees_mutex.lock();
             connected_trees.erase(tree_id);
+            connected_trees_mutex.unlock();
         }
         else
             dce2_smb_stats.v2_tree_discn_ignored++;
@@ -103,7 +101,7 @@ void Dce2Smb2SessionTracker::process(uint16_t command, uint8_t command_type,
             debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
                 "%s_REQ: mid-stream session detected\n",
                 smb2_command_string[command]);
-            tree = connect_tree(tree_id);
+            tree = connect_tree(tree_id, current_flow_key);
             if (!tree)
             {
                 debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
@@ -114,7 +112,7 @@ void Dce2Smb2SessionTracker::process(uint16_t command, uint8_t command_type,
     // fallthrough
     default:
         if (tree)
-            tree->process(command, command_type, smb_header, end);
+            tree->process(command, command_type, smb_header, end, current_flow_key);
         else
         {
             debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
@@ -125,11 +123,12 @@ void Dce2Smb2SessionTracker::process(uint16_t command, uint8_t command_type,
     }
 }
 
-Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::connect_tree(uint32_t tree_id,
-    uint8_t share_type)
+Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::connect_tree(const uint32_t tree_id,
+    const uint32_t current_flow_key, const uint8_t share_type)
 {
-    Dce2Smb2SessionData* current_flow = get_current_flow();
-    if ((SMB2_SHARE_TYPE_DISK == share_type) and (-1 == current_flow->get_max_file_depth()) and
+    Dce2Smb2SessionData* current_flow = get_flow(current_flow_key);
+    if ((SMB2_SHARE_TYPE_DISK == share_type) and current_flow and
+        (-1 == current_flow->get_max_file_depth()) and
         (-1 == current_flow->get_smb_file_depth()))
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "Not inserting TID (%u) "
@@ -138,34 +137,54 @@ Dce2Smb2TreeTracker* Dce2Smb2SessionTracker::connect_tree(uint32_t tree_id,
         return nullptr;
     }
     Dce2Smb2TreeTracker* tree = nullptr;
+    connected_trees_mutex.lock();
     auto it_tree = connected_trees.find(tree_id);
     if (it_tree != connected_trees.end())
         tree = it_tree->second;
+    connected_trees_mutex.unlock();
     if (!tree)
     {
         tree = new Dce2Smb2TreeTracker(tree_id, this, share_type);
+        connected_trees_mutex.lock();
         connected_trees.insert(std::make_pair(tree_id, tree));
+        connected_trees_mutex.unlock();
+        increase_size(sizeof(Dce2Smb2TreeTracker));
     }
     return tree;
 }
 
-void Dce2Smb2SessionTracker::attach_flow(Smb2FlowKey flow_key,
-    Dce2Smb2SessionData* ssd)
+void Dce2Smb2SessionTracker::clean_file_context_from_flow(Dce2Smb2FileTracker* file_tracker,
+    uint64_t file_id, uint64_t file_name_hash)
+{
+    for (auto it_flow : attached_flows)
+    {
+        snort::FileFlows* file_flows = snort::FileFlows::get_file_flows(
+            it_flow.second->get_tcp_flow(), false);
+        if (file_flows)
+            file_flows->remove_processed_file_context(file_name_hash, file_id);
+        it_flow.second->reset_matching_tcp_file_tracker(file_tracker);
+    }
+}
+
+void Dce2Smb2SessionTracker::increase_size(const size_t size)
 {
-    attached_flows.insert(std::make_pair(flow_key,ssd));
+    smb2_session_cache.increase_size(size);
 }
 
-bool Dce2Smb2SessionTracker::detach_flow(Smb2FlowKey& flow_key)
+void Dce2Smb2SessionTracker::decrease_size(const size_t size)
 {
-    attached_flows.erase(flow_key);
-    return (0 == attached_flows.size());
+    smb2_session_cache.decrease_size(size);
 }
 
 // Session Tracker is created and destroyed only from session cache
 Dce2Smb2SessionTracker::~Dce2Smb2SessionTracker(void)
 {
-    debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "session tracker %" PRIu64
-        " terminating\n", session_id);
+    if (smb_module_is_up)
+    {
+        debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
+            "session tracker %" PRIu64 " terminating\n", session_id);
+    }
+
     auto it_tree = connected_trees.begin();
     while (it_tree != connected_trees.end())
     {
@@ -176,7 +195,5 @@ Dce2Smb2SessionTracker::~Dce2Smb2SessionTracker(void)
 
     for (auto it_flow : attached_flows)
         it_flow.second->remove_session(session_id);
-
-    memory::MemoryCap::update_deallocations(sizeof(*this));
 }
 
index e7472061928979a3f1e6cd2cb3e5eb521f40ad53..b99ff52bc3bb31700fedd0bea87cdcdfaa46b4dd 100644 (file)
@@ -31,30 +31,52 @@ uint32_t Smb2Tid(const Smb2Hdr* hdr);
 class Dce2Smb2SessionTracker
 {
 public:
-    Dce2Smb2SessionTracker()
+    Dce2Smb2SessionTracker(const Smb2SessionKey& key)
     {
-        session_id = 0;
-        session_key = { };
-        memory::MemoryCap::update_allocations(sizeof(*this));
+        session_id = key.sid;
+        session_key = key;
+        debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "session tracker %" PRIu64
+            " created\n", session_id);
     }
 
     ~Dce2Smb2SessionTracker();
-    void init(uint64_t, const Smb2SessionKey&);
-    void attach_flow(Smb2FlowKey, Dce2Smb2SessionData*);
-    bool detach_flow(Smb2FlowKey&);
-    void process(uint16_t, uint8_t, const Smb2Hdr*, const uint8_t*);
-    void disconnect_tree(uint32_t tree_id) { connected_trees.erase(tree_id); }
-    Dce2Smb2SessionData* get_current_flow();
+    Dce2Smb2TreeTracker* connect_tree(const uint32_t, const uint32_t,
+        uint8_t=SMB2_SHARE_TYPE_DISK);
+    void disconnect_tree(uint32_t tree_id)
+    {
+        std::lock_guard<std::mutex> guard(connected_trees_mutex);
+        connected_trees.erase(tree_id);
+        decrease_size(sizeof(Dce2Smb2TreeTracker));
+    }
+
+    void attach_flow(uint32_t flow_key, Dce2Smb2SessionData* ssd)
+    {
+        std::lock_guard<std::mutex> guard(attached_flows_mutex);
+        attached_flows.insert(std::make_pair(flow_key,ssd));
+    }
+
+    bool detach_flow(uint32_t flow_key)
+    {
+        std::lock_guard<std::mutex> guard(attached_flows_mutex);
+        attached_flows.erase(flow_key);
+        return (0 == attached_flows.size());
+    }
+
     Smb2SessionKey get_key() { return session_key; }
-    Dce2Smb2SessionDataMap get_attached_flows() { return attached_flows; }
-    Dce2Smb2TreeTracker* connect_tree(uint32_t, uint8_t=SMB2_SHARE_TYPE_DISK);
+    void clean_file_context_from_flow(Dce2Smb2FileTracker*, uint64_t, uint64_t);
+    Dce2Smb2SessionData* get_flow(uint32_t);
+    void process(const uint16_t, uint8_t, const Smb2Hdr*, const uint8_t*, const uint32_t);
+    void increase_size(const size_t size);
+    void decrease_size(const size_t size);
 
 private:
-    Dce2Smb2TreeTracker* find_tree_for_message(uint64_t);
+    Dce2Smb2TreeTracker* find_tree_for_message(const uint64_t, const uint32_t);
     uint64_t session_id;
     Smb2SessionKey session_key;
     Dce2Smb2SessionDataMap attached_flows;
     Dce2Smb2TreeTrackerMap connected_trees;
+    std::mutex connected_trees_mutex;
+    std::mutex attached_flows_mutex;
 };
 
 #endif
index 47f3f93c2ea3b30f6567e18210328cdf94505a69..6a19c46894e24c7a5e3570a7eee612d4af5b73fe 100644 (file)
 
 #define SMB_AVG_FILES_PER_SESSION 5
 
-template<typename Key, typename Value, typename Hash>
-class Dce2Smb2SharedCache : public LruCacheShared<Key, Value, Hash>
+template<typename Key, typename Value, typename Hash, typename Eq = std::equal_to<Key>,
+    typename Purgatory = std::vector<std::shared_ptr<Value> > >
+class Dce2Smb2SharedCache : public LruCacheShared<Key, Value, Hash, Eq, Purgatory>
 {
 public:
     Dce2Smb2SharedCache() = delete;
     Dce2Smb2SharedCache(const Dce2Smb2SharedCache& arg) = delete;
     Dce2Smb2SharedCache& operator=(const Dce2Smb2SharedCache& arg) = delete;
     Dce2Smb2SharedCache(const size_t initial_size) :
-        LruCacheShared<Key, Value, Hash>(initial_size) { }
-    virtual ~Dce2Smb2SharedCache() { }
+        LruCacheShared<Key, Value, Hash, Eq, Purgatory>(initial_size) { }
 
-    Value* find_session(Key key)
-    { return this->find(key).get(); }
-    Value* find_else_create_session(Key key)
+    Value* find_session(Key key, Dce2Smb2SessionData* ssd)
     {
-        std::shared_ptr<Value> new_session = std::shared_ptr<Value>(new Value());
-        return this->find_else_insert(key, new_session, nullptr).get();
+        flow_mutex.lock();
+        Value* session = this->find(key).get();
+        if (session)
+            session->attach_flow(ssd->get_flow_key(), ssd);
+        flow_mutex.unlock();
+        return session;
+    }
+
+    Value* find_else_create_session(Key& key, Dce2Smb2SessionData* ssd)
+    {
+        std::shared_ptr<Value> new_session = std::shared_ptr<Value>(new Value(key));
+        flow_mutex.lock();
+        Value* session = this->find_else_insert(key, new_session, nullptr).get();
+        session->attach_flow(ssd->get_flow_key(), ssd);
+        flow_mutex.unlock();
+        return session;
+    }
+
+    size_t mem_size() override
+    {
+        return current_size;
+    }
+
+    void increase_size(size_t size)
+    {
+        current_size += size;
+    }
+
+    void decrease_size(size_t size)
+    {
+        assert(current_size >= size);
+        current_size -= size;
+    }
+
+private:
+    using LruCacheShared<Key, Value, Hash, Eq, Purgatory>::current_size;
+    using LruCacheShared<Key, Value, Hash, Eq, Purgatory>::cache_mutex;
+    std::mutex flow_mutex;
+    void increase_size(Value* value_ptr=nullptr) override
+    {
+        if (value_ptr) current_size += sizeof(*value_ptr);
+    }
+
+    void decrease_size(Value* value_ptr=nullptr) override
+    {
+        if (value_ptr)
+        {
+            assert(current_size >= sizeof(*value_ptr) );
+            current_size -= sizeof(*value_ptr);
+        }
     }
 };
 
 using Dce2Smb2SessionCache =
     Dce2Smb2SharedCache<Smb2SessionKey, Dce2Smb2SessionTracker, Smb2KeyHash>;
 
-extern THREAD_LOCAL Dce2Smb2SessionCache* smb2_session_cache;
-
-inline void DCE2_SmbSessionCacheInit(const size_t cache_size)
-{
-    smb2_session_cache = new Dce2Smb2SessionCache(cache_size);
-}
+extern Dce2Smb2SessionCache smb2_session_cache;
 
 #endif
 
index 4958b9f5a880452eaa1df1e9c9d12b57c22c7414..b1f6a1df0893a9518995b1abf5128ace8249d0b2 100644 (file)
@@ -36,8 +36,20 @@ uint64_t Smb2Mid(const Smb2Hdr* hdr)
     return alignedNtohq(&(hdr->message_id));
 }
 
+Dce2Smb2FileTracker* Dce2Smb2TreeTracker::open_file(const uint64_t file_id,
+    const uint32_t current_flow_key)
+{
+    Dce2Smb2FileTracker* ftracker = new Dce2Smb2FileTracker(file_id, current_flow_key, this);
+    tree_tracker_mutex.lock();
+    opened_files.insert(std::make_pair(file_id, ftracker));
+    tree_tracker_mutex.unlock();
+    parent_session->increase_size(sizeof(Dce2Smb2FileTracker));
+    return ftracker;
+}
+
 Dce2Smb2FileTracker* Dce2Smb2TreeTracker::find_file(uint64_t file_id)
 {
+    std::lock_guard<std::mutex> guard(tree_tracker_mutex);
     auto it_file = opened_files.find(file_id);
     if (it_file != opened_files.end())
         return it_file->second;
@@ -46,29 +58,41 @@ Dce2Smb2FileTracker* Dce2Smb2TreeTracker::find_file(uint64_t file_id)
 
 void Dce2Smb2TreeTracker::close_file(uint64_t file_id, bool destroy)
 {
+    tree_tracker_mutex.lock();
     auto it_file = opened_files.find(file_id);
     if (it_file != opened_files.end())
     {
         Dce2Smb2FileTracker* file = it_file->second;
         if (opened_files.erase(file_id) and destroy)
+        {
+            parent_session->decrease_size(sizeof(Dce2Smb2FileTracker));
+            tree_tracker_mutex.unlock();
             delete file;
+            return;
+        }
     }
+    tree_tracker_mutex.unlock();
 }
 
-Dce2Smb2RequestTracker* Dce2Smb2TreeTracker::find_request(uint64_t message_id)
+Dce2Smb2RequestTracker* Dce2Smb2TreeTracker::find_request(const uint64_t message_id,
+    const uint32_t current_flow_key)
 {
-    auto request_it = active_requests.find(message_id);
-    return (request_it == active_requests.end()) ?
-        nullptr : request_it->second;
+    Smb2MessageKey message_key = { message_id, current_flow_key, 0 };
+    std::lock_guard<std::mutex> guard(tree_tracker_mutex);
+    auto request_it = active_requests.find(message_key);
+    return (request_it == active_requests.end()) ? nullptr : request_it->second;
 }
 
-bool Dce2Smb2TreeTracker::remove_request(uint64_t message_id)
+bool Dce2Smb2TreeTracker::remove_request(const uint64_t message_id,
+    const uint32_t current_flow_key)
 {
-    auto request_it = active_requests.find(message_id);
+    Smb2MessageKey message_key = { message_id, current_flow_key, 0 };
+    std::lock_guard<std::mutex> guard(tree_tracker_mutex);
+    auto request_it = active_requests.find(message_key);
     if (request_it != active_requests.end())
     {
         delete request_it->second;
-        return active_requests.erase(message_id);
+        return active_requests.erase(message_key);
     }
     return false;
 }
@@ -101,7 +125,8 @@ void Dce2Smb2TreeTracker::process_set_info_request(const Smb2Hdr* smb_header)
     }
 }
 
-void Dce2Smb2TreeTracker::process_close_request(const Smb2Hdr* smb_header)
+void Dce2Smb2TreeTracker::process_close_request(const Smb2Hdr* smb_header,
+    const uint32_t current_flow_key)
 {
     const uint8_t* smb_data = (const uint8_t*)smb_header + SMB2_HEADER_LENGTH;
     uint64_t file_id = alignedNtohq(&(((const Smb2CloseRequestHdr*)
@@ -114,7 +139,7 @@ void Dce2Smb2TreeTracker::process_close_request(const Smb2Hdr* smb_header)
             smb2_command_string[SMB2_COM_CLOSE], file_id);
         return;
     }
-    if (file_tracker->close())
+    if (file_tracker->close(current_flow_key))
         close_file(file_id);
 
     if (share_type != SMB2_SHARE_TYPE_DISK)
@@ -165,8 +190,8 @@ uint64_t Dce2Smb2TreeTracker::get_durable_file_id(
     return 0;
 }
 
-void Dce2Smb2TreeTracker::process_create_response(uint64_t message_id,
-    const Smb2Hdr* smb_header)
+void Dce2Smb2TreeTracker::process_create_response(const uint64_t message_id,
+    const uint32_t current_flow_key, const Smb2Hdr* smb_header)
 {
     const uint8_t* smb_data = (const uint8_t*)smb_header + SMB2_HEADER_LENGTH;
     const Smb2CreateResponseHdr* create_res_hdr = (const Smb2CreateResponseHdr*)smb_data;
@@ -182,20 +207,17 @@ void Dce2Smb2TreeTracker::process_create_response(uint64_t message_id,
     }
     else
     {
-        Dce2Smb2RequestTracker* create_request = find_request(message_id);
+        Dce2Smb2RequestTracker* create_request = find_request(message_id, current_flow_key);
         if (create_request)
         {
             Dce2Smb2FileTracker* file_tracker = find_file(file_id);
             if (!file_tracker)
-            {
-                file_tracker = new Dce2Smb2FileTracker(file_id, this);
-                opened_files.insert(std::make_pair(file_id, file_tracker));
-            }
+                file_tracker = open_file(file_id, current_flow_key);
+
             if (share_type == SMB2_SHARE_TYPE_DISK)
             {
                 file_tracker->set_info(create_request->get_file_name(),
-                    create_request->get_file_name_size(), file_size, true);
-                create_request->reset_file_name();
+                    create_request->get_file_name_size(), file_size);
             }
         }
         else
@@ -207,8 +229,8 @@ void Dce2Smb2TreeTracker::process_create_response(uint64_t message_id,
     }
 }
 
-void Dce2Smb2TreeTracker::process_create_request(uint64_t message_id,
-    const Smb2Hdr* smb_header, const uint8_t* end)
+void Dce2Smb2TreeTracker::process_create_request(const uint64_t message_id,
+    const uint32_t current_flow_key, const Smb2Hdr* smb_header, const uint8_t* end)
 {
     const uint8_t* smb_data = (const uint8_t*)smb_header + SMB2_HEADER_LENGTH;
     const Smb2CreateRequestHdr* create_req_hdr = (const Smb2CreateRequestHdr*)smb_data;
@@ -236,7 +258,11 @@ void Dce2Smb2TreeTracker::process_create_request(uint64_t message_id,
     char* file_name = get_smb_file_name(file_name_offset, file_name_size, true, &name_len);
     //keep a request tracker with the available info
     Dce2Smb2RequestTracker* create_request = new Dce2Smb2RequestTracker(file_name, name_len);
-    store_request(message_id, create_request);
+    if (!store_request(message_id, current_flow_key, create_request))
+    {
+        debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "SMB2_COM_CREATE_REQ: store failed\n");
+        delete create_request;
+    }
     //check if file_id is available form a durable reconnect request.
     //if present we can create a file tracker right now.
     //mostly this is the case for compound request.
@@ -246,24 +272,22 @@ void Dce2Smb2TreeTracker::process_create_request(uint64_t message_id,
         Dce2Smb2FileTracker* file_tracker = find_file(file_id);
         if (!file_tracker)
         {
-            file_tracker = new Dce2Smb2FileTracker(file_id, this);
+            file_tracker = open_file(file_id, current_flow_key);
             if (share_type == SMB2_SHARE_TYPE_DISK)
             {
-                file_tracker->set_info(file_name, name_len, 0, true);
-                create_request->reset_file_name();
+                file_tracker->set_info(file_name, name_len, 0);
             }
-            opened_files.insert(std::make_pair(file_id, file_tracker));
         }
     }
 }
 
-void Dce2Smb2TreeTracker::process_read_response(uint64_t message_id,
-    const Smb2Hdr* smb_header, const uint8_t* end)
+void Dce2Smb2TreeTracker::process_read_response(const uint64_t message_id,
+    const uint32_t current_flow_key, const Smb2Hdr* smb_header, const uint8_t* end)
 {
     const uint8_t* smb_data = (const uint8_t*)smb_header + SMB2_HEADER_LENGTH;
     const Smb2ReadResponseHdr* read_resp_hdr = (const Smb2ReadResponseHdr*)smb_data;
 
-    Dce2Smb2RequestTracker* read_request = find_request(message_id);
+    Dce2Smb2RequestTracker* read_request = find_request(message_id, current_flow_key);
     if (!read_request)
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
@@ -272,7 +296,7 @@ void Dce2Smb2TreeTracker::process_read_response(uint64_t message_id,
         return;
     }
     uint16_t data_offset = alignedNtohs((const uint16_t*)(&(read_resp_hdr->data_offset)));
-    Dce2Smb2SessionData* current_flow = parent_session->get_current_flow();
+    Dce2Smb2SessionData* current_flow = parent_session->get_flow(current_flow_key);
     if (data_offset + (const uint8_t*)smb_header > end)
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "SMB2_COM_READ_RESP: bad offset\n");
@@ -286,15 +310,13 @@ void Dce2Smb2TreeTracker::process_read_response(uint64_t message_id,
         const uint8_t* file_data =  (const uint8_t*)read_resp_hdr +
             SMB2_READ_RESPONSE_STRUC_SIZE - 1;
         int data_size = end - file_data;
-        if (file_tracker->process_data(file_data, data_size, read_request->get_offset()))
+        if (file_tracker->process_data(current_flow_key, file_data, data_size, read_request->get_offset()))
         {
             if ((uint32_t)data_size < alignedNtohl((const uint32_t*)&(read_resp_hdr->length)))
             {
                 file_tracker->accept_raw_data_from(current_flow);
             }
         }
-        else
-            close_file(file_tracker->get_file_id());
     }
     else
     {
@@ -303,24 +325,28 @@ void Dce2Smb2TreeTracker::process_read_response(uint64_t message_id,
     }
 }
 
-void Dce2Smb2TreeTracker::process_read_request(uint64_t message_id,
-    const Smb2Hdr* smb_header)
+void Dce2Smb2TreeTracker::process_read_request(const uint64_t message_id,
+    const uint32_t current_flow_key, const Smb2Hdr* smb_header)
 {
     const uint8_t* smb_data = (const uint8_t*)smb_header + SMB2_HEADER_LENGTH;
     const Smb2ReadRequestHdr* read_req_hdr = (const Smb2ReadRequestHdr*)smb_data;
     uint64_t file_id = alignedNtohq((const uint64_t*)(&(read_req_hdr->fileId_persistent)));
     uint64_t offset = alignedNtohq((const uint64_t*)(&(read_req_hdr->offset)));
     Dce2Smb2RequestTracker* read_request = new Dce2Smb2RequestTracker(file_id, offset);
-    store_request(message_id, read_request);
+    if (!store_request(message_id, current_flow_key, read_request))
+    {
+        debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "SMB2_COM_READ_REQ: store failed\n");
+        delete read_request;
+    }
 }
 
-void Dce2Smb2TreeTracker::process_write_request(uint64_t message_id,
-    const Smb2Hdr* smb_header, const uint8_t* end)
+void Dce2Smb2TreeTracker::process_write_request(const uint64_t message_id,
+    const uint32_t current_flow_key, const Smb2Hdr* smb_header, const uint8_t* end)
 {
     const uint8_t* smb_data = (const uint8_t*)smb_header + SMB2_HEADER_LENGTH;
     const Smb2WriteRequestHdr* write_req_hdr = (const Smb2WriteRequestHdr*)smb_data;
     uint64_t file_id = alignedNtohq((const uint64_t*)(&(write_req_hdr->fileId_persistent)));
-    Dce2Smb2SessionData* current_flow = parent_session->get_current_flow();
+    Dce2Smb2SessionData* current_flow = parent_session->get_flow(current_flow_key);
     if ((alignedNtohs((const uint16_t*)(&(write_req_hdr->data_offset))) +
         (const uint8_t*)smb_header > end) and current_flow)
     {
@@ -329,7 +355,11 @@ void Dce2Smb2TreeTracker::process_write_request(uint64_t message_id,
     }
     //track this request to clean up opened file in case of error response
     Dce2Smb2RequestTracker* write_request = new Dce2Smb2RequestTracker(file_id);
-    store_request(message_id, write_request);
+    if (!store_request(message_id, current_flow_key, write_request))
+    {
+        debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "SMB2_COM_WRITE_REQ: store failed\n");
+        delete write_request;
+    }
     const uint8_t* file_data = (const uint8_t*)write_req_hdr + SMB2_WRITE_REQUEST_STRUC_SIZE - 1;
     Dce2Smb2FileTracker* file_tracker = find_file(file_id);
     if (file_tracker)
@@ -337,15 +367,13 @@ void Dce2Smb2TreeTracker::process_write_request(uint64_t message_id,
         file_tracker->set_direction(FILE_UPLOAD);
         int data_size = end - file_data;
         uint64_t offset = alignedNtohq((const uint64_t*)(&(write_req_hdr->offset)));
-        if (file_tracker->process_data(file_data, data_size, offset))
+        if (file_tracker->process_data(current_flow_key, file_data, data_size, offset))
         {
             if ((uint32_t)data_size < alignedNtohl((const uint32_t*)&(write_req_hdr->length)))
             {
                 file_tracker->accept_raw_data_from(current_flow);
             }
         }
-        else
-            close_file(file_tracker->get_file_id());
     }
     else
     {
@@ -354,8 +382,8 @@ void Dce2Smb2TreeTracker::process_write_request(uint64_t message_id,
     }
 }
 
-void Dce2Smb2TreeTracker::process_ioctl_command(uint8_t command_type, const Smb2Hdr* smb_header,
-    const uint8_t* end)
+void Dce2Smb2TreeTracker::process_ioctl_command(const uint8_t command_type,
+    const uint32_t current_flow_key, const Smb2Hdr* smb_header, const uint8_t* end)
 {
     const uint8_t* smb_data = (const uint8_t*)smb_header + SMB2_HEADER_LENGTH;
     const uint8_t structure_size = (command_type == SMB2_CMD_TYPE_REQUEST) ?
@@ -363,7 +391,7 @@ void Dce2Smb2TreeTracker::process_ioctl_command(uint8_t command_type, const Smb2
 
     const uint8_t* file_data = (const uint8_t*)smb_data + structure_size - 1;
     int data_size = end - file_data;
-    Dce2Smb2SessionData* current_flow = parent_session->get_current_flow();
+    Dce2Smb2SessionData* current_flow = parent_session->get_flow(current_flow_key);
     if (data_size > UINT16_MAX)
     {
         data_size = UINT16_MAX;
@@ -373,19 +401,22 @@ void Dce2Smb2TreeTracker::process_ioctl_command(uint8_t command_type, const Smb2
 }
 
 void Dce2Smb2TreeTracker::process(uint16_t command, uint8_t command_type,
-    const Smb2Hdr* smb_header, const uint8_t* end)
+    const Smb2Hdr* smb_header, const uint8_t* end, const uint32_t current_flow_key)
 {
-    Dce2Smb2SessionData* current_flow = parent_session->get_current_flow();
+    Dce2Smb2SessionData* current_flow = parent_session->get_flow(current_flow_key);
+    tree_tracker_mutex.lock();
+    size_t pending_requests = active_requests.size();
+    tree_tracker_mutex.unlock();
     if (SMB2_CMD_TYPE_REQUEST == command_type and current_flow and
-        active_requests.size() >= current_flow->get_max_outstanding_requests())
+        pending_requests >= current_flow->get_max_outstanding_requests())
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
             "%s_REQ: max req exceeded\n", smb2_command_string[command]);
         dce_alert(GID_DCE2, DCE2_SMB_MAX_REQS_EXCEEDED, (dce2CommonStats*)&dce2_smb_stats,
             *current_flow->get_dce2_session_data());
-
         return;
     }
+
     uint64_t message_id = Smb2Mid(smb_header);
 
     switch (command)
@@ -405,13 +436,13 @@ void Dce2Smb2TreeTracker::process(uint16_t command, uint8_t command_type,
                     "processed for ipc share\n", smb2_command_string[command]);
                 dce2_smb_stats.v2_crt_req_ipc++;
             }
-            process_create_request(message_id, smb_header, end);
+            process_create_request(message_id, current_flow_key, smb_header, end);
         }
         else if (SMB2_CMD_TYPE_RESPONSE == command_type)
-            process_create_response(message_id, smb_header);
+            process_create_response(message_id, current_flow_key, smb_header);
         break;
     case SMB2_COM_CLOSE:
-        process_close_request(smb_header);
+        process_close_request(smb_header, current_flow_key);
         break;
     case SMB2_COM_SET_INFO:
         process_set_info_request(smb_header);
@@ -422,14 +453,14 @@ void Dce2Smb2TreeTracker::process(uint16_t command, uint8_t command_type,
             dce2_smb_stats.v2_read_err_resp++;
             debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
                 "%s_RESP: error\n", smb2_command_string[command]);
-            Dce2Smb2RequestTracker* request = find_request(message_id);
+            Dce2Smb2RequestTracker* request = find_request(message_id, current_flow_key);
             if (request)
                 close_file(request->get_file_id());
         }
         else if (SMB2_CMD_TYPE_REQUEST == command_type)
-            process_read_request(message_id, smb_header);
+            process_read_request(message_id, current_flow_key, smb_header);
         else if (SMB2_CMD_TYPE_RESPONSE == command_type)
-            process_read_response(message_id, smb_header, end);
+            process_read_response(message_id, current_flow_key, smb_header, end);
         break;
     case SMB2_COM_WRITE:
         if (SMB2_CMD_TYPE_ERROR_RESPONSE == command_type)
@@ -437,12 +468,12 @@ void Dce2Smb2TreeTracker::process(uint16_t command, uint8_t command_type,
             dce2_smb_stats.v2_wrt_err_resp++;
             debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
                 "%s_RESP: error\n", smb2_command_string[command]);
-            Dce2Smb2RequestTracker* request = find_request(message_id);
+            Dce2Smb2RequestTracker* request = find_request(message_id, current_flow_key);
             if (request)
                 close_file(request->get_file_id());
         }
         else if (SMB2_CMD_TYPE_REQUEST == command_type)
-            process_write_request(message_id, smb_header, end);
+            process_write_request(message_id, current_flow_key, smb_header, end);
         break;
     case SMB2_COM_IOCTL:
         if (SMB2_CMD_TYPE_ERROR_RESPONSE == command_type)
@@ -452,18 +483,19 @@ void Dce2Smb2TreeTracker::process(uint16_t command, uint8_t command_type,
         }
         else if (SMB2_SHARE_TYPE_DISK != share_type)
         {
-            process_ioctl_command(command_type, smb_header, end);
+            process_ioctl_command(command_type, current_flow_key, smb_header, end);
         }
         break;
     }
     if (SMB2_CMD_TYPE_RESPONSE == command_type or SMB2_CMD_TYPE_ERROR_RESPONSE == command_type)
-        remove_request(message_id);
+        remove_request(message_id, current_flow_key);
 }
 
 Dce2Smb2TreeTracker::~Dce2Smb2TreeTracker(void)
 {
-    debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
-        "tree tracker %" PRIu32 " terminating\n", tree_id);
+    if (smb_module_is_up)
+        debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "tree tracker %" PRIu32 " terminating\n",
+            tree_id);
 
     if (co_tracker != nullptr)
     {
@@ -471,26 +503,32 @@ Dce2Smb2TreeTracker::~Dce2Smb2TreeTracker(void)
         snort_free((void*)co_tracker);
         co_tracker = nullptr;
     }
+
+    tree_tracker_mutex.lock();
+
     if (active_requests.size())
     {
-        debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
-            "cleanup pending requests for below MIDs:\n");
-        for (auto it_request : active_requests)
+        if (smb_module_is_up)
         {
             debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
-                "mid %" PRIu64 "\n", it_request.first);
+                "cleanup pending requests for below MIDs:\n");
+        }
+        for (auto it_request : active_requests)
+        {
+            if (smb_module_is_up)
+            {
+                debug_logf(dce_smb_trace, GET_CURRENT_PACKET, "mid %" PRIu64 "\n",
+                    it_request.first.mid);
+            }
             delete it_request.second;
         }
     }
 
-    auto it_file = opened_files.begin();
-    while (it_file != opened_files.end())
-    {
-        Dce2Smb2FileTracker* file = it_file->second;
-        it_file = opened_files.erase(it_file);
-        delete file;
-    }
+    for (auto it_file : opened_files)
+        delete it_file.second;
+
+    tree_tracker_mutex.unlock();
+
     parent_session->disconnect_tree(tree_id);
-    memory::MemoryCap::update_deallocations(sizeof(*this));
 }
 
index 67ff167c39d00c15ca18a082789238a2cbca0ab1..9cc8b46f2b91e349eb813cf5b7ac0ab043414d2b 100644 (file)
@@ -45,7 +45,6 @@ public:
     {
         debug_logf(dce_smb_trace, GET_CURRENT_PACKET,
             "tree tracker %" PRIu32 " created\n", tree_id);
-        memory::MemoryCap::update_allocations(sizeof(*this));
         if (share_type != SMB2_SHARE_TYPE_DISK)
         {
             co_tracker = (DCE2_CoTracker*)snort_calloc(sizeof(DCE2_CoTracker));
@@ -59,11 +58,11 @@ public:
 
     ~Dce2Smb2TreeTracker();
 
-    void open_file(uint64_t);
+    Dce2Smb2FileTracker* open_file(const uint64_t, const uint32_t);
     void close_file(uint64_t, bool=true);
     Dce2Smb2FileTracker* find_file(uint64_t);
-    Dce2Smb2RequestTracker* find_request(uint64_t);
-    void process(uint16_t, uint8_t, const Smb2Hdr*, const uint8_t*);
+    Dce2Smb2RequestTracker* find_request(const uint64_t, const uint32_t);
+    void process(uint16_t, uint8_t, const Smb2Hdr*, const uint8_t*, const uint32_t);
     Dce2Smb2SessionTracker* get_parent() { return parent_session; }
     DCE2_CoTracker* get_cotracker() { return co_tracker; }
     uint32_t get_tree_id() { return tree_id; }
@@ -71,17 +70,22 @@ public:
 
 private:
     void process_set_info_request(const Smb2Hdr*);
-    void process_close_request(const Smb2Hdr*);
-    void process_create_response(uint64_t, const Smb2Hdr*);
-    void process_create_request(uint64_t, const Smb2Hdr*, const uint8_t*);
-    void process_read_response(uint64_t, const Smb2Hdr*, const uint8_t*);
-    void process_read_request(uint64_t, const Smb2Hdr*);
-    void process_write_request(uint64_t, const Smb2Hdr*, const uint8_t*);
+    void process_close_request(const Smb2Hdr*, const uint32_t);
+    void process_create_response(const uint64_t, const uint32_t, const Smb2Hdr*);
+    void process_create_request(const uint64_t, const uint32_t, const Smb2Hdr*, const uint8_t*);
+    void process_read_response(const uint64_t, const uint32_t, const Smb2Hdr*, const uint8_t*);
+    void process_read_request(const uint64_t, const uint32_t, const Smb2Hdr*);
+    void process_write_request(const uint64_t, const uint32_t, const Smb2Hdr*, const uint8_t*);
     uint64_t get_durable_file_id(const Smb2CreateRequestHdr*, const uint8_t*);
-    bool remove_request(uint64_t);
-    void process_ioctl_command(uint8_t, const Smb2Hdr*, const uint8_t*);
-    void store_request(uint64_t message_id, Dce2Smb2RequestTracker* request)
-    { active_requests.insert(std::make_pair(message_id, request)); }
+    bool remove_request(const uint64_t, const uint32_t);
+    void process_ioctl_command(const uint8_t, const uint32_t, const Smb2Hdr*, const uint8_t*);
+    bool store_request(const uint64_t message_id, const uint32_t current_flow_key,
+        Dce2Smb2RequestTracker* request)
+    {
+        Smb2MessageKey message_key = { message_id, current_flow_key, 0 };
+        std::lock_guard<std::mutex> guard(tree_tracker_mutex);
+        return active_requests.insert(std::make_pair(message_key, request)).second;
+    }
 
     uint32_t tree_id;
     uint8_t share_type;
@@ -89,6 +93,7 @@ private:
     Dce2Smb2FileTrackerMap opened_files;
     Dce2Smb2RequestTrackerMap active_requests;
     Dce2Smb2SessionTracker* parent_session;
+    std::mutex tree_tracker_mutex;
 };
 
 using Dce2Smb2TreeTrackerMap =
index 52c49b6e4f8eb58788fa5a28b40607b7057a7f79..026055772f4252b38d3da14cb1879f1486cfce88 100644 (file)
@@ -46,7 +46,14 @@ Dce2SmbFlowData::Dce2SmbFlowData(Dce2SmbSessionData* ssd_v) : FlowData(inspector
     if (dce2_smb_stats.max_concurrent_sessions < dce2_smb_stats.concurrent_sessions)
         dce2_smb_stats.max_concurrent_sessions = dce2_smb_stats.concurrent_sessions;
     ssd = ssd_v;
-    memory::MemoryCap::update_allocations(sizeof(*this));
+}
+
+void Dce2SmbFlowData::handle_expected(Packet* p)
+{
+    //we have a fd, but ssd not set, set it here in this flow
+    if (ssd)
+        delete ssd;
+    ssd = new Dce2Smb2SessionData(p, config);
 }
 
 Dce2SmbSessionData* Dce2SmbFlowData::upgrade(const Packet* p)
@@ -70,10 +77,10 @@ void Dce2SmbFlowData::handle_retransmit(Packet* p)
 
 Dce2SmbFlowData::~Dce2SmbFlowData()
 {
-    delete ssd;
+    if (ssd)
+        delete ssd;
     assert(dce2_smb_stats.concurrent_sessions > 0);
     dce2_smb_stats.concurrent_sessions--;
-    memory::MemoryCap::update_deallocations(sizeof(*this));
 }
 
 //Dce2SmbSessionData members
@@ -113,6 +120,16 @@ static inline DCE2_SmbVersion get_smb_version(const Packet* p)
     return DCE2_SMB_VERSION_NULL;
 }
 
+Dce2SmbFlowData* create_expected_smb_flow_data(const Packet* p, dce2SmbProtoConf* config)
+{
+    DCE2_SmbVersion smb_version = get_smb_version(p);
+    if (DCE2_SMB_VERSION_2 == smb_version)
+    {
+        return new Dce2SmbFlowData(config);
+    }
+    return nullptr;
+}
+
 Dce2SmbSessionData* create_new_smb_session(const Packet* p,
     dce2SmbProtoConf* config)
 {
@@ -143,10 +160,10 @@ inline FileContext* get_smb_file_context(const Packet* p)
     return file_flows ? file_flows->get_current_file_context() : nullptr;
 }
 
-FileContext* get_smb_file_context(uint64_t file_id,
+FileContext* get_smb_file_context(Flow* flow, uint64_t file_id,
     uint64_t multi_file_processing_id, bool to_create)
 {
-    FileFlows* file_flows = FileFlows::get_file_flows(DetectionEngine::get_current_packet()->flow);
+    FileFlows* file_flows = FileFlows::get_file_flows(flow);
 
     if ( !file_flows )
     {
@@ -202,11 +219,11 @@ void set_smb_reassembled_data(uint8_t* nb_ptr, uint16_t co_len)
     {
         Dce2SmbFlowData* fd = (Dce2SmbFlowData*)flow->get_flow_data(
             Dce2SmbFlowData::inspector_id);
-       if (fd)
-       {
-           Dce2SmbSessionData* smb_ssn_data = fd->get_smb_session_data();
-           smb_ssn_data->set_reassembled_data(nb_ptr, co_len);
-       }
+        if (fd)
+        {
+            Dce2SmbSessionData* smb_ssn_data = fd->get_smb_session_data();
+            smb_ssn_data->set_reassembled_data(nb_ptr, co_len);
+        }
     }
 }
 
index 63484d7be3d17024319c0f83ec5cf7dc83295359..23746949c864ec268af94e138697b573e4a79884 100644 (file)
@@ -215,6 +215,8 @@ enum Dce2SmbPduState
 
 extern THREAD_LOCAL dce2SmbStats dce2_smb_stats;
 extern THREAD_LOCAL snort::ProfileStats dce2_smb_pstat_main;
+extern bool smb_module_is_up;
+extern SnortProtocolId snort_protocol_id_smb;
 
 class Dce2SmbSessionData
 {
@@ -230,7 +232,7 @@ public:
     DCE2_SsnData* get_dce2_session_data()
     { return &sd; }
 
-    snort::Flow* get_flow()
+    snort::Flow* get_tcp_flow()
     { return tcp_flow; }
 
     int64_t get_max_file_depth()
@@ -265,6 +267,15 @@ class Dce2SmbFlowData : public snort::FlowData
 {
 public:
     Dce2SmbFlowData(Dce2SmbSessionData*);
+    Dce2SmbFlowData(dce2SmbProtoConf* cfg) : snort::FlowData(inspector_id)
+    {
+        dce2_smb_stats.concurrent_sessions++;
+        if (dce2_smb_stats.max_concurrent_sessions < dce2_smb_stats.concurrent_sessions)
+            dce2_smb_stats.max_concurrent_sessions = dce2_smb_stats.concurrent_sessions;
+        config = cfg;
+        ssd = nullptr;
+    }
+
     ~Dce2SmbFlowData() override;
 
     static void init()
@@ -278,18 +289,21 @@ public:
 
     Dce2SmbSessionData* upgrade(const snort::Packet*);
     void handle_retransmit(snort::Packet*) override;
+    void handle_expected(snort::Packet*) override;
 
 public:
     static unsigned inspector_id;
 
 private:
     Dce2SmbSessionData* ssd;
+    dce2SmbProtoConf* config;
 };
 
+Dce2SmbFlowData* create_expected_smb_flow_data(const snort::Packet*, dce2SmbProtoConf*);
 Dce2SmbSessionData* create_new_smb_session(const snort::Packet*, dce2SmbProtoConf*);
 DCE2_SsnData* get_dce2_session_data(snort::Flow*);
 snort::FileContext* get_smb_file_context(const snort::Packet*);
-snort::FileContext* get_smb_file_context(uint64_t, uint64_t, bool);
+snort::FileContext* get_smb_file_context(snort::Flow*, uint64_t, uint64_t, bool);
 char* get_smb_file_name(const uint8_t*, uint32_t, bool, uint16_t*);
 void set_smb_reassembled_data(uint8_t*, uint16_t);
 
index de9ea879f96f12d7d70b6530956c09f70732d2ac..65924ab6b8553eb03393cca489e66505efbaee70 100644 (file)
@@ -21,7 +21,6 @@
 #ifdef HAVE_CONFIG_H
 #include "config.h"
 #endif
-
 #include "dce_smb_inspector.h"
 
 #include "dce_smb_common.h"
 #include "dce_smb_utils.h"
 #include "dce_smb2_session_cache.h"
 
+#define DCE_SMB_PROTOCOL_ID "netbios-ssn"
+
 using namespace snort;
 
-Dce2Smb::Dce2Smb(const dce2SmbProtoConf& pc)
-{
-    config = pc;
-}
+bool smb_module_is_up = false;
+SnortProtocolId snort_protocol_id_smb = UNKNOWN_PROTOCOL_ID;
+
+Dce2Smb::Dce2Smb(const dce2SmbProtoConf& pc) :
+    config(pc) { }
 
 Dce2Smb::~Dce2Smb()
 {
@@ -44,6 +46,12 @@ Dce2Smb::~Dce2Smb()
     }
 }
 
+bool Dce2Smb::configure(SnortConfig* sc)
+{
+    snort_protocol_id_smb = sc->proto_ref->add(DCE_SMB_PROTOCOL_ID);
+    return true;
+}
+
 void Dce2Smb::show(const SnortConfig*) const
 {
     print_dce2_smb_conf(config);
@@ -72,12 +80,10 @@ void Dce2Smb::eval(Packet* p)
         p->endianness = new DceEndianness();
 
         smb_session_data->process();
-
         //smb_session_data may not be valid anymore in case of upgrade
         //but flow will always have updated session
         if (!dce2_detected)
             DCE2_Detect(get_dce2_session_data(p->flow));
-
         delete(p->endianness);
         p->endianness = nullptr;
     }
@@ -98,8 +104,6 @@ void Dce2Smb::clear(Packet* p)
 // api stuff
 //-------------------------------------------------------------------------
 
-size_t session_cache_size;
-
 static Module* mod_ctor()
 {
     return new Dce2SmbModule;
@@ -118,38 +122,20 @@ static void dce2_smb_init()
     DceContextData::init(DCE2_TRANS_TYPE__SMB);
 }
 
-static void dce2_smb_thread_int()
-{
-    DCE2_SmbSessionCacheInit(session_cache_size);
-}
-
-static void dce_smb_thread_term()
-{
-    delete smb2_session_cache;
-}
-
-static size_t get_max_smb_session(dce2SmbProtoConf* config)
-{
-    size_t smb_sess_storage_req = (sizeof(Dce2Smb2SessionTracker) +
-        sizeof(Dce2Smb2TreeTracker) + sizeof(Dce2Smb2RequestTracker) +
-        (sizeof(Dce2Smb2FileTracker) * SMB_AVG_FILES_PER_SESSION));
-
-    size_t max_smb_mem = DCE2_ScSmbMemcap(config);
-
-    return (max_smb_mem/smb_sess_storage_req);
-}
-
 static Inspector* dce2_smb_ctor(Module* m)
 {
     Dce2SmbModule* mod = (Dce2SmbModule*)m;
     dce2SmbProtoConf config;
     mod->get_data(config);
-    session_cache_size = get_max_smb_session(&config);
+    size_t max_smb_mem = DCE2_ScSmbMemcap(&config);
+    smb_module_is_up = true;
+    smb2_session_cache.set_max_size(max_smb_mem);
     return new Dce2Smb(config);
 }
 
 static void dce2_smb_dtor(Inspector* p)
 {
+    smb_module_is_up = false;
     delete p;
 }
 
@@ -170,11 +156,11 @@ const InspectApi dce2_smb_api =
     IT_SERVICE,
     PROTO_BIT__PDU,
     nullptr,  // buffers
-    "netbios-ssn",
+    DCE_SMB_PROTOCOL_ID,
     dce2_smb_init,
-    nullptr, // pterm
-    dce2_smb_thread_int, // tinit
-    dce_smb_thread_term, // tterm
+    nullptr,
+    nullptr, // tinit
+    nullptr, // tterm
     dce2_smb_ctor,
     dce2_smb_dtor,
     nullptr, // ssn
index fb1c8dd2e9fceef00fe70c4060e461640b0c0a79..171f036503a6cb6f007c180382ff6225cf9cd4c7 100644 (file)
@@ -32,6 +32,7 @@ public:
     Dce2Smb(const dce2SmbProtoConf&);
     ~Dce2Smb() override;
 
+    bool configure(snort::SnortConfig*) override;
     void show(const snort::SnortConfig*) const override;
     void eval(snort::Packet*) override;
     void clear(snort::Packet*) override;