]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #4950: appid: add multi-stream support for DNS
authorShibin K V (shikv) <shikv@cisco.com>
Wed, 29 Oct 2025 06:13:24 +0000 (06:13 +0000)
committerShanmugam S (shanms) <shanms@cisco.com>
Wed, 29 Oct 2025 06:13:24 +0000 (06:13 +0000)
Merge in SNORT/snort3 from ~SHIKV/snort3:doh_multi_stream to master

Squashed commit of the following:

commit e46e9809c787162b84bdd9147a27cde496cd8714
Author: shibin k v <shikv@cisco.com>
Date:   Tue Oct 21 04:00:46 2025 -0500

    appid: add multi-stream support for DNS

src/flow/stream_flow.h
src/network_inspectors/appid/appid_session.cc
src/network_inspectors/appid/appid_session_api.cc
src/network_inspectors/appid/appid_session_api.h
src/network_inspectors/appid/detector_plugins/detector_dns.cc
src/network_inspectors/appid/detector_plugins/test/detector_dns_test.cc

index 9775e95bd0e207c65fcaaf69b6309563dc0a1a44..65af57f21b28a38192585c38a69c3003a3ee73b2 100644 (file)
@@ -33,11 +33,13 @@ class FlowData;
 class SO_PUBLIC StreamFlowIntf
 {
 public:
-    virtual FlowData* get_stream_flow_data(const Flow*) = 0;
-    virtual void set_stream_flow_data(Flow*, FlowData*) = 0;
+    virtual FlowData* get_stream_flow_data(const Flow*)
+    { return nullptr; }
+    virtual void set_stream_flow_data(Flow*, FlowData*) { }
     virtual void get_stream_id(const Flow*, int64_t& stream_id) = 0;
-    virtual void* get_hi_msg_section(const Flow*) = 0;
-    virtual void set_hi_msg_section(Flow*, void* section) = 0;
+    virtual void* get_hi_msg_section(const Flow*)
+    { return nullptr; }
+    virtual void set_hi_msg_section(Flow*, void*) { }
     virtual AppId get_appid_from_stream(const Flow*) { return APP_ID_NONE; }
     // Stream based flows should override this interface to return parent flow
     // when child flow is passed as input
index 2d0ed9f0dbf8717a0e58e948244d5934c61bbdc0..0d5697eabde7b561cff43728b98c671b847796e6 100644 (file)
@@ -28,6 +28,7 @@
 #include <cstring>
 
 #include "flow/flow_stash.h"
+#include "flow/stream_flow.h"
 #include "main/snort_config.h"
 #include "managers/inspector_manager.h"
 #include "profiler/profiler.h"
@@ -1195,15 +1196,45 @@ void AppIdSession::delete_all_http_sessions()
 
 AppIdDnsSession* AppIdSession::create_dns_session()
 {
-    if (api.dsession)
-        delete api.dsession;
-    api.dsession = new AppIdDnsSession();
-    return api.dsession;
+    if (flow->stream_intf)
+    {
+        int64_t stream_id;
+        flow->stream_intf->get_stream_id(flow, stream_id);
+        api.dsessions.emplace(stream_id, new AppIdDnsSession());
+        return api.dsessions[stream_id];
+    }
+    else
+    {
+        if (api.dsession)
+            delete api.dsession;
+        api.dsession = new AppIdDnsSession();
+        return api.dsession;
+    }
 }
 
 AppIdDnsSession* AppIdSession::get_dns_session() const
 {
-    return api.dsession;
+    if (flow->stream_intf)
+    {
+        int64_t stream_id;
+        flow->stream_intf->get_stream_id(flow, stream_id);
+        if (stream_id == 0xFFFFFFFF) // no stream id is processing now, pick the last processed
+        {
+            if (!api.dsessions.empty())
+                return std::prev(api.dsessions.end())->second;
+            else
+                return nullptr;
+        }
+        auto it = api.dsessions.find(stream_id);
+        if (it != api.dsessions.end())
+            return it->second;
+
+        return nullptr;
+    }
+    else if (!api.dsessions.empty()) // flow data of inspector who handles stream_intf got deleted, take the last one
+        return std::prev(api.dsessions.end())->second;
+    else
+        return api.dsession;
 }
 
 bool AppIdSession::is_tp_appid_done() const
index eaedd893a14a23664daaa1d1b4dafa17797f7382..07ea75e33bf438ee8b466a7acb8f627b54e59744 100644 (file)
@@ -26,6 +26,8 @@
 #include "appid_session_api.h"
 
 #include "flow/ha.h"
+#include "flow/stream_flow.h"
+
 #include "appid_inspector.h"
 #include "appid_peg_counts.h"
 #include "appid_session.h"
@@ -374,7 +376,28 @@ uint16_t AppIdSessionApi::get_service_port() const
 
 const AppIdDnsSession* AppIdSessionApi::get_dns_session() const
 {
-    return dsession;
+    if (asd && asd->flow->stream_intf)
+    {
+        int64_t stream_id;
+        asd->flow->stream_intf->get_stream_id(asd->flow, stream_id);
+        if (stream_id == 0xFFFFFFFF || stream_id == -1) // no stream id is processing now, pick the last processed
+        {
+            if (!dsessions.empty())
+                return std::prev(dsessions.end())->second;
+            else
+                return nullptr;
+        }
+        auto it = dsessions.find(stream_id);
+        if (it != dsessions.end())
+        {
+            return it->second;
+        }
+        return nullptr;
+    }
+    else if (!dsessions.empty()) // flow data of inspector who handles stream_intf got deleted, take the last one
+        return std::prev(dsessions.end())->second;
+    else
+        return dsession;
 }
 
 bool AppIdSessionApi::is_http_inspection_done() const
index 6444f25d2e9dc2f81753e747e55354eb34513dcc..9858d46cc4744a6f249a0316d366685213fca144 100644 (file)
@@ -22,6 +22,8 @@
 #ifndef APPID_SESSION_API_H
 #define APPID_SESSION_API_H
 
+#include <map>
+
 #include "flow/flow.h"
 #include "main/snort_types.h"
 #include "pub_sub/appid_events.h"
@@ -175,7 +177,8 @@ private:
         bool user_logged_in : 1;
     } flags = {};
     std::vector<std::unique_ptr<AppIdHttpSession>> hsessions;
-    AppIdDnsSession* dsession = nullptr;
+    std::map<int64_t, AppIdDnsSession*> dsessions; // for multi stream DNS like DoH2/DoH3/DoQ
+    AppIdDnsSession* dsession = nullptr; // for DNS over TCP/UDP
     snort::SfIp initiator_ip;
     ServiceAppDescriptor service;
     char* tls_host = nullptr;
@@ -212,6 +215,9 @@ private:
         netbios_domain = nullptr;
         snort_free(tls_sni);
         tls_sni = nullptr;
+        for (auto& pair : dsessions)
+            delete pair.second;
+        dsessions.clear();
         delete dsession;
         dsession = nullptr;
     }
index 9e6b56ecf2c6d21a077ffd9d7f6d25c5b9e4e8ec..ed107ca00b2d28630193791c906bfbac01646332 100644 (file)
@@ -25,6 +25,8 @@
 
 #include "detector_dns.h"
 
+#include "flow/stream_flow.h"
+
 #include "appid_config.h"
 #include "appid_dns_session.h"
 #include "app_info_table.h"
@@ -178,6 +180,36 @@ ServiceDNSData::~ServiceDNSData()
     free_dns_cache();
 }
 
+class ServiceDNSDoQData : public AppIdFlowData
+{
+public:
+    ServiceDNSDoQData() = default;
+    ~ServiceDNSDoQData() override;
+
+    std::unordered_map<int64_t, ServiceDNSData*> service_dns_data;
+    ServiceDNSData* get_stream_dns_data(int64_t stream_id)
+    {
+        auto it = service_dns_data.find(stream_id);
+        if (it != service_dns_data.end())
+            return it->second;
+        else
+            return nullptr;
+    }
+    ServiceDNSData* create_stream_dns_data(int64_t stream_id)
+    {
+        ServiceDNSData* dns_data = new ServiceDNSData();
+        service_dns_data[stream_id] = dns_data;
+        return dns_data;
+    }
+};
+
+ServiceDNSDoQData::~ServiceDNSDoQData()
+{
+    for (auto& it: service_dns_data)
+        delete it.second;
+    service_dns_data.clear();
+}
+
 DnsTcpServiceDetector::DnsTcpServiceDetector(ServiceDiscovery* sd)
 {
     handler = sd;
@@ -685,7 +717,8 @@ int DnsTcpServiceDetector::validate_doq(AppIdDiscoveryArgs& args)
     uint8_t* reallocated_data = nullptr;
     const uint8_t* data = args.data;
     uint16_t size = args.size;
-    ServiceDNSData* dd = static_cast<ServiceDNSData*>(data_get(args.asd));
+    ServiceDNSDoQData* dd_doq;
+    ServiceDNSData* dd;
     {
         if (!args.size)
             goto inprocess;
@@ -697,13 +730,29 @@ int DnsTcpServiceDetector::validate_doq(AppIdDiscoveryArgs& args)
             else
                 goto fail;
         }
-
-        if (!dd)
+        if (args.asd.flow->stream_intf)
         {
-            dd = new ServiceDNSData;
-            data_add(args.asd, dd);
+            dd_doq = static_cast<ServiceDNSDoQData*>(data_get(args.asd));
+            if (!dd_doq)
+            {
+                dd_doq = new ServiceDNSDoQData;
+                data_add(args.asd, dd_doq);
+            }
+            int64_t stream_id;
+            args.asd.flow->stream_intf->get_stream_id(args.asd.flow, stream_id);
+            dd = dd_doq->get_stream_dns_data(stream_id);
+            if (!dd)
+                dd = dd_doq->create_stream_dns_data(stream_id);
+        }
+        else
+        {
+            dd = static_cast<ServiceDNSData*>(data_get(args.asd));
+            if (!dd)
+            {
+                dd = new ServiceDNSData;
+                data_add(args.asd, dd);
+            }
         }
-
         if (dd->cached_data and dd->cached_len and args.dir == APP_ID_FROM_INITIATOR)
         {
             reallocated_data = static_cast<uint8_t*>(snort_calloc(dd->cached_len + args.size, sizeof(uint8_t)));
index 9c080c6e65408fce5e093c6b50764178fee48a1f..5427de3dee832cf91d456ba165b38d63793960b4 100644 (file)
 
 using namespace snort;
 
+Flow::~Flow() = default;
+class QuicStreamIntf: public snort::StreamFlowIntf
+{
+public:
+    void get_stream_id(const snort::Flow* flow, int64_t& stream_id) override
+    {
+        stream_id = 0;
+    }
+};
+static QuicStreamIntf quic_stream_intf;
+
 static ServiceDiscovery test_discovery;
 
 static AppIdDnsSession static_dns_session;
@@ -55,17 +66,27 @@ int ServiceDetector::incompatible_data(AppIdSession&, const snort::Packet*, Appi
 }
 // Stubs for AppIdInspector
 static ServiceDNSData dd;
+static bool return_null_data = false;
+ServiceDNSDoQData* dns_doq_data  = nullptr;
 AppIdConfig test_app_config;
 
 AppIdInspector::AppIdInspector(AppIdModule&) : config(&test_app_config), ctxt(test_app_config) { }
 
 void AppIdDetector::add_user(AppIdSession&, const char*, AppId, bool, AppidChangeBits&){}
 
-int AppIdDetector::data_add(AppIdSession&, AppIdFlowData*) { return 1; }
+int AppIdDetector::data_add(AppIdSession&, AppIdFlowData* dns_data)
+{
+    if (return_null_data)
+    {
+         dns_doq_data = static_cast<ServiceDNSDoQData*>(dns_data);
+         return 1;
+    }
+    return 1;
+}
 
 AppIdFlowData* AppIdDetector::data_get(const AppIdSession&)
 {
-    return &dd;
+    return return_null_data ? nullptr : &dd;
 }
 
 int ServiceDetector::fail_service(AppIdSession& asd, const Packet* pkt, AppidSessionDirection dir) { return 1; }
@@ -93,7 +114,9 @@ TEST(detector_dns_doq_tests, doq_validator_match_full_session)
     AppIdInspector test_inspector(test_module);
     dd.state = DNS_STATE_QUERY;
     AppIdSession test_asd(IpProtocol::TCP, nullptr, (uint16_t)0, test_inspector, test_odp_ctxt, (uint32_t)0, 0);
-        uint8_t dns_tcp_packet[] = {
+    Flow* flow = new Flow;
+    test_asd.flow = flow;
+    uint8_t dns_tcp_packet[] = {
         0x00, 0x1c, // TCP length (28 bytes)
         0x12, 0x34, // id
         0x01, 0x00, // flags
@@ -138,7 +161,7 @@ TEST(detector_dns_doq_tests, doq_validator_match_full_session)
 
     AppIdDiscoveryArgs response_args(dns_tcp_response, sizeof(dns_tcp_response), APP_ID_FROM_RESPONDER, test_asd, nullptr, change_bits);
     result = test_detector->validate_doq(response_args);
-
+    delete flow;
     CHECK_EQUAL(APPID_SUCCESS, result);
 }
 
@@ -149,6 +172,8 @@ TEST(detector_dns_doq_tests, doq_validator_in_process_cached)
     AppIdInspector test_inspector(test_module);
     dd.state = DNS_STATE_QUERY;
     AppIdSession test_asd(IpProtocol::TCP, nullptr, (uint16_t)0, test_inspector, test_odp_ctxt, (uint32_t)0, 0);
+    Flow* flow = new Flow;
+    test_asd.flow = flow;
     // partial data
     uint8_t dns_tcp_packet_1[] = {
         0x00, 0x1c, // TCP length (28 bytes)
@@ -173,6 +198,7 @@ TEST(detector_dns_doq_tests, doq_validator_in_process_cached)
     };
     AppIdDiscoveryArgs args2(dns_tcp_packet_2, sizeof(dns_tcp_packet_2), APP_ID_FROM_INITIATOR, test_asd, nullptr, change_bits);
     result = test_detector->validate_doq(args2);
+    delete flow;
     CHECK_EQUAL(APPID_INPROCESS, result);
 }
 
@@ -183,7 +209,10 @@ TEST(detector_dns_doq_tests, doq_validator_not_compatible)
     AppIdInspector test_inspector(test_module);
     dd.state = DNS_STATE_QUERY;
     AppIdSession test_asd(IpProtocol::TCP, nullptr, (uint16_t)0, test_inspector, test_odp_ctxt, (uint32_t)0, 0);
-        uint8_t dns_tcp_packet[] = {
+    Flow* flow = new Flow;
+    test_asd.flow = flow;
+
+    uint8_t dns_tcp_packet[] = {
         0x00, 0x01, // TCP length (1 bytes)
         0x12, 0x34, // id
         0x01, 0x00, // flags
@@ -199,6 +228,7 @@ TEST(detector_dns_doq_tests, doq_validator_not_compatible)
     AppidChangeBits change_bits;
     AppIdDiscoveryArgs args(dns_tcp_packet, sizeof(dns_tcp_packet), APP_ID_FROM_INITIATOR, test_asd, nullptr, change_bits);
     auto result = test_detector->validate_doq(args);
+    delete flow;
     CHECK_EQUAL(APPID_NOT_COMPATIBLE, result);
 
 }
@@ -210,7 +240,9 @@ TEST(detector_dns_doq_tests, doq_validator_no_match)
     AppIdInspector test_inspector(test_module);
     dd.state = DNS_STATE_QUERY;
     AppIdSession test_asd(IpProtocol::TCP, nullptr, (uint16_t)0, test_inspector, test_odp_ctxt, (uint32_t)0, 0);
-        uint8_t dns_tcp_packet[] = {
+    Flow* flow = new Flow;
+    test_asd.flow = flow;
+    uint8_t dns_tcp_packet[] = {
         0x00, 0x10, // TCP length (28 bytes)
         0x12, 0x34, // id
         0x01, 0x00, // flags
@@ -226,9 +258,44 @@ TEST(detector_dns_doq_tests, doq_validator_no_match)
     AppidChangeBits change_bits;
     AppIdDiscoveryArgs args(dns_tcp_packet, sizeof(dns_tcp_packet), APP_ID_FROM_RESPONDER, test_asd, nullptr, change_bits);
     auto result = test_detector->validate_doq(args);
+    delete flow;
     CHECK_EQUAL(APPID_NOMATCH, result);
 }
 
+TEST(detector_dns_doq_tests, doq_validator_stream_intf)
+{
+    OdpContext test_odp_ctxt(test_app_config, nullptr);
+    AppIdModule test_module;
+    AppIdInspector test_inspector(test_module);
+    dd.state = DNS_STATE_QUERY;
+    AppIdSession test_asd(IpProtocol::TCP, nullptr, (uint16_t)0, test_inspector, test_odp_ctxt, (uint32_t)0, 0);
+    Flow* flow = new Flow;
+    test_asd.flow = flow;
+    flow->stream_intf = &quic_stream_intf;
+    return_null_data = true; // Set flag to return nullptr from data_get
+    uint8_t dns_tcp_packet[] = {
+        0x00, 0x1c, // TCP length (28 bytes)
+        0x12, 0x34, // id
+        0x01, 0x00, // flags
+        0x00, 0x01, // QDCount
+        0x00, 0x00, // ANCount
+        0x00, 0x00, // NSCount
+        0x00, 0x00, // ARCount
+        0x03, 'w', 'w', 'w', 0x00, // "www"
+        0x00, 0x01, // QType A
+        0x00, 0x01  // QClass IN
+    };
+
+    AppidChangeBits change_bits;
+    AppIdDiscoveryArgs args(dns_tcp_packet, sizeof(dns_tcp_packet), APP_ID_FROM_INITIATOR, test_asd, nullptr, change_bits);
+    auto result = test_detector->validate_doq(args);
+    delete flow;
+    return_null_data = false;
+    delete dns_doq_data;
+    dns_doq_data = nullptr;
+    CHECK_EQUAL(APPID_INPROCESS, result);
+}
+
 
 
 int main(int argc, char** argv)