]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4973: Mdns deviceinfo
authorUmang Sharma (umasharm) <umasharm@cisco.com>
Thu, 4 Dec 2025 14:19:54 +0000 (14:19 +0000)
committerChris Sherwin (chsherwi) <chsherwi@cisco.com>
Thu, 4 Dec 2025 14:19:54 +0000 (14:19 +0000)
Merge in SNORT/snort3 from ~UMASHARM/snort3:mdns_deviceinfo to master

Squashed commit of the following:

commit b183f83410da9d86cef10e8bae079e9bc734c933
Author: Umang Sharma <umasharm@cisco.com>
Date:   Tue Nov 4 09:16:57 2025 -0500

    appid: mDNS TXT records parsing and deviceinfo event generation

src/network_inspectors/appid/app_info_table.cc
src/network_inspectors/appid/appid_config.h
src/network_inspectors/appid/service_plugins/service_mdns.cc
src/network_inspectors/appid/service_plugins/service_mdns.h
src/pub_sub/CMakeLists.txt
src/pub_sub/deviceinfo_events.h [new file with mode: 0644]

index 15977471a9b41242367e71258325ee0a31a7e801..4030b94a3b02381829ea8e853b077a0522b957b7 100644 (file)
@@ -402,6 +402,13 @@ void AppInfoManager::load_odp_config(OdpContext& odp_ctxt, const char* path)
                     odp_ctxt.recheck_for_portservice_appid = true;
                 }
             }
+            else if (!(strcasecmp(conf_key, "mdns_deviceinfo")))
+            {
+                if (!(strcasecmp(conf_val, "enabled")))
+                {
+                    odp_ctxt.mdns_deviceinfo = true;
+                }
+            }
             else if (!(strcasecmp(conf_key, "bittorrent_aggressiveness")))
             {
                 int aggressiveness = atoi(conf_val);
index a087010673653f30db8bbdb1225550dff8664975..f655e7b0f4ecde6f5084c56bc8c5627858f81774 100644 (file)
@@ -135,6 +135,7 @@ public:
     bool dns_host_reporting = true;
     bool referred_appId_disabled = false;
     bool mdns_user_reporting = true;
+    bool mdns_deviceinfo = false;
     bool chp_userid_disabled = false;
     bool is_host_port_app_cache_runtime = false;
     bool check_host_port_app_cache = false;
index 90fd5252fec6b8d123adf2236effd9a3527bef24..22339f39026e40369a97dd2126e010551dfcfba6 100644 (file)
 #include "appid_module.h"
 #include "protocols/packet.h"
 #include "search_engines/search_tool.h"
+#include "pub_sub/deviceinfo_events.h"
+#include "appid_inspector.h"
+#include <vector>
+#include <string>
+#include <utility>
+#include <map>
+#include <set>
 
 using namespace snort;
 
@@ -54,6 +61,15 @@ using namespace snort;
 #define SHIFT_BITS_REFERENCE_PTR  6
 #define REFERENCE_PTR_LENGTH  2
 #define MAX_LENGTH_SERVICE_NAME 256
+#define DNS_COMPRESSION_PTR_SKIP  2
+#define DNS_LABEL_LENGTH_SKIP  1
+#define DNS_QUESTION_FIXED_SIZE  4
+#define DNS_RECORD_HEADER_SIZE  10
+#define TXT_RECORD_TYPE  0x0010
+#define DNS_COMPRESSION_MASK  0xC0
+#define DNS_NULL_TERMINATOR  0x00
+#define DNS_COMPRESSION_OFFSET_MASK  0x3F
+#define DNS_RDLENGTH_SIZE  2
 
 enum MDNSState
 {
@@ -121,7 +137,7 @@ void MdnsServiceDetector::do_custom_reload()
 
 int MdnsServiceDetector::validate(AppIdDiscoveryArgs& args)
 {
-    ServiceMDNSData* fd = (ServiceMDNSData*)data_get(args.asd);
+    ServiceMDNSData* fd = static_cast<ServiceMDNSData*>(data_get(args.asd));
     if (!fd)
     {
         fd = new ServiceMDNSData();
@@ -237,10 +253,38 @@ int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char* en
             pattern_length = REFERENCE_PTR_LENGTH;
     }
 
-    if ((start_ptr + index + temp_index + pattern_length) < end_pkt)
-        *resp_endptr = start_ptr + index + temp_index + pattern_length;
+    const char* name_parser = temp_start_ptr + temp_index;
+    
+    while (name_parser < end_pkt)
+    {
+        if (((unsigned char)*name_parser & DNS_COMPRESSION_MASK) == DNS_COMPRESSION_MASK)
+        {
+            name_parser += DNS_COMPRESSION_PTR_SKIP;
+            break;
+        }
+        else if (*name_parser == DNS_NULL_TERMINATOR)
+        {
+            name_parser += DNS_LABEL_LENGTH_SKIP;
+            break;
+        }
+        else
+        {
+            uint8_t label_len = (unsigned char)*name_parser;
+            if (name_parser + DNS_LABEL_LENGTH_SKIP + label_len > end_pkt)
+                return -1;
+            name_parser += DNS_LABEL_LENGTH_SKIP + label_len;
+        }
+    }
+
+    if (name_parser < end_pkt)
+    {
+        *resp_endptr = name_parser;
+    }
     else
+    {
         return -1;
+    }
+
 
     if (*user_name_len > 0)
         return 1;
@@ -248,6 +292,157 @@ int MdnsServiceDetector::reference_pointer(const char* start_ptr, const char* en
         return 0;
 }
 
+static bool is_printable_string(const std::string& str)
+{
+    return std::all_of(str.begin(), str.end(), [](unsigned char c) {
+        return std::isprint(c);
+    }) && !str.empty();
+}
+
+static std::string clean_mdns_string(const std::string& str)
+{
+    std::string clean;
+    for (char c : str)
+    {
+        if (static_cast<unsigned char>(c) < 128 && std::isprint(c))
+            clean += c;
+    }
+    return clean;
+}
+
+void MdnsServiceDetector::process_txt_record(const snort::Packet* pkt, const char* srv_original, 
+    const char* rdata_start, uint16_t data_len, const char* packet_end, 
+    std::string& protocol_type, std::string& device_name,
+    std::vector<std::pair<std::string, std::string>>& kv_pairs)
+{
+    const char* dns_name_start = srv_original;
+    const char* name_parser = dns_name_start;
+    bool first_label = true;
+    std::set<const char*> visited_ptrs;
+
+    while (name_parser < packet_end)
+    {
+        if (((unsigned char)*name_parser & DNS_COMPRESSION_MASK) == DNS_COMPRESSION_MASK)
+        {
+            if (name_parser + 1 >= packet_end)
+                break;
+            uint16_t offset = ((name_parser[0] & DNS_COMPRESSION_OFFSET_MASK) << SHIFT_BITS) | (unsigned char)name_parser[1];
+            
+            if (offset >= (packet_end - (const char*)pkt->data) || offset < RECORD_OFFSET)
+            {
+                break;
+            }
+
+            const char* compressed_ptr = (const char*)pkt->data + offset;
+            if (compressed_ptr < packet_end and compressed_ptr >= (const char*)pkt->data)
+            {
+                if (visited_ptrs.find(compressed_ptr) != visited_ptrs.end())
+                {
+                    break;
+                }
+                visited_ptrs.insert(compressed_ptr);
+                name_parser = compressed_ptr;
+                continue;
+            }
+            else
+                break;
+        }
+        else if (*name_parser == DNS_NULL_TERMINATOR)
+            break;
+        else
+        {
+            uint8_t label_len = (unsigned char)*name_parser;
+            name_parser += DNS_LABEL_LENGTH_SKIP;
+            if (name_parser + label_len > packet_end)
+                break;
+
+            std::string label(name_parser, label_len);
+            
+            if (first_label)
+            {
+                device_name = std::move(label);
+                
+                size_t at_pos = device_name.find(PATTERN_USERNAME_1);
+                if (at_pos != std::string::npos and at_pos > 0)
+                {
+                    device_name = device_name.substr(at_pos + DNS_LABEL_LENGTH_SKIP);
+                }
+
+                size_t dot_pos = device_name.find('.');
+                if (dot_pos != std::string::npos and dot_pos > 0)
+                {
+                    protocol_type = device_name.substr(dot_pos + DNS_LABEL_LENGTH_SKIP);
+                    device_name = device_name.substr(0, dot_pos);
+                }
+
+                if (!is_printable_string(device_name))
+                {
+                    device_name.clear();
+                }
+                else
+                {
+                    device_name = clean_mdns_string(device_name);
+                }
+                
+                first_label = false;
+            }
+            else
+            {
+                if (!protocol_type.empty())
+                    protocol_type += ".";
+                protocol_type += label;
+            }
+            
+            name_parser += label_len;
+        }
+    }
+
+    const uint8_t* txt_data = (const uint8_t*)rdata_start;
+    if (rdata_start + data_len > packet_end)
+    {
+        return;
+    }
+    const uint8_t* txt_end = txt_data + data_len;
+    
+    while (txt_data < txt_end)
+    {
+        uint8_t txt_len = *txt_data++;
+        
+        if (txt_len == 0 || txt_data + txt_len > txt_end)
+        {
+            break;
+        }
+        
+        std::string txt_string((const char*)txt_data, txt_len);
+        txt_data += txt_len;
+
+        if (txt_string.empty())
+            continue;
+
+        size_t equals_pos = txt_string.find('=');
+        if (equals_pos != std::string::npos and equals_pos > 0)
+        {
+            std::string key = txt_string.substr(0, equals_pos);
+            std::string value = txt_string.substr(equals_pos + 1);
+
+            if (is_printable_string(key) && (value.empty() || is_printable_string(value)))
+            {
+                key = clean_mdns_string(key);
+                value = clean_mdns_string(value);
+                kv_pairs.emplace_back(key, value);
+            }
+        }
+        else
+        {
+            if (is_printable_string(txt_string))
+            {
+                std::string clean_key = clean_mdns_string(txt_string);
+                kv_pairs.emplace_back(clean_key, "");
+            }
+        }
+    }
+}
+
 /* Input to this Function is pkt and size
    Processing: 1. Parses Multiple MDNS response packet
                2. Calls the function which scans for pattern to identify the user
@@ -264,6 +459,10 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
     int query_val_int = (short)(query_val[0]<<SHIFT_BITS  | query_val[1]);
     const char* answers = (const char*)pkt->data + ANSWER_OFFSET;
     int ans_count =  (short)(answers[0]<< SHIFT_BITS | (answers[1] ));
+    int authority_count = (short)(answers[2]<< SHIFT_BITS | (answers[3] ));
+    int additional_count = (short)(answers[4]<< SHIFT_BITS | (answers[5] ));
+    std::map<std::pair<std::string, std::string>, std::vector<std::pair<std::string, std::string>>> device_info_map;
+
 
     if ( query_val_int == 0)
     {
@@ -273,8 +472,9 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
         const char* srv_original  = (const char*)pkt->data + RECORD_OFFSET;
         pattern_list = create_match_list(srv_original, size - RECORD_OFFSET);
         const char* end_srv_original  = (const char*)pkt->data + RECORD_OFFSET + data_size;
-        for (int processed_ans = 0; processed_ans < ans_count && data_size <= size;
-            processed_ans++ )
+        int total_records = ans_count + authority_count + additional_count;
+        for (int processed_records = 0; processed_records < total_records && data_size <= size;
+            processed_records++ )
         {
             // Call Decode Reference pointer function if referenced value instead of direct value
             uint8_t user_name_len = 0;
@@ -320,12 +520,34 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
             // Find the  length to Jump to the next response
             if ((resp_endptr  + NEXT_MESSAGE_OFFSET) < packet_end)
             {
-                const uint8_t* data_len_str = (const uint8_t*)(resp_endptr+ LENGTH_OFFSET);
-                uint16_t data_len =  (short)( data_len_str[0]<< SHIFT_BITS | ( data_len_str[1] ));
+                if (((unsigned char)resp_endptr[0] & DNS_COMPRESSION_MASK) == DNS_COMPRESSION_MASK)
+                    resp_endptr += DNS_COMPRESSION_PTR_SKIP;
+                uint16_t record_type = ((uint8_t)resp_endptr[0] << SHIFT_BITS) | (uint8_t)resp_endptr[1];
+                
+                const uint8_t* rdlength_ptr = (const uint8_t*)(resp_endptr + LENGTH_OFFSET);
+                if (rdlength_ptr + DNS_RDLENGTH_SIZE > (const uint8_t*)packet_end)
+                    return -1;
+
+                uint16_t data_len = (rdlength_ptr[0] << SHIFT_BITS) | rdlength_ptr[1];
 
-                if (data_len > size - (srv_original - (const char*)pkt->data))
+                const char* rdata_start = resp_endptr + NEXT_MESSAGE_OFFSET;
+                if (rdata_start + data_len > packet_end)
                     return -1;
 
+                if (record_type == TXT_RECORD_TYPE and data_len > 0 and asd.get_odp_ctxt().mdns_deviceinfo)
+                {
+                    std::string protocol_type, device_name;
+                    std::vector<std::pair<std::string, std::string>> kv_pairs;
+                    const char* dns_name_ptr = srv_original;
+                    process_txt_record(pkt, dns_name_ptr, rdata_start, data_len, packet_end,
+                                     protocol_type, device_name, kv_pairs);
+                    if (!protocol_type.empty() || !device_name.empty())
+                    {
+                        auto device_key = std::make_pair(protocol_type, device_name);
+                        device_info_map[device_key] = std::move(kv_pairs);
+                    }
+                }
+
                 data_size = data_size - (resp_endptr  + NEXT_MESSAGE_OFFSET + data_len -
                     srv_original);
                 /* Check if user name is available in the Domain Name field */
@@ -340,7 +562,7 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
                     user_original = (const char*)memchr((const uint8_t*)srv_original, PATTERN_USERNAME_1,
                         data_len);
 
-                    if (user_original )
+                    if ( user_original )
                     {
                         user_name_len = user_original - srv_original - start_index;
                         const char* user_name_bkp = srv_original + start_index;
@@ -375,7 +597,6 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
                                 user_name_len - user_index);
                             user_name[ user_name_len - user_index ] = '\0';
                             add_user(asd, user_name, APP_ID_MDNS, true, change_bits);
-                            return 1;
                         }
                         else
                             return 0;
@@ -395,6 +616,11 @@ int MdnsServiceDetector::analyze_user(AppIdSession& asd, const Packet* pkt, uint
     else
         return 0;
 
+    if (!device_info_map.empty() and asd.get_odp_ctxt().mdns_deviceinfo)
+    {
+        DeviceInfoEvent event(pkt, device_info_map);
+        DataBus::publish(DataBus::get_id(deviceinfo_pub_key), DeviceInfoEventIds::DEVICEINFO, event);
+    }
     return 1;
 }
 
index 4f8c51452cba76dace46816c5283dabba992bdb1..a8da85d8cb41411e3429aed94e2fbbc26e2e40fd 100644 (file)
@@ -49,6 +49,9 @@ private:
         AppidChangeBits& change_bits, MatchedPatterns*& pattern_list);
     int reference_pointer(const char* start_ptr, const char* end_pkt, const char** resp_endptr, int* start_index,
         uint16_t data_size, uint8_t* user_name_len, unsigned size, MatchedPatterns*& pattern_list);
+    void process_txt_record(const snort::Packet* pkt, const char* srv_original, const char* rdata_start, 
+        uint16_t data_len, const char* packet_end, std::string& protocol_type, std::string& device_name,
+        std::vector<std::pair<std::string, std::string>>& kv_pairs);
 
     snort::SearchTool matcher;
 };
index 8301e10e02b41df22295ef63833f0b4f1a178fc1..49bcbd70020326546eb4396ba09058a576e2ff5d 100644 (file)
@@ -40,6 +40,7 @@ set (PUB_SUB_INCLUDES
     ssl_events.h
     stream_event_ids.h
     dns_payload_event.h
+    deviceinfo_events.h
 )
 
 add_library( pub_sub OBJECT
diff --git a/src/pub_sub/deviceinfo_events.h b/src/pub_sub/deviceinfo_events.h
new file mode 100644 (file)
index 0000000..6b972d3
--- /dev/null
@@ -0,0 +1,113 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2014-2025 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// deviceinfo_events.h author Umang Sharma <umasharm@cisco.com>
+
+#ifndef DEVICEINFO_EVENTS_H
+#define DEVICEINFO_EVENTS_H
+
+// The DeviceInfoEvent class is used to store device information extracted from
+// network protocols such as MDNS (multicast DNS)
+// Device information includes device names and key-value attribute pairs
+// that describe device characteristics like model, manufacturer, services, etc.
+// Subscribers can register handlers to receive and process these events for network analysis purposes.
+
+#include "framework/data_bus.h"
+#include <vector>
+#include <string>
+#include <utility>
+#include <map>
+
+namespace snort
+{
+
+// Event IDs for device information events published via DataBus
+// DEVICEINFO: Primary event type for device information extracted from network protocols
+struct DeviceInfoEventIds { enum : unsigned { DEVICEINFO, num_ids }; };
+
+const PubKey deviceinfo_pub_key { "deviceinfo", DeviceInfoEventIds::num_ids };
+
+// DataEvent that contains device identification data including protocol type, device name, and attributes
+class DeviceInfoEvent : public DataEvent
+{
+public:
+    // Composite key for unique device identification consisting of protocol type and device name
+    // The protocol type identifies the network protocol (e.g., "_airplay._tcp.local", "_http._tcp.local")
+    // The device name identifies the specific device instance (e.g., "John's iPhone", "Office Printer")
+    using DeviceKey = std::pair<std::string, std::string>;
+
+    // Collection of device attributes extracted from network protocols as key-value pairs
+    // Contains device characteristics like model, manufacturer, version, services, etc.
+    // Example: [("model", "iPhone12"), ("manufacturer", "Apple"), ("os", "iOS 15.0")]
+    using KeyValueVector = std::vector<std::pair<std::string, std::string>>;
+
+    // Maps device identifiers to their corresponding attribute collections
+    // Allows multiple devices to be tracked within a single event, each with their own attributes
+    // Key: (protocol_type, device_name), Value: vector of device attribute key-value pairs
+    using DeviceInfoMap = std::map<DeviceKey, KeyValueVector>;
+
+    // Constructor for creating an event containing multiple devices with their attributes
+    // Used when a single network packet or protocol exchange reveals information about multiple devices
+    // For example, a network scan response that contains information about several discovered devices
+    DeviceInfoEvent(const snort::Packet* p, const DeviceInfoMap& device_map)
+        : pkt(p), device_info_map(device_map) { }
+
+    // Constructor for creating an event containing a single device with its attributes
+    // Used when network protocol analysis identifies a specific device and its characteristics
+    // The device is uniquely identified by protocol type and device name combination
+    DeviceInfoEvent(const snort::Packet* p, const std::string& protocol_type,
+                   const std::string& device_name, const KeyValueVector& kv_pairs)
+        : pkt(p)
+    {
+        device_info_map[std::make_pair(protocol_type, device_name)] = kv_pairs;
+    }
+
+    const Packet* get_packet() const override
+    { return pkt; }
+
+    const DeviceInfoMap& get_device_info_map() const
+    { return device_info_map; }
+
+    // Retrieve device attributes for a specific device identified by protocol type and device name
+    // Returns nullptr if the specified device is not found in this event
+    // Used by subscribers to extract specific device information from the event
+    const KeyValueVector* get_key_value_pairs(const std::string& protocol_type,
+                                             const std::string& device_name) const
+    {
+        auto it = device_info_map.find(std::make_pair(protocol_type, device_name));
+        return (it != device_info_map.end()) ? &it->second : nullptr;
+    }
+
+    size_t get_device_count() const
+    { return device_info_map.size(); }
+
+    size_t get_total_kv_count() const
+    {
+        size_t total = 0;
+        for (const auto& entry : device_info_map)
+            total += entry.second.size();
+        return total;
+    }
+
+private:
+    const Packet* pkt;
+    DeviceInfoMap device_info_map;
+};
+
+}
+
+#endif