]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4806: file_api: multi process snort file cache sharing crash fix
authorShilpa Nagpal (shinagpa) <shinagpa@cisco.com>
Tue, 22 Jul 2025 14:42:06 +0000 (14:42 +0000)
committerBhargava Jandhyala (bjandhya) <bjandhya@cisco.com>
Tue, 22 Jul 2025 14:42:06 +0000 (14:42 +0000)
Merge in SNORT/snort3 from ~SHINAGPA/snort3:mp_file_fix to master

Squashed commit of the following:

commit fa415933046d9d74e92d9bfc6b0f044283e6dd97
Author: Shilpa Nagpal <shinagpa@cisco.com>
Date:   Mon Jul 14 13:32:25 2025 +0530

    file_api: multi process snort file cache sharing crash fix

12 files changed:
src/file_api/CMakeLists.txt
src/file_api/file_cache.cc
src/file_api/file_cache.h
src/file_api/file_cache_share.cc
src/file_api/file_cache_share.h
src/file_api/file_flows.cc
src/file_api/file_inspect.cc
src/file_api/file_lib.cc
src/file_api/file_lib.h
src/pub_sub/CMakeLists.txt
src/pub_sub/file_events.h
src/pub_sub/file_events_ids.h

index ef732bf6dbb59933e79ae12ab77fcbd591993ad7..fd31341b45a5cb6b0da0aaa1a10470d01c7e28bd 100644 (file)
@@ -15,6 +15,8 @@ add_library ( file_api OBJECT
     file_capture.cc
     file_cache.cc
     file_cache.h
+    file_cache_share.cc
+    file_cache_share.h
     file_config.cc
     file_config.h
     file_flows.cc
@@ -35,8 +37,6 @@ add_library ( file_api OBJECT
     file_service.cc
     file_stats.cc
     file_stats.h
-    file_cache_share.cc
-    file_cache_share.h
 )
 
 install (FILES ${FILE_API_INCLUDES}
index d1cb1152833d35ff307b3496eb76a4b5cf39b66c..90589e3df2b851decca6439a947674f6664aca3f 100644 (file)
 #include "main/thread_config.h"
 #include "packet_io/active.h"
 #include "packet_io/packet_tracer.h"
-#include "time/packet_time.h"
 #include "pub_sub/file_events.h"
+#include "time/packet_time.h"
 
 #include "file_flows.h"
 #include "file_module.h"
 #include "file_service.h"
 #include "file_stats.h"
-#include "file_cache_share.h"
 
 #define DEFAULT_FILE_LOOKUP_TIMEOUT_CACHED_ITEM 3600    // 1 hour
 
@@ -163,7 +162,7 @@ FileContext* FileCache::find_add(const FileHashKey& hashKey, int64_t timeout)
     return node->file;
 }
 
-FileContext* FileCache::add(const FileHashKey& hashKey, int64_t timeout, bool &cache_full, int64_t& cache_expire, FileInspect* ins)
+FileContext* FileCache::add(const FileHashKey& hashKey, int64_t timeout, bool &cache_full, int64_t& cache_expire, bool cache_sync)
 {
     FileNode new_node;
     /*
@@ -183,11 +182,7 @@ FileContext* FileCache::add(const FileHashKey& hashKey, int64_t timeout, bool &c
 
     if (!file)
     {
-        if (ins)
-            new_node.file = new FileContext(ins);
-        else
-            new_node.file = new FileContext;
-
+        new_node.file = new FileContext;
         int ret = fileHash->insert((void*)&hashKey, &new_node);
         cache_expire = new_node.cache_expire_time.tv_sec;
 
@@ -202,7 +197,8 @@ FileContext* FileCache::add(const FileHashKey& hashKey, int64_t timeout, bool &c
                     PacketTracer::log("add:Insert failed in file cache, returning\n");
                 cache_full = true;
             }
-            if (!ins)
+
+            if (!cache_sync)
                 FILE_DEBUG(file_trace, DEFAULT_TRACE_OPTION_ID, TRACE_CRITICAL_LEVEL, GET_CURRENT_PACKET,
                     "add:Insert failed in file cache, returning\n");
             file_counts.cache_add_fails++;
@@ -214,7 +210,7 @@ FileContext* FileCache::add(const FileHashKey& hashKey, int64_t timeout, bool &c
     }
     else
     {
-        if (!ins)
+        if (!cache_sync)
             FILE_DEBUG(file_trace, DEFAULT_TRACE_OPTION_ID, TRACE_CRITICAL_LEVEL, GET_CURRENT_PACKET,
                 "add:file already found in file cache, returning\n");
         return file;
index 0b805e74f222fd856f42fe09acf62ada77f88c98..8c7ed5db17dacf8fee423caff374975bf87c2ab2 100644 (file)
@@ -59,13 +59,12 @@ public:
     void set_block_timeout(int64_t);
     void set_lookup_timeout(int64_t);
     void set_max_files(int64_t);
-    snort::FileContext* add(const FileHashKey&, int64_t timeout, bool &cache_full, int64_t& cache_expire, FileInspect* ins = nullptr);
+    snort::FileContext* add(const FileHashKey&, int64_t timeout, bool &cache_full, int64_t& cache_expire, bool cache_sync = false);
     snort::FileContext* get_file(snort::Flow*, uint64_t file_id, bool to_create, bool using_cache_entry);
     FileVerdict cached_verdict_lookup(snort::Packet*, snort::FileInfo*,
         snort::FilePolicyBase*,const uint8_t* current_data, uint32_t current_data_len);
     bool apply_verdict(snort::Packet*, snort::FileContext*, FileVerdict, bool resume,
         snort::FilePolicyBase*);
-    
 
 private: 
     snort::FileContext* find(const FileHashKey&, int64_t, int64_t& cache_expire);
index fb5ab14fdaa6c4ac399a93a13aef598fc3b531dc..7d06c51cf2f91bb9ab5613d5e8b2c1c3ee172ee7 100644 (file)
@@ -17,8 +17,9 @@
 //--------------------------------------------------------------------------
 //  file_cache_share.cc author Shilpa Nagpal <shinagpa@cisco.com>
 
-
+#ifdef HAVE_CONFIG_H
 #include "config.h"
+#endif
 
 #include "file_cache_share.h"
 #include "file_service.h"
@@ -33,14 +34,12 @@ void FileCacheShare::handle(DataEvent& de, Flow*)
     FileHashKey key = fe.get_hashkey();
     FileInfo file = fe.get_file_ctx();
 
-    LogMessage("File Cache Sharing: Received event with file_id %lu\n", key.file_id);
-
     FileCache* file_cache = FileService::get_file_cache();
     if (file_cache)
     {
         bool cache_full = false;
         int64_t cache_expire = 0;
-        FileContext* file_got = file_cache->add(key, timeout, cache_full, cache_expire, ins);
+        FileContext* file_got = file_cache->add(key, timeout, cache_full, cache_expire, true);
         if (file_got)
         {    
             *((FileInfo*)(file_got)) = file;
@@ -48,30 +47,31 @@ void FileCacheShare::handle(DataEvent& de, Flow*)
     }
 }
 
-bool serialize_file_event(snort::DataEvent* event, char*& buffer, uint16_t* len)
+bool serialize_file_event(DataEvent* event, char*& buffer, uint16_t* len)
 {
     if (!event)
         return false;
 
-    snort::FileMPEvent* file_event = static_cast<snort::FileMPEvent*>(event);
+    FileMPEvent* file_event = static_cast<FileMPEvent*>(event);
     uint16_t event_buffer_len = file_event->get_data_len();
     if (event_buffer_len == 0)
         return false;
-    
+
     buffer = new char[event_buffer_len];
     if (!buffer)
         return false;
 
-    file_event->serialize(buffer, len);
+    *len = event_buffer_len;
+    file_event->serialize(buffer, *len);
     return true;
 }
 
-bool deserialize_file_event(const char* buffer, uint16_t len, snort::DataEvent*& event)
+bool deserialize_file_event(const char* buffer, uint16_t len, DataEvent*& event)
 {
     if (!buffer || len == 0)
         return false;
 
-    snort::FileMPEvent* file_event = new snort::FileMPEvent();
+    FileMPEvent* file_event = new FileMPEvent();
     if (!file_event)
         return false;
 
index a8c5d44f611fd95a1bcf23b4edd8a6d27b5617f1..218e6f97870953e8c1d5f95bdcec9584ad81876d 100644 (file)
 #include "framework/mp_data_bus.h"
 #include "pub_sub/file_events.h"
 
+namespace snort
+{
 class FileCacheShare : public snort::DataHandler
 {
 public:
-    FileCacheShare(FileInspect* fi) : DataHandler(FILE_ID_NAME) { ins = fi; }
+    FileCacheShare() : DataHandler(FILE_ID_NAME) {}
     void handle(snort::DataEvent&, snort::Flow*) override;
-private:
-    FileInspect* ins;
 };
-
+}
 bool serialize_file_event(snort::DataEvent* event, char*& buffer, uint16_t* len);
 bool deserialize_file_event(const char* buffer, uint16_t len, snort::DataEvent*& event);
 
index de855e5bb53d035ae78a0e5dacebaaef0a7930f4..62b90d07f6c7862bd5d8589db5c83e0564d31e3c 100644 (file)
@@ -276,6 +276,13 @@ FileContext* FileFlows::get_file_context(
         FileCache* file_cache = FileService::get_file_cache();
         assert(file_cache);
         context = file_cache->get_file(flow, file_id, false, true);
+        if (context)
+        {
+            FileConfig *fc = get_file_config(SnortConfig::get_conf());  
+            
+            if (!context->get_config())
+                context->set_config(fc);
+        }
         FILE_DEBUG(file_trace, DEFAULT_TRACE_OPTION_ID, TRACE_DEBUG_LEVEL, GET_CURRENT_PACKET,
             "get_file_context:trying to get context from cache\n");
     }
index 3b6e4d73a5f07cdf6b93134513261ebd996ae091..51510ce0eb9dc1d2f50cfe8daa23c004f66261c1 100644 (file)
 #include "log/messages.h"
 
 #include "file_cache.h"
+#include "file_cache_share.h"
 #include "file_config.h"
 #include "file_flows.h"
 #include "file_module.h"
 #include "file_service.h"
-#include "file_cache_share.h"
 
 using namespace snort;
 
@@ -66,13 +66,13 @@ bool FileInspect::configure(SnortConfig* sc)
 
     FileService::set_max_file_depth(sc);
 
-    if(sc->mp_dbus)
+    if (sc->mp_dbus)
     {
         MPSerializeFunc serialize_func = serialize_file_event;
         MPDeserializeFunc deserialize_func = deserialize_file_event;
 
         MPDataBus::register_event_helpers(file_pub_key, FileMPEvents::FILE_SHARE_SYNC, serialize_func, deserialize_func);
-        MPDataBus::subscribe(file_pub_key, FileMPEvents::FILE_SHARE_SYNC, new FileCacheShare(this));
+        MPDataBus::subscribe(file_pub_key, FileMPEvents::FILE_SHARE_SYNC, new FileCacheShare());
     }
     return true;
 }
index 7de1d912f76479a4f6e6a8f64f3ff3e2fe5a2d50..127cd2a4c5ec786cee978e7318891d51dd0d12fb 100644 (file)
@@ -152,10 +152,9 @@ void FileInfo::copy(const FileInfo& other, bool clear_data)
     }
 }
 
-void FileInfo::serialize(char* buffer, uint16_tlen)
+void FileInfo::serialize(char* buffer, uint16_t buffer_len)
 {
-    int offset = *len;
-
+    uint16_t offset = 0;
     auto write_bool = [&](bool val) {
         memcpy(buffer + offset, &val, sizeof(val));
         offset += sizeof(val);
@@ -163,7 +162,7 @@ void FileInfo::serialize(char* buffer, uint16_t* len)
 
     auto write_string = [&](const std::string& str, bool is_set) {
         write_bool(is_set);
-        if (is_set)
+        if (is_set && offset < buffer_len)
         {
             uint32_t len = static_cast<uint32_t>(str.length());
             memcpy(buffer + offset, &len, sizeof(len));
@@ -172,10 +171,20 @@ void FileInfo::serialize(char* buffer, uint16_t* len)
             offset += len;
         }
     };
+    
+    if (sha256)
+        is_sha256_set = true;
 
-    memcpy(buffer + offset,(uint16_t *) sha256, SHA256_HASH_SIZE); 
-    offset += SHA256_HASH_SIZE;
-    memcpy(buffer + offset, &verdict, sizeof(verdict)); 
+    memcpy(buffer, &is_sha256_set, sizeof(is_sha256_set));
+    offset += sizeof(is_sha256_set);
+
+    if (is_sha256_set && sha256 && offset < buffer_len)
+    {
+        memcpy(buffer + offset,(uint16_t *) sha256, SHA256_HASH_SIZE);
+        offset += SHA256_HASH_SIZE;
+    }
+
+    memcpy(buffer + offset, &verdict, sizeof(verdict));
     offset += sizeof(verdict);
     memcpy(buffer + offset, &file_size, sizeof(file_size)); 
     offset += sizeof(file_size);
@@ -198,12 +207,11 @@ void FileInfo::serialize(char* buffer, uint16_t* len)
     write_bool(file_signature_enabled);
     write_bool(file_capture_enabled);
     write_bool(is_partial);
-
-    *len = offset;
 }
 
-void FileInfo::deserialize(const char* buffer, uint16_t& offset)
+void FileInfo::deserialize(const char* buffer, uint16_t buffer_len)
 {
+    uint16_t offset = 0;
     auto read_bool = [&](bool& val) {
         memcpy(&val, buffer + offset, sizeof(val));
         offset += sizeof(val);
@@ -211,7 +219,7 @@ void FileInfo::deserialize(const char* buffer, uint16_t& offset)
 
     auto read_string = [&](std::string& str, bool& is_set) {
         read_bool(is_set);
-        if (is_set)
+        if (is_set && offset < buffer_len)
         {
             uint32_t len = 0;
             memcpy(&len, buffer + offset, sizeof(len));
@@ -221,10 +229,16 @@ void FileInfo::deserialize(const char* buffer, uint16_t& offset)
         }
     };
 
-    if (!sha256)
-        sha256 = new uint8_t[SHA256_HASH_SIZE];  
-    memcpy(sha256, (const uint8_t *)(buffer + offset), SHA256_HASH_SIZE);
-    offset += SHA256_HASH_SIZE;
+    memcpy(&is_sha256_set, buffer + offset, sizeof(is_sha256_set));
+    offset += sizeof(is_sha256_set);
+
+    if (is_sha256_set && offset < buffer_len)
+    {
+        if (!sha256)
+            sha256 = new uint8_t[SHA256_HASH_SIZE];  
+        memcpy(sha256, (const uint8_t *)(buffer + offset), SHA256_HASH_SIZE);
+        offset += SHA256_HASH_SIZE;
+    }
     memcpy(&verdict, buffer + offset, sizeof(verdict)); 
     offset += sizeof(verdict);
     memcpy(&file_size, buffer + offset, sizeof(file_size)); 
@@ -476,25 +490,22 @@ UserFileDataBase* FileInfo::get_file_data() const
     return user_file_data;
 }
 
-FileContext::FileContext (FileInspect* ins)
-{
-    file_type_context = nullptr;
-    file_signature_context = nullptr;
-    file_capture = nullptr;
-    file_segments = nullptr;
-    inspector = ins;
-    ins->add_global_ref();
-    config = ins->config;
-}
-
 FileContext::FileContext ()
 {
     file_type_context = nullptr;
     file_signature_context = nullptr;
     file_capture = nullptr;
     file_segments = nullptr;
-    inspector = (FileInspect*)InspectorManager::acquire_file_inspector();
-    config = inspector->config;
+    if (SnortConfig::get_conf())
+    {
+        inspector = (FileInspect*)InspectorManager::acquire_file_inspector();
+        config = inspector->config;
+    }
+    else
+    {
+        inspector = nullptr;
+        config = nullptr;
+    }
 }
 
 FileContext::~FileContext ()
@@ -506,7 +517,9 @@ FileContext::~FileContext ()
         stop_file_capture();
 
     delete file_segments;
-    InspectorManager::release(inspector);
+
+    if (inspector)
+        InspectorManager::release(inspector);
 }
 
 /* stop file type identification */
index 8a224da6a32cfb08ccf965854e70ec2f15553893..dbb049b55934f2ff2b7c60f32a7530eabb72d7ce 100644 (file)
@@ -115,8 +115,8 @@ public:
     void set_partial_flag(bool partial);
     bool is_partial_download() const;
 
-    void serialize(char* buffer, uint16_t* offset);
-    void deserialize(const char* buffer, uint16_t& offset);
+    void serialize(char* buffer, uint16_t buffer_len);
+    void deserialize(const char* buffer, uint16_t buffer_len);
 
 protected:
     std::string file_name;
@@ -129,6 +129,7 @@ protected:
     FileDirection direction = FILE_DOWNLOAD;
     uint32_t file_type_id = SNORT_FILE_TYPE_CONTINUE;
     uint8_t* sha256 = nullptr;
+    bool is_sha256_set = false;
     uint64_t file_id = 0;
     FileCapture* file_capture = nullptr;
     bool file_type_enabled = false;
@@ -186,6 +187,8 @@ public:
     // Configuration functions
     void remove_segments();
     void reset();
+    void set_config(FileConfig* fc) { config = fc; }
+    FileConfig* get_config() { return config; }
 private:
     uint64_t processed_bytes = 0;
     void* file_type_context;
index 44db0cb218009be512bcce7d76a3a874c158f51a..4bed9857e85e6a8355c15a98577b768ea29a42b6 100644 (file)
@@ -14,6 +14,8 @@ set (PUB_SUB_INCLUDES
     eve_process_event.h
     expect_events.h
     external_event_ids.h
+    file_events.h
+    file_events_ids.h
     finalize_packet_event.h
     ftp_events.h
     http_event_ids.h
@@ -35,8 +37,6 @@ set (PUB_SUB_INCLUDES
     ssh_events.h
     ssl_events.h
     dns_events.h
-    file_events.h
-    file_events_ids.h
 )
 
 add_library( pub_sub OBJECT
index f36cfe0203e8d22bc95567f0526530b01ce2de12..a4cd98e929b07a2f2b359936d9ece37d96f69932 100644 (file)
@@ -23,9 +23,9 @@
 #ifndef FILE_MP_EVENTS_H
 #define FILE_MP_EVENTS_H
 
-#include "framework/mp_data_bus.h"
-#include "file_events_ids.h"
 #include "file_api/file_cache.h"
+#include "file_events_ids.h"
+#include "framework/mp_data_bus.h"
 #include "hash/hashes.h"
 
 namespace snort
@@ -33,25 +33,25 @@ namespace snort
 
 class SO_PUBLIC FileMPEvent : public snort::DataEvent
 {
-    public:
+public:
 
     FileMPEvent(const FileHashKey& key, int64_t tm, FileInfo& file) : timeout(tm), hashkey(key), file_ctx(file)
     { 
-        len = sizeof(timeout) + sizeof(hashkey.sip) + sizeof(hashkey.sgroup) + 
-              sizeof(hashkey.dip) + sizeof(hashkey.dgroup) +
-              sizeof(hashkey.file_id) + sizeof(hashkey.asid) + sizeof(hashkey.padding) +
-              sizeof(file) + SHA256_HASH_SIZE;
+        len = sizeof(timeout) + sizeof(hashkey.sip) + sizeof(hashkey.sgroup)
+            + sizeof(hashkey.dip) + sizeof(hashkey.dgroup)
+            + sizeof(hashkey.file_id) + sizeof(hashkey.asid) + sizeof(hashkey.padding)
+            + sizeof(file) + SHA256_HASH_SIZE;
     }
 
     FileMPEvent() : hashkey()
     {
-       timeout = 0;
+        timeout = 0;
         len = 0;
     }
 
     int64_t get_timeout()
     {
-         return timeout;
+        return timeout;
     }
 
     FileInfo get_file_ctx()
@@ -69,30 +69,31 @@ class SO_PUBLIC FileMPEvent : public snort::DataEvent
         return len;
     }
 
-    void deserialize(const char* d, uint16_t len)
+    void deserialize(const char* buffer, uint16_t len)
     {
         uint16_t offset = 0;
-        memcpy(&timeout, d, sizeof(timeout));
+        memcpy(&timeout, buffer, sizeof(timeout));
         offset += sizeof(timeout);
-        memcpy(&hashkey.sip, d + offset, sizeof(hashkey.sip));
+        memcpy(&hashkey.sip, buffer + offset, sizeof(hashkey.sip));
         offset += sizeof(hashkey.sip);
-        memcpy(&hashkey.sgroup, d + offset, sizeof(hashkey.sgroup));
+        memcpy(&hashkey.sgroup, buffer + offset, sizeof(hashkey.sgroup));
         offset += sizeof(hashkey.sgroup);
-        memcpy(&hashkey.dip, d + offset, sizeof(hashkey.dip));
+        memcpy(&hashkey.dip, buffer + offset, sizeof(hashkey.dip));
         offset += sizeof(hashkey.dip);
-        memcpy(&hashkey.dgroup, d + offset, sizeof(hashkey.dgroup));
+        memcpy(&hashkey.dgroup, buffer + offset, sizeof(hashkey.dgroup));
         offset += sizeof(hashkey.dgroup);
-        memcpy(&hashkey.file_id, d + offset, sizeof(hashkey.file_id));
+        memcpy(&hashkey.file_id, buffer + offset, sizeof(hashkey.file_id));
         offset += sizeof(hashkey.file_id);
-        memcpy(&hashkey.asid, d + offset, sizeof(hashkey.asid));
+        memcpy(&hashkey.asid, buffer + offset, sizeof(hashkey.asid));
         offset += sizeof(hashkey.asid);
-        memcpy(&hashkey.padding, d + offset, sizeof(hashkey.padding));
+        memcpy(&hashkey.padding, buffer + offset, sizeof(hashkey.padding));
         offset += sizeof(hashkey.padding);
-        file_ctx.deserialize(d, offset);
+        if (offset < len)
+            file_ctx.deserialize(buffer + offset, len - offset);
         this->len = len;
     }
 
-    void serialize(char* buffer, uint16_t* len)
+    void serialize(char* buffer, uint16_t len)
     {
         uint16_t offset = 0;
         memcpy(buffer, &timeout, sizeof(timeout));
@@ -111,15 +112,15 @@ class SO_PUBLIC FileMPEvent : public snort::DataEvent
         offset += sizeof(hashkey.asid);
         memcpy(buffer + offset, &hashkey.padding, sizeof(hashkey.padding));
         offset += sizeof(hashkey.padding);
-        file_ctx.serialize(buffer, &offset);
-        *len = offset;
+        if (offset < len)
+            file_ctx.serialize(buffer + offset, len - offset);
     }
 
-    private:
-        int64_t timeout;
-        FileHashKey hashkey;
-        FileInfo file_ctx;
-        uint16_t len;
+private:
+    int64_t timeout;
+    FileHashKey hashkey;
+    FileInfo file_ctx;
+    uint16_t len;
 };
 
 }
index c57f91f8c7a0045569026a48b59bc07b5e6f78ca..ef8dc05fc7eab0367f048e3f3bbab540685ddd45 100644 (file)
@@ -29,12 +29,13 @@ namespace snort
 {
 
 struct FileMPEvents
-{ enum : unsigned {
-    
-    FILE_SHARE = 0,
-    FILE_SHARE_SYNC,
-    num_ids
-}; };
+{
+    enum : unsigned {
+        FILE_SHARE = 0,
+        FILE_SHARE_SYNC,
+        num_ids
+    };
+};
 
 const PubKey file_pub_key { "file_mp_events", FileMPEvents::num_ids };