]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #1742 in SNORT/snort3 from ~CLJUDGE/snort3:set_payload_unknown_if_...
authorShravan Rangarajuvenkata (shrarang) <shrarang@cisco.com>
Fri, 11 Oct 2019 18:58:19 +0000 (14:58 -0400)
committerShravan Rangarajuvenkata (shrarang) <shrarang@cisco.com>
Fri, 11 Oct 2019 18:58:19 +0000 (14:58 -0400)
Squashed commit of the following:

commit f06c11626ed3bc09d801b4b589d4c6b9ed51f00a
Author: cljudge <cljudge@cisco.com>
Date:   Thu Sep 12 03:13:54 2019 -0400

    appid: for ssl sessions, set payload id to unknown after ssl handshake is done if the payload id was not not found

src/network_inspectors/appid/appid_api.cc
src/network_inspectors/appid/appid_session.cc
src/network_inspectors/appid/appid_session.h
src/network_inspectors/appid/service_plugins/service_ssl.cc

index ea43658f4af04f4b9eb8256cb3ece25cc4561182..67b3206d4ed9b3b75cda7dde7821e99cb4de5d87 100644 (file)
@@ -61,25 +61,25 @@ const char* AppIdApi::get_application_name(const Flow& flow, bool from_client)
     const char* app_name = nullptr;
     AppId appid = APP_ID_NONE;
     AppIdSession* asd = get_appid_session(flow);
-    if ( asd )
+    if (asd)
     {
         appid = asd->pick_payload_app_id();
-        if (  !appid )
+        if (appid <= APP_ID_NONE)
             appid = asd->pick_misc_app_id();
-        if (  !appid and from_client)
+        if (!appid and from_client)
         {
             appid = asd->pick_client_app_id();
-            if ( !appid)
+            if (!appid)
                 appid = asd->pick_service_app_id();
         }
-        else if ( !appid )
+        else if (!appid)
         {
             appid = asd->pick_service_app_id();
-            if ( !appid)
+            if (!appid)
                 appid = asd->pick_client_app_id();
         }
     }
-    if (appid  > APP_ID_NONE && appid < SF_APPID_MAX)
+    if (appid > APP_ID_NONE && appid < SF_APPID_MAX)
         app_name = AppInfoManager::get_instance().get_app_name(appid);
 
     return app_name;
@@ -100,14 +100,14 @@ uint32_t AppIdApi::produce_ha_state(const Flow& flow, uint8_t* buf)
     assert(buf);
     AppIdSessionHA* appHA = (AppIdSessionHA*)buf;
     AppIdSession* asd = get_appid_session(flow);
-    if ( asd and ( asd->common.flow_type == APPID_FLOW_TYPE_NORMAL ) )
+    if (asd and (asd->common.flow_type == APPID_FLOW_TYPE_NORMAL))
     {
         appHA->flags = APPID_HA_FLAGS_APP;
-        if ( asd->is_tp_appid_available() )
+        if (asd->is_tp_appid_available())
             appHA->flags |= APPID_HA_FLAGS_TP_DONE;
-        if ( asd->is_service_detected() )
+        if (asd->is_service_detected())
             appHA->flags |= APPID_HA_FLAGS_SVC_DONE;
-        if ( asd->get_session_flags(APPID_SESSION_HTTP_SESSION) )
+        if (asd->get_session_flags(APPID_SESSION_HTTP_SESSION))
             appHA->flags |= APPID_HA_FLAGS_HTTP;
         appHA->appId[0] = asd->get_tp_app_id();
         appHA->appId[1] = asd->service.get_id();
@@ -142,11 +142,11 @@ uint32_t AppIdApi::consume_ha_state(Flow& flow, const uint8_t* buf, uint8_t, IpP
                 asd = new AppIdSession(proto, ip, port, *inspector);
                 flow.set_flow_data(asd);
                 asd->service.set_id(appHA->appId[1]);
-                if ( asd->service.get_id() == APP_ID_FTP_CONTROL )
+                if (asd->service.get_id() == APP_ID_FTP_CONTROL)
                 {
                     asd->set_session_flags(APPID_SESSION_CLIENT_DETECTED |
                             APPID_SESSION_NOT_A_SERVICE | APPID_SESSION_SERVICE_DETECTED);
-                    if ( !ServiceDiscovery::add_ftp_service_state(*asd) )
+                    if (!ServiceDiscovery::add_ftp_service_state(*asd))
                         asd->set_session_flags(APPID_SESSION_CONTINUE);
 
                     asd->service_disco_state = APPID_DISCO_STATE_STATEFUL;
@@ -162,12 +162,12 @@ uint32_t AppIdApi::consume_ha_state(Flow& flow, const uint8_t* buf, uint8_t, IpP
             }
         }
 
-        if ( !asd )
+        if (!asd)
         {
             return sizeof(*appHA);
         }
 
-        if( (appHA->flags & APPID_HA_FLAGS_TP_DONE) && asd->tpsession )
+        if((appHA->flags & APPID_HA_FLAGS_TP_DONE) && asd->tpsession)
         {
 #ifdef ENABLE_APPID_THIRD_PARTY
             asd->tpsession->set_state(TP_STATE_TERMINATED);
index 6a86904be183c8be955256678c9934968c63ea1a..367c01458814d25d3bde4c1d5d3fd1eec72e1def 100644 (file)
@@ -79,7 +79,8 @@ AppIdSession* AppIdSession::allocate_session(const Packet* p, IpProtocol proto,
 
     const SfIp* ip = (direction == APP_ID_FROM_INITIATOR)
         ? p->ptrs.ip_api.get_src() : p->ptrs.ip_api.get_dst();
-    if ( ( proto == IpProtocol::TCP || proto == IpProtocol::UDP ) && ( p->ptrs.sp != p->ptrs.dp ) )
+    if ((proto == IpProtocol::TCP || proto == IpProtocol::UDP) &&
+        (p->ptrs.sp != p->ptrs.dp))
         port = (direction == APP_ID_FROM_INITIATOR) ? p->ptrs.sp : p->ptrs.dp;
 
     AppIdSession* asd = new AppIdSession(proto, ip, port, *inspector);
@@ -111,18 +112,19 @@ AppIdSession::AppIdSession(IpProtocol proto, const SfIp* ip, uint16_t port,
 
 AppIdSession::~AppIdSession()
 {
-    if ( !in_expected_cache )
+    if (!in_expected_cache)
     {
-        if ( config->mod_config->stats_logging_enabled )
+        if (config->mod_config->stats_logging_enabled)
             AppIdStatistics::get_stats_manager()->update(*this);
 
         // fail any service detection that is in process for this flow
-        if (!get_session_flags(APPID_SESSION_SERVICE_DETECTED | APPID_SESSION_UDP_REVERSED |
-            APPID_SESSION_MID | APPID_SESSION_OOO) and flow)
+        if (!get_session_flags(APPID_SESSION_SERVICE_DETECTED |
+            APPID_SESSION_UDP_REVERSED | APPID_SESSION_MID |
+            APPID_SESSION_OOO) and flow)
         {
             ServiceDiscoveryState* sds =
                 AppIdServiceState::get(&service_ip, protocol, service_port, is_decrypted());
-            if ( sds )
+            if (sds)
             {
                 if (flow->server_ip.fast_eq6(service_ip))
                     sds->set_service_id_failed(*this, &flow->client_ip,
@@ -187,8 +189,8 @@ AppIdSession* AppIdSession::create_future_session(const Packet* ctrlPkt, const S
     AppIdSession* asd = new AppIdSession(proto, cliIp, 0, *inspector);
     asd->common.policyId = asd->config->appIdPolicyId;
 
-    if ( Stream::set_snort_protocol_id_expected(ctrlPkt, type, proto, cliIp, cliPort, srvIp,
-        srvPort, snort_protocol_id, asd) )
+    if (Stream::set_snort_protocol_id_expected(ctrlPkt, type, proto, cliIp,
+        cliPort, srvIp, srvPort, snort_protocol_id, asd))
     {
         if (appidDebug->is_active())
         {
@@ -222,17 +224,17 @@ void AppIdSession::reinit_session_data(AppidChangeBits& change_bits)
     misc_app_id = APP_ID_NONE;
 
     //data
-    if ( is_service_over_ssl(tp_app_id) )
+    if (is_service_over_ssl(tp_app_id))
     {
         payload.reset();
         referred_payload_app_id = tp_payload_app_id = APP_ID_NONE;
         clear_session_flags(APPID_SESSION_CONTINUE);
-        if ( hsession )
+        if (hsession)
             hsession->set_field(MISC_URL_FID, nullptr, change_bits);
     }
 
     //service
-    if ( !get_session_flags(APPID_SESSION_STICKY_SERVICE) )
+    if (!get_session_flags(APPID_SESSION_STICKY_SERVICE))
     {
         service.reset();
         tp_app_id = APP_ID_NONE;
@@ -259,8 +261,9 @@ void AppIdSession::reinit_session_data(AppidChangeBits& change_bits)
     resp_tpPackets = 0;
 
     scan_flags &= ~SCAN_HTTP_HOST_URL_FLAG;
-    clear_session_flags(APPID_SESSION_SERVICE_DETECTED |APPID_SESSION_CLIENT_DETECTED |
-        APPID_SESSION_SSL_SESSION|APPID_SESSION_HTTP_SESSION | APPID_SESSION_APP_REINSPECT);
+    clear_session_flags(APPID_SESSION_SERVICE_DETECTED |
+        APPID_SESSION_CLIENT_DETECTED | APPID_SESSION_SSL_SESSION |
+        APPID_SESSION_HTTP_SESSION | APPID_SESSION_APP_REINSPECT);
 }
 
 void AppIdSession::sync_with_snort_protocol_id(AppId newAppId, Packet* p)
@@ -299,19 +302,21 @@ void AppIdSession::sync_with_snort_protocol_id(AppId newAppId, Packet* p)
         }
 
         AppInfoTableEntry* entry = app_info_mgr->get_app_info_entry(newAppId);
-        if ( entry )
+        if (entry)
         {
             SnortProtocolId tmp_snort_protocol_id = entry->snort_protocol_id;
             // A particular APP_ID_xxx may not be assigned a service_snort_key value
             // in the rna_app.yaml file entry; so ignore the snort_protocol_id ==
             // UNKNOWN_PROTOCOL_ID case.
-            if ( tmp_snort_protocol_id == UNKNOWN_PROTOCOL_ID && (newAppId == APP_ID_HTTP2))
+            if (tmp_snort_protocol_id == UNKNOWN_PROTOCOL_ID &&
+                (newAppId == APP_ID_HTTP2))
                 tmp_snort_protocol_id = snortId_for_http2;
 
-            if ( tmp_snort_protocol_id != snort_protocol_id )
+            if (tmp_snort_protocol_id != snort_protocol_id)
             {
                 snort_protocol_id = tmp_snort_protocol_id;
-                if (appidDebug->is_active() && tmp_snort_protocol_id == snortId_for_http2)
+                if (appidDebug->is_active() &&
+                    tmp_snort_protocol_id == snortId_for_http2)
                     LogMessage("AppIdDbg %s Telling Snort that it's HTTP/2\n",
                         appidDebug->get_debug_session());
 
@@ -343,8 +348,7 @@ void AppIdSession::check_app_detection_restart(AppidChangeBits& change_bits)
         encrypted.referred_id = pick_referred_payload_app_id();
         reinit_session_data(change_bits);
         if (appidDebug->is_active())
-            LogMessage("AppIdDbg %s SSL decryption is available, restarting app detection\n",
-                appidDebug->get_debug_session());
+            LogMessage("AppIdDbg %s SSL decryption is available, restarting app detection\n", appidDebug->get_debug_session());
 
         // APPID_SESSION_ENCRYPTED is set upon receiving a command which upgrades the session to
         // SSL. Next packet after the command will have encrypted traffic.  In the case of a
@@ -360,9 +364,10 @@ void AppIdSession::update_encrypted_app_id(AppId service_id)
     switch (service_id)
     {
     case APP_ID_HTTP:
-        if (misc_app_id == APP_ID_NSIIOPS || misc_app_id == APP_ID_DDM_SSL
-            || misc_app_id == APP_ID_MSFT_GC_SSL
-            || misc_app_id == APP_ID_SF_APPLIANCE_MGMT)
+        if (misc_app_id == APP_ID_NSIIOPS ||
+            misc_app_id == APP_ID_DDM_SSL ||
+            misc_app_id == APP_ID_MSFT_GC_SSL ||
+            misc_app_id == APP_ID_SF_APPLIANCE_MGMT)
         {
             break;
         }
@@ -446,6 +451,13 @@ void AppIdSession::examine_ssl_metadata(Packet* p, AppidChangeBits& change_bits)
         }
         tsession->set_tls_org_unit(nullptr, 0);
     }
+    if (tsession->get_tls_handshake_done() and
+        payload.get_id() == APP_ID_NONE)
+    {
+        if (appidDebug->is_active())
+            LogMessage("AppIdDbg %s End of SSL/TLS handshake detected with no payloadAppId, so setting to unknown\n", appidDebug->get_debug_session());
+        payload.set_id(APP_ID_UNKNOWN);
+    }
 }
 
 void AppIdSession::examine_rtmp_metadata(AppidChangeBits& change_bits)
@@ -456,20 +468,19 @@ void AppIdSession::examine_rtmp_metadata(AppidChangeBits& change_bits)
     AppId referred_payload_id = APP_ID_NONE;
     char* version = nullptr;
 
-    if ( !hsession )
+    if (!hsession)
         hsession = new AppIdHttpSession(*this);
 
-    if ( const char* url = hsession->get_cfield(MISC_URL_FID) )
+    if (const char* url = hsession->get_cfield(MISC_URL_FID))
     {
         HttpPatternMatchers* http_matchers = HttpPatternMatchers::get_instance();
         const char* referer = hsession->get_cfield(REQ_REFERER_FID);
-        if ( ( ( http_matchers->get_appid_from_url(nullptr, url, &version,
-            referer, &client_id, &service_id,
-            &payload_id, &referred_payload_id, true) )
-            ||
-            ( http_matchers->get_appid_from_url(nullptr, url, &version,
-            referer, &client_id, &service_id,
-            &payload_id, &referred_payload_id, false) ) ) )
+        if (((http_matchers->get_appid_from_url(nullptr, url, &version,
+            referer, &client_id, &service_id, &payload_id,
+            &referred_payload_id, true)) ||
+            (http_matchers->get_appid_from_url(nullptr, url, &version,
+            referer, &client_id, &service_id, &payload_id,
+            &referred_payload_id, false))))
         {
             /* do not overwrite a previously-set client or service */
             if (client.get_id() <= APP_ID_NONE)
@@ -486,14 +497,14 @@ void AppIdSession::examine_rtmp_metadata(AppidChangeBits& change_bits)
 
 void AppIdSession::set_client_appid_data(AppId id, AppidChangeBits& change_bits, char* version)
 {
-    if ( id <= APP_ID_NONE || id == APP_ID_HTTP )
+    if (id <= APP_ID_NONE || id == APP_ID_HTTP)
         return;
 
     AppId cur_id = client.get_id();
-    if ( id != cur_id )
+    if (id != cur_id)
     {
-        if ( cur_id )
-            if ( app_info_mgr->get_priority(cur_id) > app_info_mgr->get_priority(id) )
+        if (cur_id)
+            if (app_info_mgr->get_priority(cur_id) > app_info_mgr->get_priority(id))
                 return;
 
         client.set_id(id);
@@ -516,10 +527,10 @@ void AppIdSession::set_referred_payload_app_id_data(AppId id, AppidChangeBits& c
 
 void AppIdSession::set_payload_appid_data(AppId id, AppidChangeBits& change_bits, char* version)
 {
-    if ( id <= APP_ID_NONE )
+    if (id <= APP_ID_NONE)
         return;
 
-    if ( app_info_mgr->get_priority(payload.get_id()) > app_info_mgr->get_priority(id) )
+    if (app_info_mgr->get_priority(payload.get_id()) > app_info_mgr->get_priority(id))
         return;
     payload.set_id(id);
     payload.set_version(version, change_bits);
@@ -543,7 +554,7 @@ void AppIdSession::set_service_appid_data(AppId id, AppidChangeBits& change_bits
 
 void AppIdSession::free_tls_session_data()
 {
-    if ( tsession )
+    if (tsession)
     {
         tsession->free_data();
         snort_free(tsession);
@@ -560,7 +571,7 @@ void AppIdSession::delete_session_data()
     snort_free(netbios_domain);
 
     AppIdServiceSubtype* rna_ss = subtype;
-    while ( rna_ss )
+    while (rna_ss)
     {
         subtype = rna_ss->next;
         snort_free(const_cast<char*>(rna_ss->service));
@@ -578,7 +589,7 @@ void AppIdSession::delete_session_data()
 int AppIdSession::add_flow_data(void* data, unsigned id, AppIdFreeFCN fcn)
 {
     AppIdFlowDataIter it = flow_data.find(id);
-    if ( it != flow_data.end() )
+    if (it != flow_data.end())
         return -1;
 
     AppIdFlowData* fd = new AppIdFlowData(data, id, fcn);
@@ -589,7 +600,7 @@ int AppIdSession::add_flow_data(void* data, unsigned id, AppIdFreeFCN fcn)
 void* AppIdSession::get_flow_data(unsigned id)
 {
     AppIdFlowDataIter it = flow_data.find(id);
-    if ( it != flow_data.end() )
+    if (it != flow_data.end())
         return it->second->fd_data;
     else
         return nullptr;
@@ -600,7 +611,7 @@ void* AppIdSession::remove_flow_data(unsigned id)
     void* data = nullptr;
 
     AppIdFlowDataIter it = flow_data.find(id);
-    if ( it != flow_data.end() )
+    if (it != flow_data.end())
     {
         data = it->second->fd_data;
         delete it->second;
@@ -612,7 +623,9 @@ void* AppIdSession::remove_flow_data(unsigned id)
 
 void AppIdSession::free_flow_data()
 {
-    for ( AppIdFlowDataIter it = flow_data.cbegin(); it != flow_data.cend(); ++it )
+    for (AppIdFlowDataIter it = flow_data.cbegin();
+         it != flow_data.cend();
+         ++it)
         delete it->second;
 
     flow_data.clear();
@@ -621,7 +634,7 @@ void AppIdSession::free_flow_data()
 void AppIdSession::free_flow_data_by_id(unsigned id)
 {
     AppIdFlowDataIter it = flow_data.find(id);
-    if ( it != flow_data.end() )
+    if (it != flow_data.end())
     {
         delete it->second;
         flow_data.erase(it);
@@ -630,8 +643,8 @@ void AppIdSession::free_flow_data_by_id(unsigned id)
 
 void AppIdSession::free_flow_data_by_mask(unsigned mask)
 {
-    for ( AppIdFlowDataIter it = flow_data.cbegin(); it != flow_data.cend(); )
-        if ( !mask || ( it->second->fd_id & mask ) )
+    for (AppIdFlowDataIter it = flow_data.cbegin(); it != flow_data.cend();)
+        if (!mask || (it->second->fd_id & mask))
         {
             delete it->second;
             it = flow_data.erase(it);
@@ -664,8 +677,8 @@ void AppIdSession::stop_rna_service_inspection(Packet* p, AppidSessionDirection
 
     service_disco_state = APPID_DISCO_STATE_FINISHED;
 
-    if ( payload.get_id() == APP_ID_NONE and
-        ( is_tp_appid_available() or get_session_flags(APPID_SESSION_NO_TPI) ) )
+    if (payload.get_id() == APP_ID_NONE and
+        (is_tp_appid_available() or get_session_flags(APPID_SESSION_NO_TPI)))
         payload.set_id(APP_ID_UNKNOWN);
 
     set_session_flags(APPID_SESSION_SERVICE_DETECTED);
@@ -676,10 +689,10 @@ AppId AppIdSession::pick_service_app_id()
 {
     AppId rval = APP_ID_NONE;
 
-    if ( common.flow_type != APPID_FLOW_TYPE_NORMAL )
+    if (common.flow_type != APPID_FLOW_TYPE_NORMAL)
         return APP_ID_NONE;
 
-    if ( is_service_detected() )
+    if (is_service_detected())
     {
         bool deferred = service.get_deferred() || tp_app_id_deferred;
 
@@ -715,7 +728,7 @@ AppId AppIdSession::pick_service_app_id()
 
 AppId AppIdSession::pick_only_service_app_id()
 {
-    if ( common.flow_type != APPID_FLOW_TYPE_NORMAL )
+    if (common.flow_type != APPID_FLOW_TYPE_NORMAL)
         return APP_ID_NONE;
 
     bool deferred = service.get_deferred() || tp_app_id_deferred;
@@ -736,7 +749,7 @@ AppId AppIdSession::pick_only_service_app_id()
 
 AppId AppIdSession::pick_misc_app_id()
 {
-    if ( common.flow_type != APPID_FLOW_TYPE_NORMAL )
+    if (common.flow_type != APPID_FLOW_TYPE_NORMAL)
         return APP_ID_NONE;
     if (misc_app_id > APP_ID_NONE)
         return misc_app_id;
@@ -745,7 +758,7 @@ AppId AppIdSession::pick_misc_app_id()
 
 AppId AppIdSession::pick_client_app_id()
 {
-    if ( common.flow_type != APPID_FLOW_TYPE_NORMAL )
+    if (common.flow_type != APPID_FLOW_TYPE_NORMAL)
         return APP_ID_NONE;
     if (client.get_id() > APP_ID_NONE)
         return client.get_id();
@@ -754,23 +767,34 @@ AppId AppIdSession::pick_client_app_id()
 
 AppId AppIdSession::pick_payload_app_id()
 {
-    if ( common.flow_type != APPID_FLOW_TYPE_NORMAL )
+    if (common.flow_type != APPID_FLOW_TYPE_NORMAL)
         return APP_ID_NONE;
 
     if (tp_payload_app_id_deferred)
         return tp_payload_app_id;
-    else if (payload.get_id() > APP_ID_NONE)
+    
+    if (payload.get_id() > APP_ID_NONE)
         return payload.get_id();
-    else if (tp_payload_app_id > APP_ID_NONE)
+
+    if (tp_payload_app_id > APP_ID_NONE)
         return tp_payload_app_id;
-    return encrypted.payload_id;
+
+    if (encrypted.payload_id > APP_ID_NONE)
+        return encrypted.payload_id;
+
+    /* APP_ID_UNKNOWN is valid only for HTTP type services */
+    if (payload.get_id() == APP_ID_UNKNOWN and
+        is_svc_http_type(service.get_id()))
+        return APP_ID_UNKNOWN;
+
+    return APP_ID_NONE;
 }
 
 AppId AppIdSession::pick_referred_payload_app_id()
 {
-    if ( common.flow_type != APPID_FLOW_TYPE_NORMAL )
+    if (common.flow_type != APPID_FLOW_TYPE_NORMAL)
         return APP_ID_NONE;
-    if ( referred_payload_app_id > APP_ID_NONE )
+    if (referred_payload_app_id > APP_ID_NONE)
         return referred_payload_app_id;
     return encrypted.referred_id;
 }
@@ -860,7 +884,7 @@ void AppIdSession::reset_session_data()
 
 bool AppIdSession::is_payload_appid_set()
 {
-    return ( payload.get_id() || tp_payload_app_id );
+    return (payload.get_id() || tp_payload_app_id);
 }
 
 void AppIdSession::clear_http_flags()
@@ -884,14 +908,14 @@ void AppIdSession::clear_http_data()
 
 AppIdHttpSession* AppIdSession::get_http_session()
 {
-    if ( !hsession )
+    if (!hsession)
         hsession = new AppIdHttpSession(*this);
     return hsession;
 }
 
 AppIdDnsSession* AppIdSession::get_dns_session()
 {
-    if ( !dsession )
+    if (!dsession)
         dsession = new AppIdDnsSession();
     return dsession;
 }
@@ -899,14 +923,14 @@ AppIdDnsSession* AppIdSession::get_dns_session()
 bool AppIdSession::is_tp_appid_done() const
 {
 #ifdef ENABLE_APPID_THIRD_PARTY
-    if ( TPLibHandler::have_tp() )
+    if (TPLibHandler::have_tp())
     {
         if (!tpsession)
             return false;
 
         unsigned state = tpsession->get_state();
-        return (state  == TP_STATE_CLASSIFIED || state == TP_STATE_TERMINATED
-               || state == TP_STATE_HA);
+        return (state == TP_STATE_CLASSIFIED || state == TP_STATE_TERMINATED ||
+            state == TP_STATE_HA);
     }
 #endif
 
@@ -928,15 +952,15 @@ bool AppIdSession::is_tp_processing_done() const
 bool AppIdSession::is_tp_appid_available() const
 {
 #ifdef ENABLE_APPID_THIRD_PARTY
-    if ( TPLibHandler::have_tp() )
+    if (TPLibHandler::have_tp())
     {
         if (!tpsession)
             return false;
 
         unsigned state = tpsession->get_state();
 
-        return (state == TP_STATE_CLASSIFIED || state == TP_STATE_TERMINATED
-               || state == TP_STATE_MONITORING);
+        return (state == TP_STATE_CLASSIFIED || state == TP_STATE_TERMINATED ||
+            state == TP_STATE_MONITORING);
     }
 #endif
 
index ed87858feada2c261d7f3e946fd7bf2347806fdd..19e4bb451c2e090e1e3b7840fcf85baece30577f 100644 (file)
@@ -101,7 +101,7 @@ public:
 
     ~AppIdFlowData()
     {
-        if ( fd_data && fd_free )
+        if (fd_data && fd_free)
             fd_free(fd_data);
     }
 
@@ -135,6 +135,8 @@ struct TlsSession
 
     char* get_tls_org_unit() { return tls_org_unit; }
 
+    bool get_tls_handshake_done() { return tls_handshake_done; }
+
     // Duplicate only if len > 0, otherwise simply set (i.e., own the argument)
     void set_tls_host(const char* new_tls_host, uint32_t len, AppidChangeBits& change_bits)
     {
@@ -165,6 +167,8 @@ struct TlsSession
             const_cast<char*>(new_tls_org_unit);
     }
 
+    void set_tls_handshake_done() { tls_handshake_done = true; }
+
     void free_data()
     {
         if (tls_host)
@@ -174,12 +178,14 @@ struct TlsSession
         if (tls_org_unit)
             snort_free(tls_org_unit);
         tls_host = tls_cname = tls_org_unit = nullptr;
+        tls_handshake_done = false;
     }
 
 private:
     char* tls_host = nullptr;
     char* tls_cname = nullptr;
     char* tls_org_unit = nullptr;
+    bool tls_handshake_done = false;
 };
 
 class AppIdSession : public snort::FlowData
@@ -335,14 +341,16 @@ public:
     void set_tp_payload_app_id(snort::Packet& p, AppidSessionDirection dir, AppId app_id, AppidChangeBits& change_bits);
 
     inline void set_tp_app_id(AppId app_id) {
-        if(tp_app_id != app_id) {
+        if (tp_app_id != app_id)
+        {
             tp_app_id = app_id;
             tp_app_id_deferred = app_info_mgr->get_app_info_flags(tp_app_id, APPINFO_FLAG_DEFER);
         }
     }
 
     inline void set_tp_payload_app_id(AppId app_id) {
-        if(tp_payload_app_id != app_id) {
+        if (tp_payload_app_id != app_id)
+        {
             tp_payload_app_id = app_id;
             tp_payload_app_id_deferred = app_info_mgr->get_app_info_flags(tp_payload_app_id, APPINFO_FLAG_DEFER_PAYLOAD);
         }
@@ -373,5 +381,24 @@ private:
     AppId tp_payload_app_id = APP_ID_NONE;
 };
 
+static inline bool is_svc_http_type(AppId serviceId)
+{
+    switch(serviceId)
+    {
+        case APP_ID_HTTP:
+        case APP_ID_HTTPS:
+        case APP_ID_FTPS:
+        case APP_ID_IMAPS:
+        case APP_ID_IRCS:
+        case APP_ID_LDAPS:
+        case APP_ID_NNTPS:
+        case APP_ID_POP3S:
+        case APP_ID_SMTPS:
+        case APP_ID_SSHELL:
+        case APP_ID_SSL:
+            return true;
+    }
+    return false;
+}
 #endif
 
index 32b142d55e9d2365a6d0ed87d7d522d535bf28c8..e3f2c5f2a9821a1fb51273407f813bf04cba567a 100644 (file)
@@ -81,10 +81,9 @@ struct MatchedSSLPatterns
 
 enum SSLState
 {
-    SSL_STATE_INITIATE,      /* Client initiates. */
-    SSL_STATE_CONNECTION,    /* Server responds... */
+    SSL_STATE_INITIATE,    // Client initiates.
+    SSL_STATE_CONNECTION,  // Server responds...
     SSL_STATE_HEADER,
-    SSL_STATE_DONE
 };
 
 struct ServiceSSLData
@@ -97,10 +96,10 @@ struct ServiceSSLData
     char* host_name;
     int host_name_strlen;
     /* While collecting certificates: */
-    unsigned certs_len;     /* (Total) length of certificate(s). */
-    uint8_t* certs_data;    /* Certificate(s) data (each proceeded by length (3 bytes)). */
-    int in_certs;           /* Currently collecting certificates? */
-    int certs_curr_len;     /* Current amount of collected certificate data. */
+    unsigned certs_len;   // (Total) length of certificate(s).
+    uint8_t* certs_data;  // Certificate(s) data (each proceeded by length (3 bytes)).
+    int in_certs;         // Currently collecting certificates?
+    int certs_curr_len;   // Current amount of collected certificate data.
     /* Data collected from certificates afterwards: */
     char* common_name;
     int common_name_strlen;
@@ -121,14 +120,16 @@ struct ServiceSSLCertificate
 
 #pragma pack(1)
 
-struct ServiceSSLV3Hdr    /* Actually a TLS Record. */
+/* Usually referred to as a TLS Record. */
+struct ServiceSSLV3Hdr
 {
     uint8_t type;
     uint16_t version;
     uint16_t len;
 };
 
-struct ServiceSSLV3Record    /* Actually a Handshake. */
+/* Usually referred to as a TLS Handshake. */
+struct ServiceSSLV3Record
 {
     uint8_t type;
     uint8_t length_msb;
@@ -141,12 +142,13 @@ struct ServiceSSLV3Record    /* Actually a Handshake. */
     } random;
 };
 
-struct ServiceSSLV3CertsRecord    /* Actually a Certificate(s) Handshake. */
+/* Usually referred to as a Certificate Handshake. */
+struct ServiceSSLV3CertsRecord
 {
     uint8_t type;
     uint8_t length_msb;
     uint16_t length;
-    uint8_t certs_len[3];    /* 3-byte length, network byte order. */
+    uint8_t certs_len[3];  // 3-byte length, network byte order.
     /* Certificate(s) follow.
      * For each:
      *  - Length: 3 bytes
@@ -243,7 +245,7 @@ static int ssl_detector_create_matcher(SearchTool** matcher, DetectorSSLCertPatt
 
     patternIndex = &size;
 
-    /* Add patterns from Lua API */
+    /* Add patterns from Lua API. */
     for (element = list; element; element = element->next)
     {
         (*matcher)->add(element->dpattern->pattern,
@@ -330,8 +332,8 @@ SslServiceDetector::SslServiceDetector(ServiceDiscovery* sd)
     handler->register_detector(name, this, proto);
 }
 
-
-static void ssl_free(void* ss)    /* AppIdFreeFCN */
+/* AppIdFreeFCN */
+static void ssl_free(void* ss)
 {
     ServiceSSLData* ss_tmp = (ServiceSSLData*)ss;
     snort_free(ss_tmp->certs_data);
@@ -419,7 +421,7 @@ static void parse_client_initiation(const uint8_t* data, uint16_t size, ServiceS
     if (size < length)
         return;
 
-    // We need at least type (2 bytes) and length (2 bytes) fields in the extension
+    /* We need at least type (2 bytes) and length (2 bytes) in the extension. */
     while (length >= 4)
     {
         const ServiceSSLV3ExtensionServerName* ext = (const ServiceSSLV3ExtensionServerName*)data;
@@ -436,7 +438,7 @@ static void parse_client_initiation(const uint8_t* data, uint16_t size, ServiceS
             const uint8_t* str = data
                 + offsetof(ServiceSSLV3ExtensionServerName, string_length)
                 + sizeof(ext->string_length);
-            ss->host_name = (char*)snort_alloc(len + 1);     /* Plus nullptr term. */
+            ss->host_name = (char*)snort_alloc(len + 1);  //Plus nullptr term.
             memcpy(ss->host_name, str, len);
             ss->host_name[len] = '\0';
             ss->host_name_strlen = len;
@@ -477,8 +479,9 @@ static bool parse_certificates(ServiceSSLData* ss)
                 success = false;
                 break;
             }
+            /* d2i_X509() increments the data ptr for us. */
             X509* cert = d2i_X509(nullptr, (const unsigned char**)&data, cert_len);
-            len -= cert_len;    /* Above call increments data pointer already. */
+            len -= cert_len;
             if (!cert)
             {
                 success = false;
@@ -528,20 +531,21 @@ static bool parse_certificates(ServiceSSLData* ss)
             }
         }
 
-        if ( success )
+        if (success)
         {
             char* common_name = nullptr;
             if (common_name_tot_len)
             {
-                common_name_tot_len += num_certs;    /* Space between each and terminator at end.
-                                                        */
+                /* Add a space for each and the terminator at the end. */
+                common_name_tot_len += num_certs;
                 common_name = (char*)snort_calloc(common_name_tot_len);
             }
 
             char* org_name = nullptr;
             if (org_name_tot_len)
             {
-                org_name_tot_len += num_certs;    /* Space between each and terminator at end. */
+                /* Add a space for each and the terminator at the end. */
+                org_name_tot_len += num_certs;
                 org_name = (char*)snort_calloc(org_name_tot_len);
             }
 
@@ -583,9 +587,9 @@ static bool parse_certificates(ServiceSSLData* ss)
                 *org_name_ptr     = '\0';
             }
             ss->common_name        = common_name;
-            ss->common_name_strlen = common_name_tot_len - 1;    /* Minus terminator. */
+            ss->common_name_strlen = common_name_tot_len - 1;  // Minus terminator.
             ss->org_name           = org_name;
-            ss->org_name_strlen    = org_name_tot_len - 1;       /* Minus terminator. */
+            ss->org_name_strlen    = org_name_tot_len - 1;  // Minus terminator.
         }
 
         while (certs_head)
@@ -597,7 +601,7 @@ static bool parse_certificates(ServiceSSLData* ss)
             snort_free(certs_curr);
         }
 
-        /* No longer need entire certificates.  We have what we came for. */
+        /* No longer need entire certificates. We have what we came for. */
         snort_free(ss->certs_data);
         ss->certs_data = nullptr;
         ss->certs_len  = 0;
@@ -648,34 +652,34 @@ int SslServiceDetector::validate(AppIdDiscoveryArgs& args)
     switch (ss->state)
     {
     case SSL_STATE_CONNECTION:
-        ss->state = SSL_STATE_DONE;
         pct = (const ServiceSSLPCTHdr*)data;
         hdr2 = (const ServiceSSLV2Hdr*)data;
         hdr3 = (const ServiceSSLV3Hdr*)data;
+
+        /* SSL PCT header? */
         if (size >= sizeof(ServiceSSLPCTHdr) && pct->len >= 0x80 &&
             pct->type == PCT_SERVER_HELLO && ntohs(pct->version) == 0x8001)
         {
             goto success;
         }
-        if (size >= sizeof(ServiceSSLV2Hdr) && hdr2->len >= 0x80 &&
-            hdr2->type == SSL2_SERVER_HELLO && !(hdr2->cert & 0xFE))
+
+        /* SSL v2 header? */
+        if (size >= sizeof(ServiceSSLV2Hdr) &&
+            hdr2->len >= 0x80 &&
+            hdr2->type == SSL2_SERVER_HELLO &&
+            !(hdr2->cert & 0xFE))
         {
-            switch (ntohs(hdr2->version))
+            uint16_t h2v = ntohs(hdr2->version);
+            if ((h2v == 0x0002 || h2v == 0x0300 ||
+                h2v == 0x0301 || h2v == 0x0303) &&
+                !(hdr2->cipher_len % 3))
             {
-            case 0x0002:
-            case 0x0300:
-            case 0x0301:
-            case 0x0303:
-                break;
-            default:
-                goto not_v2;
+                goto success;
             }
-            if (hdr2->cipher_len % 3)
-                goto not_v2;
-
-            goto success;
-not_v2:     ;
         }
+
+        /* Its probably an SSLv3, TLS 1.2, or TLS 1.3 header.
+           First record must be a handshake (type 22). */
         if (size < sizeof(ServiceSSLV3Hdr) ||
             hdr3->type != SSL_HANDSHAKE ||
             (ntohs(hdr3->version) != 0x0300 &&
@@ -701,8 +705,6 @@ not_v2:     ;
         ss->tot_length = ntohs(hdr3->len);
         ss->length = ntohs(rec->length) +
             offsetof(ServiceSSLV3Record, version);
-        if (size == ss->length)
-            goto success;                        /* Just a Server Hello. */
         if (ss->tot_length < ss->length)
             goto fail;
         ss->tot_length -= ss->length;
@@ -714,32 +716,36 @@ not_v2:     ;
         ss->pos = 0;
     /* fall through */
     case SSL_STATE_HEADER:
-        ss->state = SSL_STATE_DONE;
         while (size > 0)
         {
             if (!ss->pos)
             {
                 /* Need to move onto (and past) next header (i.e., record) if
-                 * previous was completely consumed. */
+                   previous was completely consumed. */
                 if (ss->tot_length == 0)
                 {
                     hdr3 = (const ServiceSSLV3Hdr*)data;
                     ver = ntohs(hdr3->version);
                     if (size < sizeof(ServiceSSLV3Hdr) ||
                         (hdr3->type != SSL_HANDSHAKE &&
-                        hdr3->type != SSL_CHANGE_CIPHER ) ||
+                         hdr3->type != SSL_CHANGE_CIPHER &&
+                         hdr3->type != SSL_APPLICATION_DATA) ||
                         (ver != 0x0300 &&
-                        ver != 0x0301 &&
-                        ver != 0x0302 &&
-                        ver != 0x0303))
+                         ver != 0x0301 &&
+                         ver != 0x0302 &&
+                         ver != 0x0303))
                     {
                         goto fail;
                     }
-                    if (hdr3->type == SSL_CHANGE_CIPHER)
-                        goto success;
                     data += sizeof(ServiceSSLV3Hdr);
                     size -= sizeof(ServiceSSLV3Hdr);
                     ss->tot_length = ntohs(hdr3->len);
+
+                    if (hdr3->type == SSL_CHANGE_CIPHER ||
+                        hdr3->type == SSL_APPLICATION_DATA)
+                    {
+                        goto success;
+                    }
                 }
 
                 rec = (const ServiceSSLV3Record*)data;
@@ -761,7 +767,7 @@ not_v2:     ;
                         {
                             /* Will have to get more next time around. */
                             ss->in_certs = 1;
-                            // Skip over header to data
+                            /* Skip over header to data */
                             ss->certs_curr_len = size - sizeof(ServiceSSLV3CertsRecord);
                             memcpy(ss->certs_data, data + sizeof(ServiceSSLV3CertsRecord),
                                 ss->certs_curr_len);
@@ -773,7 +779,7 @@ not_v2:     ;
                             ss->certs_curr_len = ss->certs_len;
                             memcpy(ss->certs_data, data + sizeof(ServiceSSLV3CertsRecord),
                                 ss->certs_curr_len);
-                            goto success;    /* We got everything we need. */
+                            break;
                         }
                     }
                 /* fall through */
@@ -781,8 +787,6 @@ not_v2:     ;
                 case SSL_SERVER_CERT_REQ:
                     ss->length = ntohs(rec->length) +
                         offsetof(ServiceSSLV3Record, version);
-                    if (size == ss->length)
-                        goto success;
                     if (ss->tot_length < ss->length)
                         goto fail;
                     ss->tot_length -= ss->length;
@@ -797,7 +801,6 @@ not_v2:     ;
                         size -= ss->length;
                         ss->pos = 0;
                     }
-                    ss->state = SSL_STATE_HEADER;
                     break;
                 case SSL_SERVER_HELLO_DONE:
                     if (rec->length)
@@ -817,18 +820,18 @@ not_v2:     ;
                     if (size < (ss->certs_len - ss->certs_curr_len))
                     {
                         /* Will have to get more next time around. */
-                        memcpy(ss->certs_data + ss->certs_curr_len, data, size);
-                        ss->in_certs        = 1;
+                        memcpy(ss->certs_data + ss->certs_curr_len,
+                               data, size);
+                        ss->in_certs = 1;
                         ss->certs_curr_len += size;
                     }
                     else
                     {
                         /* Can get it all this time. */
-                        memcpy(ss->certs_data + ss->certs_curr_len, data, ss->certs_len -
-                            ss->certs_curr_len);
-                        ss->in_certs       = 0;
+                        memcpy(ss->certs_data + ss->certs_curr_len,
+                               data, ss->certs_len - ss->certs_curr_len);
+                        ss->in_certs = 0;
                         ss->certs_curr_len = ss->certs_len;
-                        goto success;    /* We got everything we need. */
                     }
                 }
 
@@ -843,7 +846,6 @@ not_v2:     ;
                     size -= ss->length - ss->pos;
                     ss->pos = 0;
                 }
-                ss->state = SSL_STATE_HEADER;
             }
         }
         break;
@@ -888,9 +890,8 @@ success:
         }
         else if (ss->common_name)
         {
-            // use common name (from server) if we didn't see host name (from client)
-            args.asd.tsession->set_tls_host(ss->common_name, ss->common_name_strlen,
-                args.change_bits);
+            /* Use common name (from server) if we didn't get host name (from client). */
+            args.asd.tsession->set_tls_host(ss->common_name, ss->common_name_strlen, args.change_bits);
             args.asd.scan_flags |= SCAN_SSL_HOST_FLAG;
         }
 
@@ -903,6 +904,7 @@ success:
             args.asd.tsession->set_tls_org_unit(ss->org_name, 0);
 
         ss->host_name = ss->common_name = ss->org_name = nullptr;
+        args.asd.tsession->set_tls_handshake_done();
     }
     return add_service(args.change_bits, args.asd, args.pkt, args.dir,
         getSslServiceAppId(args.pkt->ptrs.sp));
@@ -922,7 +924,7 @@ AppId getSslServiceAppId(short srcPort)
         return APP_ID_SMTPS;
     case 563:
         return APP_ID_NNTPS;
-    case 585:  /*Currently 585 is de-registered at IANA but old implementation may still use it. */
+    case 585:  // Currently 585 is de-registered at IANA but old implementation may still use it.
     case 993:
         return APP_ID_IMAPS;
     case 614:
@@ -991,14 +993,15 @@ static int ssl_scan_patterns(SearchTool* matcher, const uint8_t* data, size_t si
     best_match = nullptr;
     while (mp)
     {
-        //  Only patterns that match start of payload,
-        //  or patterns starting with '.'
-        //  or patterns following '.' in payload are considered a match.
+        /*  Only patterns that match start of payload,
+            or patterns starting with '.'
+            or patterns following '.' in payload are considered a match. */
         if (mp->match_start_pos == 0 ||
             *mp->mpattern->pattern == '.' ||
             data[mp->match_start_pos-1] == '.')
         {
-            if (!best_match || mp->mpattern->pattern_size > best_match->pattern_size)
+            if (!best_match ||
+                mp->mpattern->pattern_size > best_match->pattern_size)
             {
                 best_match = mp->mpattern;
             }
@@ -1031,14 +1034,14 @@ static int ssl_scan_patterns(SearchTool* matcher, const uint8_t* data, size_t si
 
 int ssl_scan_hostname(const uint8_t* hostname, size_t size, AppId& client_id, AppId& payload_id)
 {
-    return ssl_scan_patterns(service_ssl_config.ssl_host_matcher, hostname, size, client_id,
-        payload_id);
+    return ssl_scan_patterns(service_ssl_config.ssl_host_matcher,
+                             hostname, size, client_id, payload_id);
 }
 
 int ssl_scan_cname(const uint8_t* common_name, size_t size, AppId& client_id, AppId& payload_id)
 {
-    return ssl_scan_patterns(service_ssl_config.ssl_cname_matcher, common_name, size, client_id,
-        payload_id);
+    return ssl_scan_patterns(service_ssl_config.ssl_cname_matcher,
+                             common_name, size, client_id, payload_id);
 }
 
 void service_ssl_clean()
@@ -1077,16 +1080,14 @@ static int ssl_add_pattern(DetectorSSLCertPattern** list, uint8_t* pattern_str,
 
 int ssl_add_cert_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id)
 {
-    return ssl_add_pattern(&service_ssl_config.DetectorSSLCertPatternList, pattern_str,
-        pattern_size,
-        type, app_id);
+    return ssl_add_pattern(&service_ssl_config.DetectorSSLCertPatternList,
+                           pattern_str, pattern_size, type, app_id);
 }
 
 int ssl_add_cname_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id)
 {
-    return ssl_add_pattern(&service_ssl_config.DetectorSSLCnamePatternList, pattern_str,
-        pattern_size,
-        type, app_id);
+    return ssl_add_pattern(&service_ssl_config.DetectorSSLCnamePatternList,
+                           pattern_str, pattern_size, type, app_id);
 }
 
 static void ssl_patterns_free(DetectorSSLCertPattern** list)
@@ -1120,8 +1121,8 @@ bool setSSLSquelch(Packet* p, int type, AppId appId)
     const SfIp* dip = p->ptrs.ip_api.get_dst();
     const SfIp* sip = p->ptrs.ip_api.get_src();
 
-    // FIXIT-H: Passing appId to create_future_session() is incorrect. We
-    // need to pass the snort_protocol_id associated with appId.
+    /* FIXIT-H: Passing appId to create_future_session() is incorrect. We
+       need to pass the snort_protocol_id associated with appId. */
     AppIdSession* asd = AppIdSession::create_future_session(
         p, sip, 0, dip, p->ptrs.dp, IpProtocol::TCP, appId, 0);
 
@@ -1132,16 +1133,13 @@ bool setSSLSquelch(Packet* p, int type, AppId appId)
         case 1:
             asd->payload.set_id(appId);
             break;
-
         case 2:
             asd->client.set_id(appId);
             asd->client_disco_state = APPID_DISCO_STATE_FINISHED;
             break;
-
         default:
             return false;
         }
-
         return true;
     }
     else