]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3751: appid: use packet thread's odp context instead of inspector's...
authorSreeja Athirkandathil Narayanan (sathirka) <sathirka@cisco.com>
Tue, 31 Jan 2023 21:51:15 +0000 (21:51 +0000)
committerSreeja Athirkandathil Narayanan (sathirka) <sathirka@cisco.com>
Tue, 31 Jan 2023 21:51:15 +0000 (21:51 +0000)
Merge in SNORT/snort3 from ~SATHIRKA/snort3:reload_fixes to master

Squashed commit of the following:

commit fb0d3790437f4b3974552ca94aa68b186b282fd2
Author: Sreeja Athirkandathil Narayanan <sathirka@cisco.com>
Date:   Fri Jan 20 10:24:30 2023 -0500

    appid: use packet thread's odp context instead of inspector's context for packet processing

src/network_inspectors/appid/appid_api.cc
src/network_inspectors/appid/appid_ha.cc
src/network_inspectors/appid/appid_session.h
src/network_inspectors/appid/lua_detector_api.cc
src/network_inspectors/appid/test/appid_api_test.cc

index f06316261361bbaaad384091b0109fd757cdffd5..cf33b6b32f714bdd54e14058df425e77b3d2681c 100644 (file)
@@ -124,14 +124,16 @@ bool AppIdApi::ssl_app_group_id_lookup(Flow* flow, const char* server_name,
     client_id = APP_ID_NONE;
     payload_id = APP_ID_NONE;
 
+    if (!pkt_thread_odp_ctxt)
+        return false;
+
     if (flow)
         asd = get_appid_session(*flow);
 
     if (asd)
     {
         // Skip detection for sessions using old odp context after odp reload
-        if (!pkt_thread_odp_ctxt or
-            pkt_thread_odp_ctxt->get_version() != asd->get_odp_ctxt_version())
+        if (pkt_thread_odp_ctxt->get_version() != asd->get_odp_ctxt_version())
             return false;
 
         AppidChangeBits change_bits;
@@ -213,24 +215,20 @@ bool AppIdApi::ssl_app_group_id_lookup(Flow* flow, const char* server_name,
     }
     else
     {
-        AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true);
-        if (inspector)
-        {
-            SslPatternMatchers& ssl_matchers = inspector->get_ctxt().get_odp_ctxt().get_ssl_matchers();
+        SslPatternMatchers& ssl_matchers = pkt_thread_odp_ctxt->get_ssl_matchers();
 
-            if (server_name and !sni_mismatch)
-                ssl_matchers.scan_hostname((const uint8_t*)server_name, strlen(server_name),
-                    client_id, payload_id);
-            if (first_alt_name and client_id == APP_ID_NONE and payload_id == APP_ID_NONE)
-                ssl_matchers.scan_hostname((const uint8_t*)first_alt_name, strlen(first_alt_name),
-                    client_id, payload_id);
-            if (common_name and client_id == APP_ID_NONE and payload_id == APP_ID_NONE)
-                ssl_matchers.scan_cname((const uint8_t*)common_name, strlen(common_name), client_id,
-                    payload_id);
-            if (org_unit and client_id == APP_ID_NONE and payload_id == APP_ID_NONE)
-                ssl_matchers.scan_cname((const uint8_t*)org_unit, strlen(org_unit), client_id,
-                    payload_id);
-        }
+        if (server_name and !sni_mismatch)
+            ssl_matchers.scan_hostname((const uint8_t*)server_name, strlen(server_name),
+                client_id, payload_id);
+        if (first_alt_name and client_id == APP_ID_NONE and payload_id == APP_ID_NONE)
+            ssl_matchers.scan_hostname((const uint8_t*)first_alt_name, strlen(first_alt_name),
+                client_id, payload_id);
+        if (common_name and client_id == APP_ID_NONE and payload_id == APP_ID_NONE)
+            ssl_matchers.scan_cname((const uint8_t*)common_name, strlen(common_name), client_id,
+                payload_id);
+        if (org_unit and client_id == APP_ID_NONE and payload_id == APP_ID_NONE)
+            ssl_matchers.scan_cname((const uint8_t*)org_unit, strlen(org_unit), client_id,
+                payload_id);
     }
 
     if (client_id != APP_ID_NONE or payload_id != APP_ID_NONE)
index 0b1a8ddfac66005c27780ca61f74dc2ff4f1e17b..196a0f4abe87616ec1ae4a5d1e64c1abc9b633a5 100644 (file)
@@ -49,7 +49,7 @@ static AppIdSession* create_appid_session(Flow& flow, const FlowKey* key,
     AppIdSession* asd = new AppIdSession(static_cast<IpProtocol>(key->ip_protocol),
         flow.flags.client_initiated ? &flow.client_ip : &flow.server_ip,
         flow.flags.client_initiated ? flow.client_port : flow.server_port, inspector,
-        inspector.get_ctxt().get_odp_ctxt(), key->addressSpaceId);
+        *pkt_thread_odp_ctxt, key->addressSpaceId);
     if (appidDebug->is_active())
         LogMessage("AppIdDbg %s high-avail - New AppId session created in consume\n",
             appidDebug->get_debug_session());
@@ -70,7 +70,7 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
     AppIdInspector* inspector =
         static_cast<AppIdInspector*>(
             InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type));
-    if (!inspector)
+    if (!inspector or !pkt_thread_odp_ctxt)
         return false;
 
     AppIdSession* asd = (AppIdSession*)(flow->get_flow_data(AppIdSession::inspector_id));
@@ -107,7 +107,7 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
             asd->service_disco_state = APPID_DISCO_STATE_FINISHED;
 
         asd->client_disco_state = APPID_DISCO_STATE_FINISHED;
-        if (asd->get_tp_appid_ctxt())
+        if (asd->get_tp_appid_ctxt() and !ThirdPartyAppIdContext::get_tp_reload_in_progress())
         {
             const TPLibHandler* tph = TPLibHandler::get();
             TpAppIdCreateSession tpsf = tph->tpsession_factory();
@@ -232,7 +232,7 @@ bool AppIdHAHttpClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
     AppIdInspector* inspector =
         static_cast<AppIdInspector*>(
             InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type));
-    if (!inspector)
+    if (!inspector or !pkt_thread_odp_ctxt)
         return false;
 
     AppIdSession* asd = appid_api.get_appid_session(*flow);
@@ -322,7 +322,7 @@ bool AppIdHATlsHostClient::consume(Flow*& flow, const FlowKey* key, HAMessage& m
     AppIdInspector* inspector =
         static_cast<AppIdInspector*>(
             InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type));
-    if (!inspector)
+    if (!inspector or !pkt_thread_odp_ctxt)
         return false;
 
     AppIdSession* asd = appid_api.get_appid_session(*flow);
index bb41d4a9114c0c543e901410dd3ac9ac1fbf18ac..15639ded3ae2940a8e84418a05c067c893c88fd6 100644 (file)
@@ -168,7 +168,8 @@ public:
     void set_tls_host(const char* new_tls_host, uint32_t len, AppidChangeBits& change_bits)
     {
         set_tls_host(new_tls_host, len, true);
-        change_bits.set(APPID_TLSHOST_BIT);
+        if (new_tls_host and *new_tls_host != '\0')
+            change_bits.set(APPID_TLSHOST_BIT);
     }
 
     void set_tls_first_alt_name(const char* new_tls_first_alt_name, uint32_t len, AppidChangeBits& change_bits)
index c55bb33c98751f38012ce6b7eb9d31a9f1541358..678b5c031e88ed7f55790432fe0ff841ea828a3c 100644 (file)
@@ -545,7 +545,7 @@ static int service_add_service(lua_State* L)
     /*Phase2 - discuss AppIdServiceSubtype will be maintained on lua side therefore the last
       parameter on the following call is nullptr. Subtype is not displayed on DC at present. */
     unsigned int retValue = ud->sd->add_service(*lsd->ldp.change_bits, *lsd->ldp.asd, lsd->ldp.pkt,
-        lsd->ldp.dir, ud->get_odp_ctxt().get_app_info_mgr().get_appid_by_service_id(service_id),
+        lsd->ldp.dir, lsd->ldp.asd->get_odp_ctxt().get_app_info_mgr().get_appid_by_service_id(service_id),
         vendor, version, nullptr);
 
     lua_pushnumber(L, retValue);
@@ -944,9 +944,10 @@ static int client_add_application(lua_State* L)
     unsigned int service_id = lua_tonumber(L, 2);
     unsigned int productId = lua_tonumber(L, 4);
     const char* version = lua_tostring(L, 5);
+    OdpContext& odp_ctxt = lsd->ldp.asd->get_odp_ctxt();
     ud->cd->add_app(*lsd->ldp.pkt, *lsd->ldp.asd, lsd->ldp.dir,
-        ud->get_odp_ctxt().get_app_info_mgr().get_appid_by_service_id(service_id),
-        ud->get_odp_ctxt().get_app_info_mgr().get_appid_by_client_id(productId), version,
+        odp_ctxt.get_app_info_mgr().get_appid_by_service_id(service_id),
+        odp_ctxt.get_app_info_mgr().get_appid_by_client_id(productId), version,
         *lsd->ldp.change_bits);
 
     lua_pushnumber(L, 0);
@@ -968,7 +969,7 @@ static int client_add_user(lua_State* L)
     const char* userName = lua_tostring(L, 2);
     unsigned int service_id = lua_tonumber(L, 3);
     ud->cd->add_user(*lsd->ldp.asd, userName,
-        ud->get_odp_ctxt().get_app_info_mgr().get_appid_by_service_id(service_id), true,
+        lsd->ldp.asd->get_odp_ctxt().get_app_info_mgr().get_appid_by_service_id(service_id), true,
         *lsd->ldp.change_bits);
     lua_pushnumber(L, 0);
     return 1;
@@ -982,7 +983,7 @@ static int client_add_payload(lua_State* L)
 
     unsigned int payloadId = lua_tonumber(L, 2);
     ud->cd->add_payload(*lsd->ldp.asd,
-        ud->get_odp_ctxt().get_app_info_mgr().get_appid_by_payload_id(payloadId));
+        lsd->ldp.asd->get_odp_ctxt().get_app_info_mgr().get_appid_by_payload_id(payloadId));
 
     lua_pushnumber(L, 0);
     return 1;
@@ -1302,9 +1303,9 @@ static int detector_add_host_port_dynamic(lua_State* L)
 {
     auto& ud = *UserData<LuaClientObject>::check(L, DETECTOR, 1);
     // Verify detector user data and that we are in packet context
-    ud->validate_lua_state(true);
+    LuaStateDescriptor* lsd = ud->validate_lua_state(true);
 
-    if (!ud->get_odp_ctxt().is_host_port_app_cache_runtime)
+    if (!lsd->ldp.asd->get_odp_ctxt().is_host_port_app_cache_runtime)
         return 0;
 
     SfIp ip_address;
@@ -2635,16 +2636,16 @@ static int create_future_flow(lua_State* L)
     AppId client_id  = lua_tointeger(L, 8);
     AppId payload_id = lua_tointeger(L, 9);
     AppId app_id_to_snort = lua_tointeger(L, 10);
+    OdpContext& odp_ctxt = lsd->ldp.asd->get_odp_ctxt();
     if (app_id_to_snort > APP_ID_NONE)
     {
-        AppInfoTableEntry* entry = ud->get_odp_ctxt().get_app_info_mgr().get_app_info_entry(
+        AppInfoTableEntry* entry = odp_ctxt.get_app_info_mgr().get_app_info_entry(
             app_id_to_snort);
         if (!entry)
             return 0;
         snort_protocol_id = entry->snort_protocol_id;
     }
 
-    OdpContext& odp_ctxt = lsd->ldp.asd->get_odp_ctxt();
     AppIdSession* fp = AppIdSession::create_future_session(lsd->ldp.pkt,  &client_addr,
         client_port, &server_addr, server_port, proto, snort_protocol_id, odp_ctxt);
     if (fp)
index 635f95f0a1d4105b30b30bcb42fc9b6531206001..55a196e616c30ad96f47ad59bd62d69e140b9be6 100644 (file)
@@ -49,6 +49,7 @@
 using namespace snort;
 
 static SnortProtocolId dummy_http2_protocol_id = 1;
+char const* APPID_UT_ORG_UNIT = "Google";
 
 namespace snort
 {
@@ -111,9 +112,14 @@ bool SslPatternMatchers::scan_cname(const uint8_t* cname, size_t, AppId& client_
 {
     if (((const char*)cname) == APPID_UT_TLS_HOST)
     {
-        client_id = APPID_UT_ID + 2;;
+        client_id = APPID_UT_ID + 2;
         payload_id = APPID_UT_ID + 2;
     }
+    else if (((const char*)cname) == APPID_UT_ORG_UNIT)
+    {
+        client_id = APPID_UT_ID + 3;
+        payload_id = APPID_UT_ID + 3;
+    }
     else
     {
         client_id = 0;
@@ -242,7 +248,7 @@ TEST(appid_api, get_application_id)
 
 TEST(appid_api, ssl_app_group_id_lookup)
 {
-    mock().expectNCalls(5, "publish");
+    mock().expectNCalls(6, "publish");
     AppId service, client, payload = APP_ID_NONE;
     bool val = false;
 
@@ -257,6 +263,7 @@ TEST(appid_api, ssl_app_group_id_lookup)
     CHECK_EQUAL(payload, APPID_UT_ID);
     STRCMP_EQUAL("Published change_bits == 00000000000000000000", test_log);
 
+    // Server name based detection
     service = APP_ID_NONE;
     client = APP_ID_NONE;
     payload = APP_ID_NONE;
@@ -270,6 +277,7 @@ TEST(appid_api, ssl_app_group_id_lookup)
     STRCMP_EQUAL(mock_session->tsession->get_tls_cname(), APPID_UT_TLS_HOST);
     STRCMP_EQUAL("Published change_bits == 00000000000100000000", test_log);
 
+    // Common name based detection
     mock_session->tsession->set_tls_host("www.cisco.com", 13, change_bits);
     mock_session->tsession->set_tls_cname("www.cisco.com", 13, change_bits);
     mock_session->tsession->set_tls_org_unit("Cisco", 5);
@@ -286,22 +294,37 @@ TEST(appid_api, ssl_app_group_id_lookup)
     STRCMP_EQUAL(mock_session->tsession->get_tls_org_unit(), "Cisco");
     STRCMP_EQUAL("Published change_bits == 00000000000100000000", test_log);
 
-    string host = "";
-    val = appid_api.ssl_app_group_id_lookup(flow, (const char*)(host.c_str()), nullptr,
-        (const char*)APPID_UT_TLS_HOST, (const char*)"Google", false, service, client, payload);
+    // First alt name based detection
+    change_bits.reset();
+    mock_session->tsession->set_tls_host("", 0, change_bits);
+    val = appid_api.ssl_app_group_id_lookup(flow, nullptr, (const char*)APPID_UT_TLS_HOST,
+        nullptr, nullptr, false, service, client, payload);
     CHECK_TRUE(val);
-    CHECK_EQUAL(client, APPID_UT_ID + 2);
-    CHECK_EQUAL(payload, APPID_UT_ID + 2);
+    CHECK_EQUAL(client, APPID_UT_ID + 1);
+    CHECK_EQUAL(payload, APPID_UT_ID + 1);
     STRCMP_EQUAL(mock_session->tsession->get_tls_host(), APPID_UT_TLS_HOST);
-    STRCMP_EQUAL(mock_session->tsession->get_tls_cname(), APPID_UT_TLS_HOST);
-    STRCMP_EQUAL(mock_session->tsession->get_tls_org_unit(), "Google");
+    STRCMP_EQUAL(mock_session->tsession->get_tls_first_alt_name(), APPID_UT_TLS_HOST);
     STRCMP_EQUAL("Published change_bits == 00000000000100000000", test_log);
 
+    // Org unit based detection
+    string host = "";
+    change_bits.reset();
+    mock_session->tsession->set_tls_host("", 0, change_bits);
+    val = appid_api.ssl_app_group_id_lookup(flow, (const char*)(host.c_str()), nullptr,
+        nullptr, (const char*)APPID_UT_ORG_UNIT, false, service, client, payload);
+    CHECK_TRUE(val);
+    CHECK_EQUAL(client, APPID_UT_ID + 3);
+    CHECK_EQUAL(payload, APPID_UT_ID + 3);
+    STRCMP_EQUAL(mock_session->tsession->get_tls_org_unit(), APPID_UT_ORG_UNIT);
+    STRCMP_EQUAL("Published change_bits == 00000000000000000000", test_log);
+
     // Override client id found by SSL pattern matcher with the client id provided by
     // Encrypted Visibility Engine if available
     service = APP_ID_NONE;
     client = APP_ID_NONE;
     payload = APP_ID_NONE;
+    change_bits.reset();
+    mock_session->tsession->set_tls_host("", 0, change_bits);
     mock_session->set_client_id(APP_ID_NONE);
     mock_session->set_eve_client_app_id(APPID_UT_ID + 100);
     val = appid_api.ssl_app_group_id_lookup(flow, (const char*)APPID_UT_TLS_HOST, (const char*)APPID_UT_TLS_HOST,
@@ -315,6 +338,48 @@ TEST(appid_api, ssl_app_group_id_lookup)
     STRCMP_EQUAL("Published change_bits == 00000000000100000000", test_log);
 
     mock().checkExpectations();
+
+    // When appid session is not existing
+    // 1. Match based on server name
+    Flow* f = new Flow;
+    flow->set_flow_data(nullptr);
+    service = APP_ID_NONE;
+    client = APP_ID_NONE;
+    payload = APP_ID_NONE;
+    val = appid_api.ssl_app_group_id_lookup(f, (const char*)APPID_UT_TLS_HOST, (const char*)APPID_UT_TLS_HOST,
+        (const char*)APPID_UT_TLS_HOST, (const char*)APPID_UT_TLS_HOST, false, service, client, payload);
+    CHECK_TRUE(val);
+    CHECK_EQUAL(client, APPID_UT_ID + 1);
+    CHECK_EQUAL(payload, APPID_UT_ID + 1);
+
+    // 2. First alt name match
+    client = APP_ID_NONE;
+    payload = APP_ID_NONE;
+    val = appid_api.ssl_app_group_id_lookup(f, nullptr, (const char*)APPID_UT_TLS_HOST,
+        (const char*)APPID_UT_TLS_HOST, (const char*)APPID_UT_TLS_HOST, false, service, client, payload);
+    CHECK_TRUE(val);
+    CHECK_EQUAL(client, APPID_UT_ID + 1);
+    CHECK_EQUAL(payload, APPID_UT_ID + 1);
+
+    // 3. CN match
+    client = APP_ID_NONE;
+    payload = APP_ID_NONE;
+    val = appid_api.ssl_app_group_id_lookup(f, nullptr, nullptr, (const char*)APPID_UT_TLS_HOST,
+        (const char*)APPID_UT_TLS_HOST, false, service, client, payload);
+    CHECK_TRUE(val);
+    CHECK_EQUAL(client, APPID_UT_ID + 2);
+    CHECK_EQUAL(payload, APPID_UT_ID + 2);
+
+    // 4. Org unit match
+    client = APP_ID_NONE;
+    payload = APP_ID_NONE;
+    val = appid_api.ssl_app_group_id_lookup(f, nullptr, nullptr, nullptr, (const char*)APPID_UT_TLS_HOST,
+        false, service, client, payload);
+    CHECK_TRUE(val);
+    CHECK_EQUAL(client, APPID_UT_ID + 2);
+    CHECK_EQUAL(payload, APPID_UT_ID + 2);
+
+    delete f;
 }
 
 TEST(appid_api, is_inspection_needed)