]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #2070 in SNORT/snort3 from ~SHRARANG/snort3:appid_mdns_tsan to...
authorShravan Rangarajuvenkata (shrarang) <shrarang@cisco.com>
Wed, 11 Mar 2020 15:51:59 +0000 (15:51 +0000)
committerShravan Rangarajuvenkata (shrarang) <shrarang@cisco.com>
Wed, 11 Mar 2020 15:51:59 +0000 (15:51 +0000)
Squashed commit of the following:

commit 33e1910c3dfc27f1c28507c29cc743fb07cf33b4
Author: Shravan Rangaraju <shrarang@cisco.com>
Date:   Tue Mar 10 12:17:26 2020 -0400

    appid: fix thread-safety issues in mdns detector

src/network_inspectors/appid/service_plugins/service_mdns.cc
src/network_inspectors/appid/service_plugins/service_mdns.h

index 551340a48cc3fc3077a4a5cb123c592769adc07b..c05520087a47bc190c564b5c299c578d95f155fa 100644 (file)
@@ -114,7 +114,7 @@ MdnsServiceDetector::MdnsServiceDetector(ServiceDiscovery* sd)
 
 MdnsServiceDetector::~MdnsServiceDetector()
 {
-    destory_matcher();
+    delete matcher;
 }
 
 int MdnsServiceDetector::validate(AppIdDiscoveryArgs& args)
@@ -136,8 +136,9 @@ int MdnsServiceDetector::validate(AppIdDiscoveryArgs& args)
         {
             if (args.ctxt.get_odp_ctxt().mdns_user_reporting)
             {
-                analyze_user(args.asd, args.pkt, args.size);
-                destroy_match_list();
+                MatchedPatterns* pattern_list = nullptr;
+                analyze_user(args.asd, args.pkt, args.size, pattern_list);
+                destroy_match_list(pattern_list);
                 goto success;
             }
             goto success;
@@ -180,8 +181,8 @@ int MdnsServiceDetector::validate_reply(const uint8_t* data, uint16_t size)
    Returns 0 or 1 for successful/unsuccessful hit for pattern '@'
    Returns -1 for invalid address pointer or past the data_size */
 int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char** resp_endptr,
-    int* start_index,
-    uint16_t data_size, uint8_t* user_name_len, unsigned size)
+    int* start_index, uint16_t data_size, uint8_t* user_name_len, unsigned size,
+    MatchedPatterns*& pattern_list)
 {
     int index = 0;
     int pattern_length = 0;
@@ -198,7 +199,7 @@ int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char** r
 
     // FIXIT-M - This code needs review to ensure it works correctly with the new semantics of the
     //           index returned by the SearchTool find_all pattern matching function
-    scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length);
+    scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length, pattern_list);
     /* Contains reference pointer */
     while ((index < data_size) && !(*resp_endptr) && ((uint8_t )temp_start_ptr[index]  >>
         SHIFT_BITS_REFERENCE_PTR  != PATTERN_REFERENCE_PTR))
@@ -210,7 +211,7 @@ int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char** r
             break;
         }
         index++;
-        scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length);
+        scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length, pattern_list);
     }
     if (index >= data_size)
         *user_name_len = 0;
@@ -224,7 +225,7 @@ int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char** r
         {
             index++;
             scan_matched_patterns(start_ptr,  size - data_size + index, resp_endptr,
-                &pattern_length);
+                &pattern_length, pattern_list);
         }
         if (index >= data_size)
             *user_name_len = 0;
@@ -250,7 +251,8 @@ int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char** r
                2. Calls the function which scans for pattern to identify the user
                3. Calls the function which does the Username reporting along with the host
   MDNS User Analysis*/
-int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint16_t size)
+int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint16_t size,
+    MatchedPatterns*& pattern_list)
 {
     int start_index = 0;
     uint16_t data_size = size;
@@ -267,7 +269,7 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
         const char* user_original;
 
         const char* srv_original  = (const char*)pkt->data + RECORD_OFFSET;
-        create_match_list(srv_original, size - RECORD_OFFSET);
+        pattern_list = create_match_list(srv_original, size - RECORD_OFFSET);
         const char* end_srv_original  = (const char*)pkt->data + RECORD_OFFSET + data_size;
         for (int processed_ans = 0; processed_ans < ans_count && data_size <= size && size > 0;
             processed_ans++ )
@@ -275,7 +277,7 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
             // Call Decode Reference pointer function if referenced value instead of direct value
             uint8_t user_name_len = 0;
             int ret_value = reference_pointer(srv_original, &resp_endptr,  &start_index, data_size,
-                &user_name_len, size);
+                &user_name_len, size, pattern_list);
             int user_index =0;
 
             if (ret_value == -1)
@@ -415,56 +417,44 @@ static int mdns_pattern_match(void* id, void*, int match_end_pos, void* data, vo
     return 0;
 }
 
-unsigned MdnsServiceDetector::create_match_list(const char* data, uint16_t dataSize)
+MatchedPatterns* MdnsServiceDetector::create_match_list(const char* data, uint16_t dataSize)
 {
-    matcher->find_all((const char*)data, dataSize, mdns_pattern_match, false, (void*)&patternList);
+    MatchedPatterns* pattern_list = nullptr;
+    matcher->find_all((const char*)data, dataSize, mdns_pattern_match, false, (void*)&pattern_list);
 
-    if (patternList)
-        return 1;
-    return 0;
+    return pattern_list;
 }
 
 void MdnsServiceDetector::scan_matched_patterns(const char* dataPtr, uint16_t index, const
-    char** resp_endptr,
-    int* pattern_length)
+    char** resp_endptr, int* pattern_length, MatchedPatterns*& pattern_list)
 {
-    while (patternList)
+    while (pattern_list)
     {
-        if (patternList->match_start_pos == index)
+        if (pattern_list->match_start_pos == index)
         {
             *resp_endptr = dataPtr;
-            *pattern_length = patternList->mpattern->length;
+            *pattern_length = pattern_list->mpattern->length;
             return;
         }
 
-        if (patternList->match_start_pos > index)
+        if (pattern_list->match_start_pos > index)
             break;
 
-        MatchedPatterns* element = patternList;
-        patternList = patternList->next;
+        MatchedPatterns* element = pattern_list;
+        pattern_list = pattern_list->next;
         snort_free(element);
     }
     *resp_endptr = nullptr;
     *pattern_length = 0;
 }
 
-void MdnsServiceDetector::destroy_match_list()
+void MdnsServiceDetector::destroy_match_list(MatchedPatterns*& pattern_list)
 {
-    while (patternList)
+    while (pattern_list)
     {
-        MatchedPatterns* element = patternList;
-        patternList = patternList->next;
+        MatchedPatterns* element = pattern_list;
+        pattern_list = pattern_list->next;
 
         snort_free(element);
     }
 }
-
-void MdnsServiceDetector::destory_matcher()
-{
-    if (matcher)
-        delete matcher;
-    matcher = nullptr;
-
-    destroy_match_list();
-}
-
index 2aca729b250a3159cdf2a2f065d6bf9175ec76d7..1860a1eaf0fac1c291daa6273d791104a169ff14 100644 (file)
@@ -40,18 +40,17 @@ public:
     int validate(AppIdDiscoveryArgs&) override;
 
 private:
-    unsigned create_match_list(const char* data, uint16_t dataSize);
+    MatchedPatterns* create_match_list(const char* data, uint16_t dataSize);
     void scan_matched_patterns(const char* dataPtr, uint16_t index, const char** resp_endptr,
-        int* pattern_length);
-    void destroy_match_list();
-    void destory_matcher();
+        int* pattern_length, MatchedPatterns*& pattern_list);
+    void destroy_match_list(MatchedPatterns*& pattern_list);
     int validate_reply(const uint8_t* data, uint16_t size);
-    int analyze_user(AppIdSession&, const snort::Packet*, uint16_t size);
+    int analyze_user(AppIdSession&, const snort::Packet*, uint16_t size,
+        MatchedPatterns*& pattern_list);
     int reference_pointer(const char* start_ptr, const char** resp_endptr, int* start_index,
-        uint16_t data_size, uint8_t* user_name_len, unsigned size);
+        uint16_t data_size, uint8_t* user_name_len, unsigned size, MatchedPatterns*& pattern_list);
 
     snort::SearchTool* matcher = nullptr;
-    MatchedPatterns* patternList = nullptr;
 };
 #endif