]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Merge pull request #1186 in SNORT/snort3 from service_disco_state to master
authorMike Stepanek (mstepane) <mstepane@cisco.com>
Thu, 12 Apr 2018 18:01:15 +0000 (14:01 -0400)
committerMike Stepanek (mstepane) <mstepane@cisco.com>
Thu, 12 Apr 2018 18:01:15 +0000 (14:01 -0400)
Squashed commit of the following:

commit 3719339c89b9ba9cfd56393da18a8895a3e6c290
Author: Masud Hasan <mashasan@cisco.com>
Date:   Sat Apr 7 12:00:47 2018 -0400

    appid: Fixing service discovery states

15 files changed:
src/network_inspectors/appid/appid_detector.cc
src/network_inspectors/appid/appid_detector.h
src/network_inspectors/appid/appid_discovery.cc
src/network_inspectors/appid/appid_session.cc
src/network_inspectors/appid/client_plugins/client_discovery.cc
src/network_inspectors/appid/host_port_app_cache.cc
src/network_inspectors/appid/service_plugins/service_detector.cc
src/network_inspectors/appid/service_plugins/service_discovery.cc
src/network_inspectors/appid/service_plugins/service_discovery.h
src/network_inspectors/appid/service_state.cc
src/network_inspectors/appid/service_state.h
src/network_inspectors/appid/test/CMakeLists.txt
src/network_inspectors/appid/test/appid_detector_test.cc
src/network_inspectors/appid/test/appid_mock_definitions.h
src/network_inspectors/appid/test/service_state_test.cc [new file with mode: 0644]

index b825d92fba42457df16e17988e1e21a941eac3bb..e06626fe6452db00a6d93dee2125e9fff812fcfc 100644 (file)
 
 #include "appid_detector.h"
 
+#include "protocols/packet.h"
+
+#include "app_info_table.h"
 #include "appid_config.h"
 #include "appid_http_session.h"
-#include "app_info_table.h"
 #include "lua_detector_api.h"
-#include "protocols/packet.h"
 
 using namespace snort;
 
@@ -97,3 +98,31 @@ void AppIdDetector::add_app(AppIdSession& asd, AppId service_id, AppId client_id
     asd.client_inferred_service_id = service_id;
     asd.client.set_id(client_id);
 }
+
+const char* AppIdDetector::get_code_string(APPID_STATUS_CODE code) const
+{
+    switch (code)
+    {
+    case APPID_SUCCESS:
+        return "success";
+    case APPID_INPROCESS:
+        return "inprocess";
+    case APPID_NEED_REASSEMBLY:
+        return "need-reassembly";
+    case APPID_NOT_COMPATIBLE:
+        return "not-compatible";
+    case APPID_INVALID_CLIENT:
+        return "invalid-client";
+    case APPID_REVERSED:
+        return "appid-reversed";
+    case APPID_NOMATCH:
+        return "no-match";
+    case APPID_ENULL:
+        return "error-null";
+    case APPID_EINVALID:
+        return "error-invalid";
+    case APPID_ENOMEM:
+        return "error-memory";
+    }
+    return "unknown code";
+}
index a12db6c1ff538ba7af35ce4bafc9c4b1e5504d1d..0bce0d1ee428a28b80419262b4eea4a8ec54e3ae 100644 (file)
 #define APPID_DETECTOR_H
 
 #include <vector>
+
+#include "flow/flow.h"
+
 #include "appid_discovery.h"
-#include "application_ids.h"
 #include "appid_session.h"
+#include "application_ids.h"
 #include "service_state.h"
-#include "flow/flow.h"
 
 class AppIdConfig;
 class LuaStateDescriptor;
@@ -85,6 +87,8 @@ public:
     const AppIdConfig* config = nullptr;
 };
 
+// These numbers are what Lua (VDB/ODP) gives us. If these numbers are ever changed,
+// we need to change get_code_string() code to avoid misinterpretations.
 enum APPID_STATUS_CODE
 {
     APPID_SUCCESS = 0,
@@ -117,6 +121,7 @@ public:
     virtual void add_payload(AppIdSession&, AppId);
     virtual void add_app(AppIdSession&, AppId, AppId, const char*);
     virtual void finalize() {}
+    const char* get_code_string(APPID_STATUS_CODE) const;
 
     const std::string& get_name() const
     { return name; }
index 6c9392679d55cd08f9729e990d9c23ab634d334a..b879be612d4c1ed6ccf5a9c8f29a2706240db930 100644 (file)
 #include "protocols/packet.h"
 #include "protocols/tcp.h"
 
+#include "app_forecast.h"
 #include "appid_config.h"
 #include "appid_debug.h"
 #include "appid_detector.h"
-#include "app_forecast.h"
 #include "appid_dns_session.h"
 #include "appid_http_session.h"
 #include "appid_inspector.h"
@@ -678,7 +678,13 @@ void AppIdDiscovery::do_application_discovery(Packet* p, AppIdInspector& inspect
     if ( !asd || asd->common.flow_type == APPID_FLOW_TYPE_TMP )
     {
         asd = AppIdSession::allocate_session(p, protocol, direction, inspector);
-        if (appidDebug->is_active())
+        if (p->flow->get_session_flags() & SSNFLAG_MIDSTREAM)
+        {
+            asd->set_session_flags(APPID_SESSION_MID);
+            if (appidDebug->is_active())
+                LogMessage("AppIdDbg %s New AppId mid-stream session\n", appidDebug->get_debug_session());
+        }
+        else if (appidDebug->is_active())
             LogMessage("AppIdDbg %s New AppId session\n", appidDebug->get_debug_session());
     }
 
@@ -708,6 +714,7 @@ void AppIdDiscovery::do_application_discovery(Packet* p, AppIdInspector& inspect
         return;
     }
 
+    // FIXIT-H - Bring APPID_SESSION_OOO related changes from snort2 ASAP for performance reason
     if (p->packet_flags & PKT_STREAM_ORDER_BAD)
         asd->set_session_flags(APPID_SESSION_OOO);
     else if ( p->is_tcp() && p->ptrs.tcph )
index b7a025003277194e5a429aee09fb73a2d932f8af..7ddec94a52c38c01507661c04e6df780ccc5491b 100644 (file)
@@ -117,8 +117,8 @@ AppIdSession::~AppIdSession()
             stats_mgr->update(*this);
 
         // fail any service detection that is in process for this flow
-        if (flow &&
-            !get_session_flags(APPID_SESSION_SERVICE_DETECTED | APPID_SESSION_UDP_REVERSED) )
+        if (!get_session_flags(APPID_SESSION_SERVICE_DETECTED | APPID_SESSION_UDP_REVERSED |
+            APPID_SESSION_MID | APPID_SESSION_OOO) and flow)
         {
             ServiceDiscoveryState* sds =
                 AppIdServiceState::get(&service_ip, protocol, service_port, is_decrypted());
index 68d0480173293b93220dd1c8ad32900b9ef6d435..0bc299c546b3c958dd131921e3480c48aa9ab990 100644 (file)
@@ -295,8 +295,9 @@ int ClientDiscovery::exec_client_detectors(AppIdSession& asd, Packet* p, int dir
         AppIdDiscoveryArgs disco_args(p->data, p->dsize, direction, asd, p);
         ret = asd.client_detector->validate(disco_args);
         if (appidDebug->is_active())
-            LogMessage("AppIdDbg %s %s client detector returned %d\n",
-                appidDebug->get_debug_session(), asd.client_detector->get_name().c_str(), ret);
+            LogMessage("AppIdDbg %s %s client detector %s (%d)\n",
+                appidDebug->get_debug_session(), asd.client_detector->get_name().c_str(),
+                asd.client_detector->get_code_string((APPID_STATUS_CODE)ret), ret);
     }
     else
     {
@@ -305,8 +306,9 @@ int ClientDiscovery::exec_client_detectors(AppIdSession& asd, Packet* p, int dir
             AppIdDiscoveryArgs disco_args(p->data, p->dsize, direction, asd, p);
             int result = kv->second->validate(disco_args);
             if (appidDebug->is_active())
-                LogMessage("AppIdDbg %s %s client detector returned %d\n",
-                    appidDebug->get_debug_session(), kv->second->get_name().c_str(), result);
+                LogMessage("AppIdDbg %s %s client candidate %s (%d)\n",
+                    appidDebug->get_debug_session(), kv->second->get_name().c_str(),
+                    kv->second->get_code_string((APPID_STATUS_CODE)result), result);
 
             if (result == APPID_SUCCESS)
             {
index da1f69dfacb0fd0158495ac04867287458051817..1ac61974449698d0ff602e435755f870542e6a8c 100644 (file)
@@ -82,7 +82,7 @@ void HostPortCache::terminate()
 {
     if (host_port_cache)
     {
-        host_port_cache->empty();
+        host_port_cache->clear();
         delete host_port_cache;
         host_port_cache = nullptr;
     }
index 0d8438789696cd148c2360ba1350a369590d0eb3..09cb9087e554cd7e7ece76e7105f89cd6908fd70 100644 (file)
 
 #include "service_detector.h"
 
-#include "appid_config.h"
+#include "log/messages.h"
+#include "protocols/packet.h"
+#include "sfip/sf_ip.h"
+
 #include "app_info_table.h"
+#include "appid_config.h"
 #include "appid_session.h"
 #include "lua_detector_api.h"
 
-#include "protocols/packet.h"
-#include "log/messages.h"
-#include "sfip/sf_ip.h"
-
 using namespace snort;
 
 static THREAD_LOCAL unsigned service_module_index = 0;
@@ -126,10 +126,7 @@ int ServiceDetector::update_service_data(AppIdSession& asd, const Packet* pkt, i
 
     asd.service_ip = *ip;
     asd.service_port = port;
-    ServiceDiscoveryState* sds = AppIdServiceState::get(ip, asd.protocol, port,
-        asd.is_decrypted());
-    if ( !sds )
-        sds = AppIdServiceState::add(ip, asd.protocol, port, asd.is_decrypted());
+    ServiceDiscoveryState* sds = AppIdServiceState::add(ip, asd.protocol, port, asd.is_decrypted());
     sds->set_service_id_valid(this);
 
     return APPID_SUCCESS;
index ede8479c2a8420aac4b773391f31dbc7c35fe0e5..04fcd36ed2028d53cbefa3f75613371eff962607 100644 (file)
@@ -346,8 +346,7 @@ void ServiceDiscovery::get_port_based_services(IpProtocol protocol, uint16_t por
  * been specified (service_detector).  Basically, this function handles going
  * through the main port/pattern search (and returning which detector to add
  * next to the list of detectors to try (even if only 1)). */
-void ServiceDiscovery::get_next_service(const Packet* p, const int dir,
-    AppIdSession& asd, ServiceDiscoveryState* sds)
+void ServiceDiscovery::get_next_service(const Packet* p, const int dir, AppIdSession& asd)
 {
     auto proto = asd.protocol;
 
@@ -365,8 +364,8 @@ void ServiceDiscovery::get_next_service(const Packet* p, const int dir,
          * first with UDP reversed services before moving onto pattern matches. */
         if (dir == APP_ID_FROM_INITIATOR)
         {
-            if ( !asd.get_session_flags(APPID_SESSION_ADDITIONAL_PACKET)
-                && (proto == IpProtocol::UDP) && !asd.tried_reverse_service )
+            if ( (proto == IpProtocol::UDP) and !asd.tried_reverse_service and
+                 !asd.get_session_flags(APPID_SESSION_ADDITIONAL_PACKET) )
             {
                 asd.tried_reverse_service = true;
                 ServiceDiscoveryState* rsds = AppIdServiceState::get(p->ptrs.ip_api.get_src(),
@@ -389,7 +388,7 @@ void ServiceDiscovery::get_next_service(const Packet* p, const int dir,
         else
         {
             match_by_pattern(asd, p, proto);
-            sds->set_state(SERVICE_ID_STATE::SEARCHING_BRUTE_FORCE);
+            asd.service_search_state = SESSION_SERVICE_SEARCH_STATE::PENDING;
             return;
         }
     }
@@ -397,10 +396,10 @@ void ServiceDiscovery::get_next_service(const Packet* p, const int dir,
 
 int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p, int dir)
 {
-    const SfIp* ip = nullptr;
-    int ret = APPID_NOMATCH;
-    uint16_t port = 0;
-    bool got_incompatible_services = false;
+    ServiceDiscoveryState* sds = nullptr;
+    bool got_brute_force = false;
+    const SfIp* ip;
+    uint16_t port;
 
     /* Get packet info. */
     auto proto = asd.protocol;
@@ -421,64 +420,65 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p, int dir)
             ip   = p->ptrs.ip_api.get_dst();
             port = p->ptrs.dp;
         }
+        asd.service_ip = *ip;
+        asd.service_port = port;
     }
 
-    ServiceDiscoveryState* sds = AppIdServiceState::get(ip, proto, port, asd.is_decrypted());
-    if ( !sds )
-        sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
-
     if ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::START )
     {
         asd.service_search_state = SESSION_SERVICE_SEARCH_STATE::PORT;
+        sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+        sds->set_reset_time(0);
+        SERVICE_ID_STATE sds_state = sds->get_state();
 
-        if ( sds->get_state() == SERVICE_ID_STATE::FAILED )
+        if ( sds_state == SERVICE_ID_STATE::FAILED )
         {
             if (appidDebug->is_active())
-                LogMessage("AppIdDbg %s Failed state, no service match\n", appidDebug->get_debug_session());
-            fail_service(asd, p, dir, nullptr);
+                LogMessage("AppIdDbg %s No service match, failed state\n", appidDebug->get_debug_session());
+            fail_service(asd, p, dir, nullptr, sds);
             return APPID_NOMATCH;
         }
 
         if ( !asd.service_detector )
         {
             /* If a valid service already exists in host tracker, give it a try. */
-            if ( sds->get_state() == SERVICE_ID_STATE::VALID )
+            if ( sds_state == SERVICE_ID_STATE::VALID )
                 asd.service_detector = sds->get_service();
-
-            // FIXIT-H: The following logic sets asd.service_detector to sds.service even if
-            // (state != SEARCHING_BRUTE_FORCE && state != VALID). Need to verify if this is really
-            // intended as this is diverged from Snort2 logic. Also, when the walking of brute-force
-            // list is done, we should not do port-pattern again -- which is what this implementation
-            // is doing! We should do port-pattern only if (!bruteForceDone). See Snort 2.9.11-125 logic.
-
             /* If we've gotten to brute force, give next detector a try. */
-            else if ( asd.service_candidates.empty() )
+            else if ( sds_state == SERVICE_ID_STATE::SEARCHING_BRUTE_FORCE and
+                      asd.service_candidates.empty() )
             {
                 asd.service_detector = sds->select_detector_by_brute_force(proto);
+                got_brute_force = true;
             }
         }
     }
 
+    int ret = APPID_NOMATCH;
+    bool got_incompatible_service = false;
+    bool got_fail_service = false;
     AppIdDiscoveryArgs args(p->data, p->dsize, dir, asd, p);
     /* If we already have a service to try, then try it out. */
     if ( asd.service_detector )
     {
         ret = asd.service_detector->validate(args);
         if (ret == APPID_NOT_COMPATIBLE)
-            got_incompatible_services = true;
+            got_incompatible_service = true;
+        asd.service_search_state = SESSION_SERVICE_SEARCH_STATE::PENDING;
         if (appidDebug->is_active())
-            LogMessage("AppIdDbg %s %s returned %d\n", appidDebug->get_debug_session(),
-                asd.service_detector->get_name().c_str(), ret);
+            LogMessage("AppIdDbg %s %s service detector %s (%d)\n",
+                appidDebug->get_debug_session(), asd.service_detector->get_name().c_str(),
+                asd.service_detector->get_code_string((APPID_STATUS_CODE)ret), ret);
     }
     /* Try to find detectors based on ports and patterns. */
-    else
+    else if (!got_brute_force)
     {
         /* See if we've got more detector(s) to add to the candidate list. */
         if ( ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::PORT )
             || ( ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::PATTERN )
             && (dir == APP_ID_FROM_RESPONDER ) ) )
         {
-            get_next_service(p, dir, asd, sds);
+            get_next_service(p, dir, asd);
         }
 
         /* Run all of the detectors that we currently have. */
@@ -490,52 +490,52 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p, int dir)
             int result;
 
             result = service->validate(args);
-            if ( result == APPID_NOT_COMPATIBLE )
-                got_incompatible_services = true;
             if ( appidDebug->is_active() )
-                LogMessage("AppIdDbg %s %s returned %d\n",
-                    appidDebug->get_debug_session(), service->get_name().c_str(), result);
+                LogMessage("AppIdDbg %s %s service candidate %s (%d)\n",
+                    appidDebug->get_debug_session(), service->get_name().c_str(),
+                    service->get_code_string((APPID_STATUS_CODE)result), result);
 
             if ( result == APPID_SUCCESS )
             {
                 ret = APPID_SUCCESS;
                 asd.service_detector = service;
-                asd.service_candidates.empty();
+                asd.service_candidates.clear();
                 break;    /* done */
             }
-            else if (result != APPID_INPROCESS)    /* fail */
-                asd.service_candidates.erase(it);
             else
-                ++it;
-        }
-
-        /* If we tried everything and found nothing, then fail. */
-        if ( ret != APPID_SUCCESS )
-        {
-            if ( ( asd.service_candidates.empty() )
-                && ( sds->get_state() == SERVICE_ID_STATE::SEARCHING_BRUTE_FORCE ) )
             {
-                fail_service(asd, p, dir, nullptr);
-                ret = APPID_NOMATCH;
+                if ( result == APPID_NOT_COMPATIBLE )
+                    got_incompatible_service = true;
+                if (result != APPID_INPROCESS)    /* fail */
+                    it = asd.service_candidates.erase(it);
+                else
+                    ++it;
             }
         }
-    }
 
-    if ( asd.service_detector )
-    {
-        sds->set_reset_time(0);
+        /* If we tried everything and found nothing, then fail. */
+        if ( asd.service_candidates.empty() and ret != APPID_SUCCESS and
+             ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::PENDING ) )
+            got_fail_service = true;
     }
-    else if ( dir == APP_ID_FROM_RESPONDER )    // bidirectional exchange unknown service
+
+    /* Failed all candidates, or no detector identified after seeing bidirectional exchange */
+    if ( got_fail_service or ( ( ret != APPID_INPROCESS ) and
+         !asd.service_detector and ( dir == APP_ID_FROM_RESPONDER ) ) )
     {
+        if (!sds)
+            sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
         if (appidDebug->is_active())
-            LogMessage("AppIdDbg %s No service detector\n", appidDebug->get_debug_session());
-
-        fail_service(asd, p, dir, nullptr);
+            LogMessage("AppIdDbg %s No service %s\n", appidDebug->get_debug_session(),
+                got_fail_service? "candidate" : "detector");
+        got_fail_service = true;
+        fail_service(asd, p, dir, nullptr, sds);
         ret = APPID_NOMATCH;
     }
 
     /* Handle failure exception cases in states. */
-    if ( ( ret != APPID_INPROCESS ) && ( ret != APPID_SUCCESS ) )
+    if ( ( ( got_fail_service and !got_brute_force ) or got_incompatible_service ) and
+         ( ret != APPID_INPROCESS ) and ( ret != APPID_SUCCESS ) )
     {
         const SfIp* tmp_ip;
         if (dir == APP_ID_FROM_RESPONDER)
@@ -543,7 +543,10 @@ int ServiceDiscovery::identify_service(AppIdSession& asd, Packet* p, int dir)
         else
             tmp_ip = p->ptrs.ip_api.get_src();
 
-        if (got_incompatible_services)
+        if (!sds)
+            sds = AppIdServiceState::add(ip, proto, port, asd.is_decrypted());
+
+        if (got_incompatible_service)
             sds->update_service_incompatiable(tmp_ip);
 
         sds->set_service_id_failed(asd, tmp_ip);
@@ -586,7 +589,7 @@ bool ServiceDiscovery::do_service_discovery(AppIdSession& asd, Packet* p, int di
             }
             else
             {
-                asd.set_session_flags(APPID_SESSION_MID | APPID_SESSION_SERVICE_DETECTED);
+                asd.set_session_flags(APPID_SESSION_SERVICE_DETECTED);
                 asd.service_disco_state = APPID_DISCO_STATE_FINISHED;
             }
         }
@@ -707,12 +710,8 @@ bool ServiceDiscovery::do_service_discovery(AppIdSession& asd, Packet* p, int di
 int ServiceDiscovery::incompatible_data(AppIdSession& asd, const Packet* pkt, int dir,
     ServiceDetector* service)
 {
-    const SfIp* ip = pkt->ptrs.ip_api.get_src();
-    uint16_t port = asd.service_port ? asd.service_port : pkt->ptrs.sp;
-    ServiceDiscoveryState* sds = AppIdServiceState::get(ip, asd.protocol, port,
-        asd.is_decrypted());
-
-    asd.free_flow_data_by_id(service->get_flow_data_index());
+    if (service)
+        asd.free_flow_data_by_id(service->get_flow_data_index());
 
     // ignore fails while searching with port/pattern selected detectors
     if ( !asd.service_detector && !asd.service_candidates.empty() )
@@ -731,25 +730,28 @@ int ServiceDiscovery::incompatible_data(AppIdSession& asd, const Packet* pkt, in
         return APPID_SUCCESS;
     }
 
+    const SfIp* ip = pkt->ptrs.ip_api.get_src();
+    uint16_t port = asd.service_port ? asd.service_port : pkt->ptrs.sp;
+    ServiceDiscoveryState* sds = AppIdServiceState::add(ip, asd.protocol, port,
+        asd.is_decrypted());
+    sds->set_service(service);
+    sds->set_reset_time(0);
     if ( !asd.service_ip.is_set() )
     {
         asd.service_ip = *ip;
-        if (!asd.service_port)
-            asd.service_port = port;
+        asd.service_port = port;
     }
-    sds->set_reset_time(0);
     return APPID_SUCCESS;
 }
 
 int ServiceDiscovery::fail_service(AppIdSession& asd, const Packet* pkt, int dir,
-    ServiceDetector* service)
+    ServiceDetector* service, ServiceDiscoveryState* sds)
 {
-    const SfIp* ip = pkt->ptrs.ip_api.get_src();
-    uint16_t port = asd.service_port ? asd.service_port : pkt->ptrs.sp;
-
     if ( service )
         asd.free_flow_data_by_id(service->get_flow_data_index());
 
+    /* If we're still working on a port/pattern list of detectors, then ignore
+     * individual fails until we're done looking at everything. */
     if ( !asd.service_detector && !asd.service_candidates.empty() )
         return APPID_SUCCESS;
 
@@ -757,7 +759,7 @@ int ServiceDiscovery::fail_service(AppIdSession& asd, const Packet* pkt, int dir
     asd.set_service_detected();
     asd.clear_session_flags(APPID_SESSION_CONTINUE);
 
-    /* detectors should be careful in marking session UDP_REVERSED otherwise the same detector
+    /* Detectors should be careful in marking session UDP_REVERSED otherwise the same detector
      * gets all future flows. UDP_REVERSE should be marked only when detector positively
      * matches opposite direction patterns. */
     if ( asd.get_session_flags(APPID_SESSION_IGNORE_HOST | APPID_SESSION_UDP_REVERSED) )
@@ -771,15 +773,14 @@ int ServiceDiscovery::fail_service(AppIdSession& asd, const Packet* pkt, int dir
         return APPID_SUCCESS;
     }
 
+    const SfIp* ip = pkt->ptrs.ip_api.get_src();
+    uint16_t port = asd.service_port ? asd.service_port : pkt->ptrs.sp;
     if (!asd.service_ip.is_set())
     {
         asd.service_ip = *ip;
-        if (!asd.service_port)
-            asd.service_port = port;
+        asd.service_port = port;
     }
 
-    ServiceDiscoveryState* sds = AppIdServiceState::get(ip, asd.protocol, port,
-        asd.is_decrypted());
     if ( !sds )
     {
         sds = AppIdServiceState::add(ip, asd.protocol, port, asd.is_decrypted());
index b62d3c4df41c3aec578c19bd4100a1f7723fb5cf..ce368077454f531f200c3fe888ef6d6fcd9b4029 100644 (file)
 
 #include "appid_discovery.h"
 
-#include <map>
+#include <unordered_map>
 #include <vector>
 
-#include "utils/sflsq.h"
 #include "flow/flow.h"
 #include "log/messages.h"
+#include "utils/sflsq.h"
 
 class AppIdConfig;
 class AppIdSession;
@@ -77,20 +77,21 @@ public:
 
     bool do_service_discovery(AppIdSession&, snort::Packet*, int);
     int identify_service(AppIdSession&, snort::Packet*, int dir);
-    int fail_service(AppIdSession&, const snort::Packet*, int dir, ServiceDetector*);
+    int fail_service(AppIdSession&, const snort::Packet*, int dir, ServiceDetector*,
+        ServiceDiscoveryState* sds = nullptr);
     int incompatible_data(AppIdSession&, const snort::Packet*, int dir, ServiceDetector*);
     static int add_ftp_service_state(AppIdSession&);
 
 private:
     ServiceDiscovery(AppIdInspector& ins);
     void initialize() override;
-    void get_next_service(const snort::Packet*, const int dir, AppIdSession&, ServiceDiscoveryState*);
+    void get_next_service(const snort::Packet*, const int dir, AppIdSession&);
     void get_port_based_services(IpProtocol, uint16_t port, AppIdSession&);
     void match_by_pattern(AppIdSession&, const snort::Packet*, IpProtocol);
 
-    std::map<uint16_t, std::vector<ServiceDetector*> > tcp_services;
-    std::map<uint16_t, std::vector<ServiceDetector*> > udp_services;
-    std::map<uint16_t, std::vector<ServiceDetector*> > udp_reversed_services;
+    std::unordered_map<uint16_t, std::vector<ServiceDetector*> > tcp_services;
+    std::unordered_map<uint16_t, std::vector<ServiceDetector*> > udp_services;
+    std::unordered_map<uint16_t, std::vector<ServiceDetector*> > udp_reversed_services;
 };
 
 #endif
index 5d3366edc50553bc1c08c13ec42d2cca978f1253..847778c1f55878f697d64a33d6663f160f7a4d72 100644 (file)
@@ -47,23 +47,35 @@ ServiceDiscoveryState::ServiceDiscoveryState()
 
 ServiceDiscoveryState::~ServiceDiscoveryState()
 {
-    if ( brute_force_mgr )
-        delete brute_force_mgr;
+    delete tcp_brute_force_mgr;
+    delete udp_brute_force_mgr;
 }
 
 ServiceDetector* ServiceDiscoveryState::select_detector_by_brute_force(IpProtocol proto)
 {
-    if ( state == SERVICE_ID_STATE::SEARCHING_BRUTE_FORCE )
+    if (proto == IpProtocol::TCP)
     {
+        if ( !tcp_brute_force_mgr )
+            tcp_brute_force_mgr = new AppIdDetectorList(IpProtocol::TCP);
+        service = tcp_brute_force_mgr->next();
         if (appidDebug->is_active())
-            LogMessage("AppIdDbg %s Brute-force state\n", appidDebug->get_debug_session());
-        if ( !brute_force_mgr )
-            brute_force_mgr = new AppIdDetectorList(proto);
-
-        service = brute_force_mgr->next();
-        if ( !service )
-            state = SERVICE_ID_STATE::FAILED;
+            LogMessage("AppIdDbg %s Brute-force state %s\n", appidDebug->get_debug_session(),
+                service? "" : "failed - no more TCP detectors");
+    }
+    else if (proto == IpProtocol::UDP)
+    {
+        if ( !udp_brute_force_mgr )
+            udp_brute_force_mgr = new AppIdDetectorList(IpProtocol::UDP);
+        service = udp_brute_force_mgr->next();
+        if (appidDebug->is_active())
+            LogMessage("AppIdDbg %s Brute-force state %s\n", appidDebug->get_debug_session(),
+                service? "" : "failed - no more UDP detectors");
     }
+    else
+        service = nullptr;
+
+    if ( !service )
+        state = SERVICE_ID_STATE::FAILED;
 
     return service;
 }
@@ -80,13 +92,13 @@ void ServiceDiscoveryState::set_service_id_valid(ServiceDetector* sd)
 
     if ( !valid_count )
     {
+        valid_count = 1;
         detract_count = 0;
         last_detract.clear();
         invalid_client_count = 0;
         last_invalid_client.clear();
     }
-
-    if ( valid_count < STATE_ID_MAX_VALID_COUNT)
+    else if ( valid_count < STATE_ID_MAX_VALID_COUNT)
         valid_count++;
 }
 
@@ -148,11 +160,15 @@ void ServiceDiscoveryState::set_service_id_failed(AppIdSession& asd, const SfIp*
             }
         }
     }
-    else if ( ( state == SERVICE_ID_STATE::SEARCHING_PORT_PATTERN ) &&
-        ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::PENDING ) &&
-        ( asd.service_candidates.empty() ) )
+    else if ( ( state == SERVICE_ID_STATE::SEARCHING_PORT_PATTERN ) and
+        ( asd.service_search_state == SESSION_SERVICE_SEARCH_STATE::PENDING ) and
+        asd.service_candidates.empty() and
+        !asd.get_session_flags(APPID_SESSION_MID | APPID_SESSION_OOO) )
     {
-        state = SEARCHING_BRUTE_FORCE;
+        if ( ( asd.protocol == IpProtocol::TCP ) or ( asd.protocol == IpProtocol::UDP ) )
+            state = SEARCHING_BRUTE_FORCE;
+        else
+            state = FAILED;
     }
 }
 
@@ -227,7 +243,7 @@ void AppIdServiceState::clean()
         for ( auto& kv : *service_state_cache )
             delete kv.second;
 
-        service_state_cache->empty();
+        service_state_cache->clear();
         delete service_state_cache;
         service_state_cache = nullptr;
     }
@@ -313,6 +329,7 @@ void AppIdServiceState::check_reset(AppIdSession& asd, const SfIp* ip, uint16_t
         else if ( ( packet_time() - sds->get_reset_time() ) >= 60 )
         {
             AppIdServiceState::remove(ip, IpProtocol::TCP, port, asd.is_decrypted());
+            // FIXIT-L - Remove if this flag not used anywhere
             asd.set_session_flags(APPID_SESSION_SERVICE_DELETED);
         }
     }
index 0a4876350bf08fac6ff9e04f70feff634a62bc16..a56d3c1af1e04a9ea163056c53eb2ab5272fe29d 100644 (file)
 
 #include <mutex>
 
+#include "protocols/protocol_ids.h"
 #include "sfip/sf_ip.h"
+
 #include "service_plugins/service_discovery.h"
-#include "protocols/protocol_ids.h"
 #include "utils/util.h"
 
 class ServiceDetector;
@@ -114,7 +115,8 @@ public:
 private:
     SERVICE_ID_STATE state;
     ServiceDetector* service = nullptr;
-    AppIdDetectorList* brute_force_mgr = nullptr;
+    AppIdDetectorList* tcp_brute_force_mgr = nullptr;
+    AppIdDetectorList* udp_brute_force_mgr = nullptr;
     unsigned valid_count = 0;
     unsigned detract_count = 0;
     snort::SfIp last_detract;
index f8c9f47e7c5a24ab924552e21ba6da3de6fe3d21..02ce122b3780acf0267175e94cf6168564b614b6 100644 (file)
@@ -26,3 +26,7 @@ add_cpputest( appid_debug_test
     SOURCES $<TARGET_OBJECTS:appid_cpputest_deps>
 )
 
+add_cpputest( service_state_test
+    SOURCES $<TARGET_OBJECTS:appid_cpputest_deps>
+)
+
index adc6cb433dc8b46938758354a8649c4501161619..459fd00a46a4656325c572d1a4bf98ebf6cc41f8 100644 (file)
@@ -90,6 +90,24 @@ TEST(appid_detector_tests, add_user)
     delete ad;
 }
 
+TEST(appid_detector_tests, get_code_string)
+{
+    AppIdDetector* ad = new TestDetector;
+    STRCMP_EQUAL(ad->get_code_string(APPID_SUCCESS), "success");
+    STRCMP_EQUAL(ad->get_code_string(APPID_INPROCESS), "inprocess");
+    STRCMP_EQUAL(ad->get_code_string(APPID_NEED_REASSEMBLY), "need-reassembly");
+    STRCMP_EQUAL(ad->get_code_string(APPID_NOT_COMPATIBLE), "not-compatible");
+    STRCMP_EQUAL(ad->get_code_string(APPID_INVALID_CLIENT), "invalid-client");
+    STRCMP_EQUAL(ad->get_code_string(APPID_REVERSED), "appid-reversed");
+    STRCMP_EQUAL(ad->get_code_string(APPID_NOMATCH), "no-match");
+    STRCMP_EQUAL(ad->get_code_string(APPID_ENULL), "error-null");
+    STRCMP_EQUAL(ad->get_code_string(APPID_EINVALID), "error-invalid");
+    STRCMP_EQUAL(ad->get_code_string(APPID_ENOMEM), "error-memory");
+    STRCMP_EQUAL(ad->get_code_string(APPID_SUCCESS), "success");
+    STRCMP_EQUAL(ad->get_code_string((APPID_STATUS_CODE)123), "unknown code");
+    delete ad;
+}
+
 int main(int argc, char** argv)
 {
     mock_init_appid_pegs();
index 7d070c0ebc29c2660d590f0b11766f43743e95aa..c6f8dcb898ed09d420e8e07865fbe524d98df04f 100644 (file)
@@ -92,7 +92,8 @@ int ServiceDiscovery::incompatible_data(AppIdSession&, const Packet*, int, Servi
   return 0;
 }
 
-int ServiceDiscovery::fail_service(AppIdSession&, const Packet*, int, ServiceDetector*)
+int ServiceDiscovery::fail_service(AppIdSession&, const Packet*, int, ServiceDetector*,
+    ServiceDiscoveryState*)
 {
   return 0;
 }
diff --git a/src/network_inspectors/appid/test/service_state_test.cc b/src/network_inspectors/appid/test/service_state_test.cc
new file mode 100644 (file)
index 0000000..09269cc
--- /dev/null
@@ -0,0 +1,196 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2018-2018 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// service_state.cc author Masud Hasan <mashasan@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "network_inspectors/appid/service_state.cc"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+
+// Stubs for AppIdDebug
+THREAD_LOCAL AppIdDebug* appidDebug = nullptr;
+void AppIdDebug::activate(const Flow*, const AppIdSession*, bool) { active = true; }
+
+// Stubs for logs
+char test_log[256];
+void LogMessage(const char* format,...)
+{
+    va_list args;
+    va_start(args, format);
+    vsprintf(test_log, format, args);
+    va_end(args);
+}
+void ErrorMessage(const char*,...) {}
+void LogLabel(const char*, FILE*) {}
+THREAD_LOCAL AppIdStats appid_stats;
+
+// Stubs for utils
+char* snort_strdup(const char* str)
+{
+    assert(str);
+    size_t n = strlen(str) + 1;
+    char* p = (char*)snort_alloc(n);
+    memcpy(p, str, n);
+    return p;
+}
+time_t packet_time() { return std::time(0); }
+
+// Stubs for appid classes
+class AppIdInspector{};
+FlowData::FlowData(unsigned, Inspector*) {}
+FlowData::~FlowData() = default;
+AppIdSession::AppIdSession(IpProtocol, const SfIp*, uint16_t, AppIdInspector& inspector)
+    : FlowData(0), inspector(inspector) {}
+AppIdSession::~AppIdSession() = default;
+AppIdDiscovery::AppIdDiscovery(AppIdInspector& ins) : inspector(ins) {}
+AppIdDiscovery::~AppIdDiscovery() {}
+void AppIdDiscovery::register_detector(const std::string&, AppIdDetector*,  IpProtocol) {}
+void AppIdDiscovery::add_pattern_data(AppIdDetector*, SearchTool*, int, const uint8_t* const,
+    unsigned, unsigned) {}
+void AppIdDiscovery::register_tcp_pattern(AppIdDetector*, const uint8_t* const, unsigned,
+    int, unsigned) {}
+void AppIdDiscovery::register_udp_pattern(AppIdDetector*, const uint8_t* const, unsigned,
+    int, unsigned) {}
+int AppIdDiscovery::add_service_port(AppIdDetector*,
+    const ServiceDetectorPort&) { return APPID_EINVALID; }
+void ServiceDiscovery::initialize() {}
+void ServiceDiscovery::finalize_service_patterns() {}
+void ServiceDiscovery::match_by_pattern(AppIdSession&, const Packet*, IpProtocol) {}
+void ServiceDiscovery::get_port_based_services(IpProtocol, uint16_t, AppIdSession&) {}
+void ServiceDiscovery::get_next_service(const Packet*, const int, AppIdSession&) {}
+int ServiceDiscovery::identify_service(AppIdSession&, Packet*, int) { return 0; }
+int ServiceDiscovery::add_ftp_service_state(AppIdSession&) { return 0; }
+bool ServiceDiscovery::do_service_discovery(AppIdSession&, Packet*, int) { return 0; }
+int ServiceDiscovery::incompatible_data(AppIdSession&, const Packet*,int,
+    ServiceDetector*) { return 0; }
+int ServiceDiscovery::fail_service(AppIdSession&, const Packet*, int,
+    ServiceDetector*, ServiceDiscoveryState*) { return 0; }
+int ServiceDiscovery::add_service_port(AppIdDetector*,
+    const ServiceDetectorPort&) { return APPID_EINVALID; }
+ServiceDiscovery::ServiceDiscovery(AppIdInspector& ins)
+    : AppIdDiscovery(ins) {}
+bool new_manager_test = true;
+ServiceDiscovery& ServiceDiscovery::get_instance(AppIdInspector* ins)
+{
+    static THREAD_LOCAL ServiceDiscovery* discovery_manager = nullptr;
+    if (!new_manager_test)
+    {
+        delete discovery_manager;
+        discovery_manager = nullptr;
+    }
+    else if (!discovery_manager)
+        discovery_manager = new ServiceDiscovery(*ins);
+    return *discovery_manager;
+}
+
+TEST_GROUP(service_state_tests)
+{
+    void setup() override
+    {
+        appidDebug = new AppIdDebug();
+        appidDebug->activate(nullptr, nullptr, 0);
+    }
+
+    void teardown() override
+    {
+        delete appidDebug;
+    }
+};
+
+TEST(service_state_tests, select_detector_by_brute_force)
+{
+    ServiceDiscoveryState sds;
+    new_manager_test = true;
+    AppIdInspector ins;
+    ServiceDiscovery::get_instance(&ins);
+
+    // Testing end of brute-force walk for supported and unsupported protocols
+    test_log[0] = '\0';
+    sds.select_detector_by_brute_force(IpProtocol::TCP);
+    STRCMP_EQUAL(test_log, "AppIdDbg  Brute-force state failed - no more TCP detectors\n");
+
+    test_log[0] = '\0';
+    sds.select_detector_by_brute_force(IpProtocol::UDP);
+    STRCMP_EQUAL(test_log, "AppIdDbg  Brute-force state failed - no more UDP detectors\n");
+
+    test_log[0] = '\0';
+    sds.select_detector_by_brute_force(IpProtocol::IP);
+    STRCMP_EQUAL(test_log, "");
+
+    new_manager_test = false;
+    delete &ServiceDiscovery::get_instance();
+}
+
+TEST(service_state_tests, set_service_id_failed)
+{
+    ServiceDiscoveryState sds;
+    AppIdInspector inspector;
+    AppIdSession asd(IpProtocol::PROTO_NOT_SET, nullptr, 0, inspector);
+    SfIp client_ip;
+    new_manager_test = true;
+    AppIdInspector ins;
+    ServiceDiscovery::get_instance(&ins);
+
+    // Testing 3+ failures to exceed STATE_ID_NEEDED_DUPE_DETRACT_COUNT with valid_count = 0
+    client_ip.set("1.2.3.4");
+    sds.set_state(SERVICE_ID_STATE::VALID);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    CHECK_TRUE(sds.get_state() == SERVICE_ID_STATE::SEARCHING_PORT_PATTERN);
+
+    new_manager_test = false;
+    delete &ServiceDiscovery::get_instance();
+}
+
+
+TEST(service_state_tests, set_service_id_failed_with_valid)
+{
+    ServiceDiscoveryState sds;
+    AppIdInspector inspector;
+    AppIdSession asd(IpProtocol::PROTO_NOT_SET, nullptr, 0, inspector);
+    SfIp client_ip;
+    new_manager_test = true;
+    AppIdInspector ins;
+    ServiceDiscovery::get_instance(&ins);
+
+    // Testing 3+ failures to exceed STATE_ID_NEEDED_DUPE_DETRACT_COUNT with valid_count > 1
+    client_ip.set("1.2.3.4");
+    sds.set_state(SERVICE_ID_STATE::VALID);
+    sds.set_service_id_valid(0);
+    sds.set_service_id_valid(0);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    sds.set_service_id_failed(asd, &client_ip, 0);
+    CHECK_TRUE(sds.get_state() == SERVICE_ID_STATE::VALID);
+
+    new_manager_test = false;
+    delete &ServiceDiscovery::get_instance();
+}
+
+int main(int argc, char** argv)
+{
+    int rc = CommandLineTestRunner::RunAllTests(argc, argv);
+    return rc;
+}