]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
appid: two way ssl cache data
authorDaniil Kolomiiets <dkolomii@cisco.com>
Mon, 6 Apr 2026 15:08:01 +0000 (18:08 +0300)
committerGitHub <noreply@github.com>
Mon, 6 Apr 2026 15:08:01 +0000 (11:08 -0400)
Co-authored-by: Daniil Kolomiiets <dkolomii>
src/network_inspectors/appid/service_plugins/service_ssl.cc

index bf687f00025872e06234fbea626f8c412092bf2e..ead06a9a7ca739e33af6ac7122fd8be97361b3ff 100644 (file)
@@ -66,9 +66,10 @@ public:
     SSLV3ServerCertData server_cert = {};
     int in_certs = 0;         // Currently collecting certificates?
     int certs_curr_len = 0;   // Current amount of collected certificate data.
-    uint8_t* cached_data = nullptr;
-    uint16_t cached_len = 0;
-    bool cached_client_data = false;
+    uint8_t* cached_client_data = nullptr;
+    uint16_t cached_client_data_len = 0;
+    uint8_t* cached_server_data = nullptr;
+    uint16_t cached_server_data_len = 0;
 };
 
 #pragma pack(1)
@@ -192,7 +193,8 @@ ServiceSSLData::~ServiceSSLData()
 {
     client_hello.clear();
     server_cert.clear();
-    ssl_cache_free(cached_data, cached_len);
+    ssl_cache_free(cached_client_data, cached_client_data_len);
+    ssl_cache_free(cached_server_data, cached_server_data_len);
 }
 
 static ParseResult parse_client_initiation(const uint8_t* data, uint16_t size, ServiceSSLData* ss)
@@ -216,14 +218,15 @@ static ParseResult parse_client_initiation(const uint8_t* data, uint16_t size, S
     return parse_client_hello_data(data, size, &ss->client_hello);
 }
 
-static void save_ssl_cache(ServiceSSLData* ss, uint16_t size, const uint8_t* data)
+static void save_ssl_cached_data(uint8_t*& cached_data, uint16_t& cached_data_len,
+    uint16_t size, const uint8_t* data)
 {
-    if(size == 0)
+    if (size == 0)
         return;
 
-    ss->cached_data = (uint8_t*)snort_calloc(size, sizeof(uint8_t));
-    memcpy(ss->cached_data, data, size);
-    ss->cached_len = size;
+    cached_data = (uint8_t*)snort_calloc(size, sizeof(uint8_t));
+    memcpy(cached_data, data, size);
+    cached_data_len = size;
 }
 
 int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
@@ -254,43 +257,45 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
          and args.dir == APP_ID_FROM_INITIATOR
          and !(args.asd.scan_flags & SCAN_CERTVIZ_ENABLED_FLAG))
     {
-        if (ss->cached_data)
+        if (ss->cached_client_data)
         {
-            reallocated_data = (uint8_t*)snort_calloc(ss->cached_len + size, sizeof(uint8_t));
+            reallocated_data = (uint8_t*)snort_calloc(ss->cached_client_data_len + size, sizeof(uint8_t));
             if (reallocated_data == nullptr)
                 goto inprocess;
             memcpy(reallocated_data, args.data, args.size);
-            memcpy(reallocated_data + args.size, ss->cached_data, ss->cached_len);
-            size = ss->cached_len + args.size;
-            ssl_cache_free(ss->cached_data, ss->cached_len);
+            memcpy(reallocated_data + args.size, ss->cached_client_data, ss->cached_client_data_len);
+            size = ss->cached_client_data_len + args.size;
+            ssl_cache_free(ss->cached_client_data, ss->cached_client_data_len);
             data = reallocated_data;
         }
         else
         {
-            save_ssl_cache(ss, size, data);
-            ss->cached_client_data = true;
+            save_ssl_cached_data(ss->cached_client_data, ss->cached_client_data_len, size, data);
             goto inprocess;
         }
     }
-
-    if (ss->cached_data)
+    else
     {
-        if ( (ss->cached_client_data and (args.dir == APP_ID_FROM_INITIATOR)) or (!ss->cached_client_data and (args.dir == APP_ID_FROM_RESPONDER)) )
+        uint8_t*& cached_data = (args.dir == APP_ID_FROM_INITIATOR) ? ss->cached_client_data : ss->cached_server_data;
+        uint16_t& cached_data_len = (args.dir == APP_ID_FROM_INITIATOR) ? ss->cached_client_data_len : ss->cached_server_data_len;
+
+        if (cached_data)
         {
-            reallocated_data = (uint8_t*)snort_calloc(ss->cached_len + size, sizeof(uint8_t));
+            reallocated_data = (uint8_t*)snort_calloc(cached_data_len + size, sizeof(uint8_t));
             if (reallocated_data == nullptr)
                 goto inprocess;
-            memcpy(reallocated_data, ss->cached_data, ss->cached_len);
-            memcpy(reallocated_data + ss->cached_len, args.data, args.size);
-            size = ss->cached_len + args.size;
-            ssl_cache_free(ss->cached_data, ss->cached_len);
+            memcpy(reallocated_data, cached_data, cached_data_len);
+            memcpy(reallocated_data + cached_data_len, args.data, args.size);
+            size = cached_data_len + args.size;
+            ssl_cache_free(cached_data, cached_data_len);
             data = reallocated_data;
         }
     }
+    
     /* Start off with a Client Hello from client to server. */
-    if (ss->state == SSL_STATE_INITIATE)
+    if (ss->state == SSL_STATE_INITIATE || (reallocated_data && args.dir == APP_ID_FROM_INITIATOR))
     {
-        ss->state = SSL_STATE_CONNECTION;
+        ss->state = reallocated_data ? ss->state : SSL_STATE_CONNECTION;
 
         if (!(args.asd.scan_flags & SCAN_CERTVIZ_ENABLED_FLAG) and
             args.dir == APP_ID_FROM_INITIATOR)
@@ -298,9 +303,7 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
             auto parse_status = parse_client_initiation(data, size, ss);
             if (parse_status == ParseResult::FRAGMENTED_PACKET)
             {
-                save_ssl_cache(ss, size, data);
-                ss->cached_client_data = true;
-                ss->state = SSL_STATE_INITIATE;
+                save_ssl_cached_data(ss->cached_client_data, ss->cached_client_data_len, size, data);
                 goto inprocess;
             }
             else if (parse_status == ParseResult::FAILURE)
@@ -396,8 +399,7 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
                 {
                     if (size < sizeof(ServiceSSLV3Hdr))
                     {
-                        save_ssl_cache(ss, size, data);
-                        ss->cached_client_data = false;
+                        save_ssl_cached_data(ss->cached_server_data, ss->cached_server_data_len, size, data);
                         goto inprocess;
                     }
 
@@ -422,8 +424,7 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
 
                 if (size < offsetof(ServiceSSLV3Record, version))
                 {
-                    save_ssl_cache(ss, size, data);
-                    ss->cached_client_data = false;
+                    save_ssl_cached_data(ss->cached_server_data, ss->cached_server_data_len, size, data);
                     goto inprocess;
                 }
 
@@ -572,30 +573,29 @@ success:
     args.asd.set_session_flags(APPID_SESSION_SSL_SESSION);
     if (!args.asd.tsession)
             args.asd.tsession = new TlsSession();
-    if (ss->client_hello.host_name || ss->server_cert.common_name || ss->server_cert.org_unit)
-    {
-        /* TLS Host */
-        if (ss->client_hello.host_name)
-        {
-            args.asd.tsession->set_tls_sni(ss->client_hello.host_name, 0);
-            args.asd.scan_flags |= SCAN_SSL_HOST_FLAG;
-        }
 
-        /* TLS Common Name */
-        if (ss->server_cert.common_name)
-        {
-            args.asd.tsession->set_tls_cname(ss->server_cert.common_name, 0);
-            args.asd.scan_flags |= SCAN_SSL_CERTIFICATE_FLAG;
-        }
-        /* TLS Org Unit */
-        if (ss->server_cert.org_unit)
-        {
-            args.asd.tsession->set_tls_org_unit(ss->server_cert.org_unit, 0);
-            args.asd.scan_flags |= SCAN_SSL_ORG_UNIT_FLAG;
-        }   
+    /* TLS Host */
+    if (ss->client_hello.host_name)
+    {
+        args.asd.tsession->set_tls_sni(ss->client_hello.host_name, 0);
+        args.asd.scan_flags |= SCAN_SSL_HOST_FLAG;
+    }
 
-        ss->client_hello.host_name = ss->server_cert.common_name = ss->server_cert.org_unit = nullptr;
+    /* TLS Common Name */
+    if (ss->server_cert.common_name)
+    {
+        args.asd.tsession->set_tls_cname(ss->server_cert.common_name, 0);
+        args.asd.scan_flags |= SCAN_SSL_CERTIFICATE_FLAG;
     }
+    /* TLS Org Unit */
+    if (ss->server_cert.org_unit)
+    {
+        args.asd.tsession->set_tls_org_unit(ss->server_cert.org_unit, 0);
+        args.asd.scan_flags |= SCAN_SSL_ORG_UNIT_FLAG;
+    }   
+
+    ss->client_hello.host_name = ss->server_cert.common_name = ss->server_cert.org_unit = nullptr;
+    
     args.asd.tsession->set_tls_handshake_done();
     return add_service(args.change_bits, args.asd, args.pkt, args.dir,
         getSslServiceAppId(args.pkt->ptrs.sp));