]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4863: appid: fix multiple mdns issues
authorBohdan Hryniv -X (bhryniv - SOFTSERVE INC at Cisco) <bhryniv@cisco.com>
Thu, 14 Aug 2025 22:06:58 +0000 (22:06 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Thu, 14 Aug 2025 22:06:58 +0000 (22:06 +0000)
Merge in SNORT/snort3 from ~BHRYNIV/snort3:multiple_mdns_fixes to master

Squashed commit of the following:

commit 3852ed0f166c5f4d69fa73912f3a6a46f91a2c96
Author: Bohdan Hryniv <bhryniv@cisco>
Date:   Wed Jul 23 10:57:12 2025 -0400

    appid: fix multiple mdns issues

src/network_inspectors/appid/service_plugins/service_mdns.cc
src/network_inspectors/appid/service_plugins/service_mdns.h

index a7b07b6079cc88373ef8e206206d96c9ce3492ab..90fd5252fec6b8d123adf2236effd9a3527bef24 100644 (file)
@@ -160,14 +160,19 @@ int MdnsServiceDetector::validate_reply(const uint8_t* data, uint16_t size)
 {
     int ret_val;
 
+    // Minimum MDNS packet size for header fields
+    // (query_val, ans_count, srv_original)
+    if (size < RECORD_OFFSET)
+        return 0;
+
     /* Check for the pattern match*/
-    if (size >= 6 && memcmp(data, MDNS_PATTERN1, sizeof(MDNS_PATTERN1)-1) == 0)
+    if (memcmp(data, MDNS_PATTERN1, sizeof(MDNS_PATTERN1)-1) == 0)
         ret_val = 1;
-    else if (size >= 6 && memcmp(data, MDNS_PATTERN2,  sizeof(MDNS_PATTERN2)-1) == 0)
+    else if (memcmp(data, MDNS_PATTERN2,  sizeof(MDNS_PATTERN2)-1) == 0)
         ret_val = 1;
-    else if (size >= 6 && memcmp(data,MDNS_PATTERN3, sizeof(MDNS_PATTERN3)-1) == 0)
+    else if (memcmp(data,MDNS_PATTERN3, sizeof(MDNS_PATTERN3)-1) == 0)
         ret_val = 1;
-    else if (size >= 4 && memcmp(data,MDNS_PATTERN4, sizeof(MDNS_PATTERN4)-1) == 0)
+    else if (memcmp(data,MDNS_PATTERN4, sizeof(MDNS_PATTERN4)-1) == 0)
         ret_val = 1;
     else
         ret_val = 0;
@@ -179,63 +184,61 @@ int MdnsServiceDetector::validate_reply(const uint8_t* data, uint16_t size)
    Output is resp_endptr, start_index and user_name_len
    Returns 0 or 1 for successful/unsuccessful hit for pattern '@'
    Returns -1 for invalid address pointer or past the data_size */
-int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char** resp_endptr,
-    int* start_index, uint16_t data_size, uint8_t* user_name_len, unsigned size,
-    MatchedPatterns*& pattern_list)
+int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char* end_pkt,
+   const char** resp_endptr, int* start_index, uint16_t data_size,
+   uint8_t* user_name_len, unsigned size, MatchedPatterns*& pattern_list)
 {
     int index = 0;
     int pattern_length = 0;
 
-    while (index< data_size &&  (start_ptr[index] == ' ' ))
+    while ((start_ptr + index) < end_pkt && (start_ptr[index] == ' '))
         index++;
 
-    if (index >= data_size)
+    if ((start_ptr + index) >= end_pkt)
         return -1;
     *start_index = index;
 
-    const char* temp_start_ptr;
-    temp_start_ptr  = start_ptr+index;
+    const char* temp_start_ptr = start_ptr + index;
+
+    int temp_index = 0;
 
-    // FIXIT-M - This code needs review to ensure it works correctly with the new semantics of the
-    //           index returned by the SearchTool find_all pattern matching function
     scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length, pattern_list);
-    /* Contains reference pointer */
-    while ((index < data_size) && !(*resp_endptr) && ((uint8_t )temp_start_ptr[index]  >>
-        SHIFT_BITS_REFERENCE_PTR  != PATTERN_REFERENCE_PTR))
+
+    while ((temp_start_ptr + temp_index) < end_pkt && !(*resp_endptr) &&
+            ((uint8_t)temp_start_ptr[temp_index] >> SHIFT_BITS_REFERENCE_PTR != PATTERN_REFERENCE_PTR))
     {
-        if (temp_start_ptr[index] == PATTERN_USERNAME_1)
+        if (temp_start_ptr[temp_index] == PATTERN_USERNAME_1)
         {
-            *user_name_len = index - *start_index;
-            index++;
+            *user_name_len = temp_index;
+            temp_index++;
             break;
         }
-        index++;
-        scan_matched_patterns(start_ptr, size - data_size + index, resp_endptr, &pattern_length, pattern_list);
-    }
-    if (index >= data_size)
+        temp_index++;
+        scan_matched_patterns(start_ptr, size - data_size + index + temp_index, resp_endptr, &pattern_length, pattern_list);
+       }
+
+    if ((temp_start_ptr + temp_index) >= end_pkt)
         *user_name_len = 0;
-    else if ((uint8_t )temp_start_ptr[index]  >> SHIFT_BITS_REFERENCE_PTR == PATTERN_REFERENCE_PTR)
+    else if ((uint8_t)temp_start_ptr[temp_index] >> SHIFT_BITS_REFERENCE_PTR == PATTERN_REFERENCE_PTR)
         pattern_length = REFERENCE_PTR_LENGTH;
-    else if (!(*resp_endptr) && ((uint8_t )temp_start_ptr[index]  >>SHIFT_BITS_REFERENCE_PTR !=
-        PATTERN_REFERENCE_PTR ))
+    else if (!(*resp_endptr) &&
+            ((uint8_t)temp_start_ptr[temp_index] >> SHIFT_BITS_REFERENCE_PTR != PATTERN_REFERENCE_PTR))
     {
-        while ((index < data_size) && !(*resp_endptr) && ((uint8_t )temp_start_ptr[index]  >>
-            SHIFT_BITS_REFERENCE_PTR != PATTERN_REFERENCE_PTR))
+        while ((temp_start_ptr + temp_index) < end_pkt && !(*resp_endptr) &&
+                ((uint8_t)temp_start_ptr[temp_index] >> SHIFT_BITS_REFERENCE_PTR != PATTERN_REFERENCE_PTR))
         {
-            index++;
-            scan_matched_patterns(start_ptr,  size - data_size + index, resp_endptr,
-                &pattern_length, pattern_list);
+            temp_index++;
+            scan_matched_patterns(start_ptr, size - data_size + index + temp_index, resp_endptr, &pattern_length, pattern_list);
         }
-        if (index >= data_size)
+
+        if ((temp_start_ptr + temp_index) >= end_pkt)
             *user_name_len = 0;
-        else if ((uint8_t )temp_start_ptr[index]  >> SHIFT_BITS_REFERENCE_PTR ==
-            PATTERN_REFERENCE_PTR)
+        else if ((uint8_t)temp_start_ptr[temp_index] >> SHIFT_BITS_REFERENCE_PTR == PATTERN_REFERENCE_PTR)
             pattern_length = REFERENCE_PTR_LENGTH;
     }
 
-    /* Add reference pointer bytes */
-    if ( index+ pattern_length < data_size)
-        *resp_endptr = start_ptr + index+ pattern_length;
+    if ((start_ptr + index + temp_index + pattern_length) < end_pkt)
+        *resp_endptr = start_ptr + index + temp_index + pattern_length;
     else
         return -1;
 
@@ -270,26 +273,32 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
         const char* srv_original  = (const char*)pkt->data + RECORD_OFFSET;
         pattern_list = create_match_list(srv_original, size - RECORD_OFFSET);
         const char* end_srv_original  = (const char*)pkt->data + RECORD_OFFSET + data_size;
-        for (int processed_ans = 0; processed_ans < ans_count && data_size <= size && size > 0;
+        for (int processed_ans = 0; processed_ans < ans_count && data_size <= size;
             processed_ans++ )
         {
             // Call Decode Reference pointer function if referenced value instead of direct value
             uint8_t user_name_len = 0;
-            int ret_value = reference_pointer(srv_original, &resp_endptr,  &start_index, data_size,
-                &user_name_len, size, pattern_list);
+            const char* packet_end = (const char*)pkt->data + size;
+            int ret_value = reference_pointer(srv_original, packet_end, &resp_endptr, &start_index,
+                                             data_size, &user_name_len, size, pattern_list);
             int user_index =0;
 
             if (ret_value == -1)
                 return -1;
             else if (ret_value)
             {
-                while (start_index < data_size && (!isprint(srv_original[start_index])  ||
-                    srv_original[start_index] == '"' || srv_original[start_index] =='\''))
+                while ((srv_original + start_index) < packet_end && start_index < data_size &&
+                       (!isprint(srv_original[start_index]) ||
+                        srv_original[start_index] == '"' || srv_original[start_index] =='\''))
                 {
                     start_index++;
                     user_index++;
                 }
-                user_name_len -=user_index;
+
+                if (user_index <= user_name_len)
+                    user_name_len -= user_index;
+                 else
+                    user_name_len = 0;
 
                 char user_name[MAX_LENGTH_SERVICE_NAME] = "";
                 memcpy(user_name, srv_original + start_index, user_name_len);
@@ -309,10 +318,14 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
             }
 
             // Find the  length to Jump to the next response
-            if ((resp_endptr  + NEXT_MESSAGE_OFFSET  ) < (srv_original + data_size))
+            if ((resp_endptr  + NEXT_MESSAGE_OFFSET) < packet_end)
             {
                 const uint8_t* data_len_str = (const uint8_t*)(resp_endptr+ LENGTH_OFFSET);
                 uint16_t data_len =  (short)( data_len_str[0]<< SHIFT_BITS | ( data_len_str[1] ));
+
+                if (data_len > size - (srv_original - (const char*)pkt->data))
+                    return -1;
+
                 data_size = data_size - (resp_endptr  + NEXT_MESSAGE_OFFSET + data_len -
                     srv_original);
                 /* Check if user name is available in the Domain Name field */
@@ -331,9 +344,12 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
                     {
                         user_name_len = user_original - srv_original - start_index;
                         const char* user_name_bkp = srv_original + start_index;
-                        /* Non-Printable characters in the beginning */
 
-                        while (user_index < user_name_len)
+                        if (user_name_bkp + user_name_len > packet_end)
+                            return 0;
+
+                        /* Non-Printable characters in the beginning */
+                        while (user_index < user_name_len && (user_name_bkp + user_index) < packet_end)
                         {
                             if (isprint(user_name_bkp[user_index]))
                                 break;
index 840dd699e74f9ac1a1eb0445af0a23499ee778b4..4f8c51452cba76dace46816c5283dabba992bdb1 100644 (file)
@@ -47,10 +47,9 @@ private:
     int validate_reply(const uint8_t* data, uint16_t size);
     int analyze_user(AppIdSession&, const snort::Packet*, uint16_t size,
         AppidChangeBits& change_bits, MatchedPatterns*& pattern_list);
-    int reference_pointer(const char* start_ptr, const char** resp_endptr, int* start_index,
+    int reference_pointer(const char* start_ptr, const char* end_pkt, const char** resp_endptr, int* start_index,
         uint16_t data_size, uint8_t* user_name_len, unsigned size, MatchedPatterns*& pattern_list);
 
     snort::SearchTool matcher;
 };
 #endif
-