]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4679: DNS: Handle multi trans_IDs in single DNS-UDP flow
authorWei Wang (weiwa) <weiwa@cisco.com>
Thu, 3 Apr 2025 18:56:15 +0000 (18:56 +0000)
committerSteve Chew (stechew) <stechew@cisco.com>
Thu, 3 Apr 2025 18:56:15 +0000 (18:56 +0000)
Merge in SNORT/snort3 from ~WEIWA/snort3:weiwa-dns-udp-flow-multi-tx to master

Squashed commit of the following:

commit bd686ccda796712e9545afa72fbcce4e31e50af1
Author: Wei Wang <weiwa@cisco.com>
Date:   Thu Apr 3 22:33:06 2025 +0530

    DNS: Handle multi trans_IDs in single DNS-UDP flow

src/service_inspectors/dns/dns.cc
src/service_inspectors/dns/dns.h

index 36078aedc8939e1d3e7c4061a9c1beeabcb2678f..3bc1df31930f64b5951581c33e620b28489d4e4b 100644 (file)
@@ -77,6 +77,10 @@ DnsFlowData::~DnsFlowData()
     dnsstats.concurrent_sessions--;
 }
 
+unsigned DnsUdpFlowData::inspector_id = 0;
+
+DnsUdpFlowData::DnsUdpFlowData() : FlowData(inspector_id) {}
+
 bool DNSData::publish_response() const
 {
     return (dns_config->publish_response and state == DNS_RESP_STATE_ANS_RR);
@@ -1149,6 +1153,53 @@ StreamSplitter* Dns::get_splitter(bool c2s)
     return new DnsSplitter(c2s);
 }
 
+// Get the DNS transaction ID from a UDP packet's data field
+static inline uint16_t get_udp_trans_id(Packet* p)
+{
+    // The length of packet's data field should have already been validated
+    return (static_cast<uint16_t>(p->data[0]) << 8) | static_cast<uint16_t>(p->data[1]);
+}
+
+// Add DNS transaction ID to the UDP packet's flow data object
+static void add_to_udp_flow(Packet* p, uint16_t trans_id)
+{
+    DnsUdpFlowData* udp_flow_data = (DnsUdpFlowData*)((p->flow)->get_flow_data(DnsUdpFlowData::inspector_id));
+    if (!udp_flow_data)
+    {
+        udp_flow_data = new DnsUdpFlowData();
+        p->flow->set_flow_data(udp_flow_data);
+    }
+    udp_flow_data->trans_ids.emplace(trans_id);
+}
+
+// Check if the DNS transaction ID is found in the UDP packet's flow data object
+static bool is_in_udp_flow(Packet* p, uint16_t trans_id)
+{
+    bool found = false;
+    DnsUdpFlowData* udp_flow_data = (DnsUdpFlowData*)((p->flow)->get_flow_data(DnsUdpFlowData::inspector_id));
+    if (udp_flow_data)
+        found = udp_flow_data->trans_ids.find(trans_id) != udp_flow_data->trans_ids.end();
+    return found;
+}
+
+// Remove DNS transaction ID from the UDP packet's flow data object
+static void rm_from_udp_flow(Packet* p, uint16_t trans_id)
+{
+    DnsUdpFlowData* udp_flow_data = (DnsUdpFlowData*)((p->flow)->get_flow_data(DnsUdpFlowData::inspector_id));
+    bool should_close = true;
+    if (udp_flow_data)
+    {
+        udp_flow_data->trans_ids.erase(trans_id);
+        should_close = udp_flow_data->trans_ids.empty();
+    }
+    if (should_close)
+    {
+        // Mark the UDP flow as "closed" only when all trans_ids are matched
+        // and removed by DNS-reply packets, or if the flow data object is not found
+        p->flow->session_state |= STREAM_STATE_CLOSED;
+    }
+}
+
 static void snort_dns(Packet* p, const DnsConfig* dns_config)
 {
     // cppcheck-suppress unreadVariable
@@ -1191,27 +1242,44 @@ static void snort_dns(Packet* p, const DnsConfig* dns_config)
     dnsSessionData->dns_config = dns_config;
     if ( from_server )
     {
-        bool needNextPacket = false;
-        ParseDNSResponseMessage(p, dnsSessionData, needNextPacket);
-
-        if (!dnsSessionData->valid_dns(dnsSessionData->hdr))
+        uint16_t trans_id = 0;
+        // Always parse the response packet for TCP flows
+        bool should_parse_response = true;
+        if (p->is_udp())
         {
-            dnsSessionData->flags |= DNS_FLAG_NOT_DNS;
-            return;
+            // If this is a DNS-over-UDP flow then parse the response packet and publish events
+            // only when the response packet's DNS transaction-ID is found in the flow data object
+            trans_id = get_udp_trans_id(p);
+            should_parse_response = is_in_udp_flow(p, trans_id);
         }
 
-        if (!needNextPacket and dnsSessionData->has_events())
-            DataBus::publish(Dns::get_pub_id(), DnsEventIds::DNS_RESPONSE_DATA, dnsSessionData->dns_events);
+        if (should_parse_response)
+        {
+            bool needNextPacket = false;
+            ParseDNSResponseMessage(p, dnsSessionData, needNextPacket);
+            trans_id = dnsSessionData->hdr.id;
 
-        DnsResponseEvent dns_response_event(*dnsSessionData);
-        DataBus::publish(Dns::get_pub_id(), DnsEventIds::DNS_RESPONSE, dns_response_event, p->flow);
+            if (!dnsSessionData->valid_dns(dnsSessionData->hdr))
+            {
+                dnsSessionData->flags |= DNS_FLAG_NOT_DNS;
+                return;
+            }
+
+            if (!needNextPacket and dnsSessionData->has_events())
+                DataBus::publish(Dns::get_pub_id(), DnsEventIds::DNS_RESPONSE_DATA, dnsSessionData->dns_events);
+
+            DnsResponseEvent dns_response_event(*dnsSessionData);
+            DataBus::publish(Dns::get_pub_id(), DnsEventIds::DNS_RESPONSE, dns_response_event, p->flow);
+        }
 
-        if (p->type() == PktType::UDP)
-            p->flow->session_state |= STREAM_STATE_CLOSED;
+        if (p->is_udp())
+            rm_from_udp_flow(p, trans_id);
     }
     else
     {
         dnsstats.requests++;
+        if (p->is_udp())
+            add_to_udp_flow(p, get_udp_trans_id(p));
     }
 }
 
@@ -1228,6 +1296,7 @@ static void mod_dtor(Module* m)
 static void dns_init()
 {
     DnsFlowData::init();
+    DnsUdpFlowData::init();
 }
 
 static Inspector* dns_ctor(Module* m)
index 9b09961bf2cd2c8a1d231e802d73a9ffe7c680a6..6bf2fba6027171550732c781ebffa780b930daab 100644 (file)
@@ -22,6 +22,8 @@
 #ifndef DNS_H
 #define DNS_H
 
+#include <set>
+
 #include "flow/flow.h"
 
 #include "protocols/packet.h"
@@ -233,6 +235,7 @@ private:
     uint16_t type = 0;
 };
 
+// Flow data class for DNS over TCP
 class DnsFlowData : public snort::FlowData
 {
 public:
@@ -247,5 +250,19 @@ public:
     DNSData session;
 };
 
+// Flow data class for DNS over UDP
+class DnsUdpFlowData : public snort::FlowData
+{
+public:
+    DnsUdpFlowData();
+
+    static void init()
+    { inspector_id = snort::FlowData::create_flow_data_id(); }
+
+public:
+    static unsigned inspector_id;
+    std::set<uint16_t> trans_ids;
+};
+
 #endif