]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3763: appid: merge cname pattern matchers with ssl pattern matchers
authorSreeja Athirkandathil Narayanan (sathirka) <sathirka@cisco.com>
Thu, 16 Feb 2023 17:10:51 +0000 (17:10 +0000)
committerSreeja Athirkandathil Narayanan (sathirka) <sathirka@cisco.com>
Thu, 16 Feb 2023 17:10:51 +0000 (17:10 +0000)
Merge in SNORT/snort3 from ~OSTEPANO/snort3:cname_and_cert_merge to master

Squashed commit of the following:

commit 9be16131179eeff287720a474b410885b19cff7a
Author: Oleksandr Stepanov <ostepano@cisco.com>
Date:   Thu Feb 9 10:41:51 2023 -0500

    appid: merge cname pattern matchers with ssl pattern matchers

src/network_inspectors/appid/detector_plugins/ssl_patterns.cc
src/network_inspectors/appid/detector_plugins/ssl_patterns.h
src/network_inspectors/appid/lua_detector_api.cc

index 425d83a3e50a639d922f51ddc13b67eececb8fcf..27f3a042896f07c4c1cd40536bc9b8a6a69431af 100644 (file)
@@ -28,7 +28,7 @@
 
 using namespace snort;
 
-static void create_matcher(SearchTool& matcher, SslPatternList* list)
+static void create_matcher(SearchTool& matcher, SslPatternList* list, CnameCache& set)
 {
     size_t* pattern_index;
     size_t size = 0;
@@ -38,6 +38,9 @@ static void create_matcher(SearchTool& matcher, SslPatternList* list)
 
     for (element = list; element; element = element->next)
     {
+        if (!element->dpattern->is_cname and set.count(*(element->dpattern)))
+            continue;
+
         matcher.add(element->dpattern->pattern,
             element->dpattern->pattern_size, element->dpattern, true);
         (*pattern_index)++;
@@ -57,17 +60,38 @@ static int cert_pattern_match(void* id, void*, int match_end_pos, void* data, vo
     cm->match_start_pos = match_end_pos - target->pattern_size;
     cm->next = *matches;
     *matches = cm;
+    
+    return 0;
+}
+
+static int cname_pattern_match(void* id, void*, int match_end_pos, void* data, void*)
+{
+    MatchedSslPatterns* cm;
+    MatchedSslPatterns** matches = (MatchedSslPatterns**)data;
+    SslPattern* target = (SslPattern*)id;
 
+    /* Only collect the match if it is a cname pattern. */
+    if (target->is_cname)
+    {
+        cm = (MatchedSslPatterns*)snort_alloc(sizeof(MatchedSslPatterns));
+        cm->mpattern = target;
+        cm->match_start_pos = match_end_pos - target->pattern_size;
+        cm->next = *matches;
+        *matches = cm;
+    }
     return 0;
 }
 
 static bool scan_patterns(SearchTool& matcher, const uint8_t* data, size_t size,
-    AppId& client_id, AppId& payload_id)
+    AppId& client_id, AppId& payload_id, bool is_cname_search)
 {
     MatchedSslPatterns* mp = nullptr;
     SslPattern* best_match;
 
-    matcher.find_all((const char*)data, size, cert_pattern_match, false, &mp);
+    if (is_cname_search)
+        matcher.find_all((const char*)data, size, cname_pattern_match, false, &mp);    
+    else
+        matcher.find_all((const char*)data, size, cert_pattern_match, false, &mp);
 
     if (!mp)
         return false;
@@ -132,7 +156,7 @@ static void free_patterns(SslPatternList*& list)
 }
 
 static void add_pattern(SslPatternList*& list, uint8_t* pattern_str, size_t
-    pattern_size, uint8_t type, AppId app_id)
+    pattern_size, uint8_t type, AppId app_id, bool is_cname, CnameCache& set)
 {
     SslPatternList* new_ssl_pattern;
 
@@ -142,45 +166,42 @@ static void add_pattern(SslPatternList*& list, uint8_t* pattern_str, size_t
     new_ssl_pattern->dpattern->app_id = app_id;
     new_ssl_pattern->dpattern->pattern = pattern_str;
     new_ssl_pattern->dpattern->pattern_size = pattern_size;
+    new_ssl_pattern->dpattern->is_cname = is_cname;
 
     new_ssl_pattern->next = list;
     list = new_ssl_pattern;
+
+    if (is_cname)
+        set.emplace(*(new_ssl_pattern->dpattern));
 }
 
 SslPatternMatchers::~SslPatternMatchers()
 {
     free_patterns(cert_pattern_list);
-    free_patterns(cname_pattern_list);
-}
-
-void SslPatternMatchers::add_cert_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id)
-{
-    add_pattern(cert_pattern_list, pattern_str, pattern_size, type, app_id);
 }
 
-void SslPatternMatchers::add_cname_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id)
+void SslPatternMatchers::add_cert_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id, bool is_cname)
 {
-    add_pattern(cname_pattern_list, pattern_str, pattern_size, type, app_id);
+    add_pattern(cert_pattern_list, pattern_str, pattern_size, type, app_id, is_cname, cert_pattern_set);
 }
 
 void SslPatternMatchers::finalize_patterns()
 {
-    create_matcher(ssl_host_matcher, cert_pattern_list);
-    create_matcher(ssl_cname_matcher, cname_pattern_list);
+    create_matcher(ssl_host_matcher, cert_pattern_list, cert_pattern_set);
+    cert_pattern_set.clear();
 }
 
 void SslPatternMatchers::reload_patterns()
 {
     ssl_host_matcher.reload();
-    ssl_cname_matcher.reload();
 }
 
 bool SslPatternMatchers::scan_hostname(const uint8_t* hostname, size_t size, AppId& client_id, AppId& payload_id)
 {
-    return scan_patterns(ssl_host_matcher, hostname, size, client_id, payload_id);
+    return scan_patterns(ssl_host_matcher, hostname, size, client_id, payload_id, false);
 }
 
 bool SslPatternMatchers::scan_cname(const uint8_t* common_name, size_t size, AppId& client_id, AppId& payload_id)
 {
-    return scan_patterns(ssl_cname_matcher, common_name, size, client_id, payload_id);
+    return scan_patterns(ssl_host_matcher, common_name, size, client_id, payload_id, true);
 }
index 044726a28f4d3f91d1277c723de5415eda1b997f..aa2a27d8ba843df194f585a284887706bf1b0ca3 100644 (file)
@@ -21,6 +21,8 @@
 #ifndef SSL_PATTERNS_H
 #define SSL_PATTERNS_H
 
+#include <cstring>
+#include <unordered_set>
 #include "search_engines/search_tool.h"
 #include "application_ids.h"
 
@@ -30,8 +32,25 @@ struct SslPattern
     AppId app_id;
     uint8_t* pattern;
     int pattern_size;
+    bool is_cname;
+
+    bool operator==(const SslPattern& v) const
+    {
+        return this->type == v.type and pattern_size == v.pattern_size
+            and (memcmp(pattern, v.pattern, (size_t)pattern_size) == 0); 
+    }
 };
 
+struct SslCacheKeyHasher
+{
+    size_t operator()(const SslPattern& key) const
+    {
+        return std::hash<std::string>{}(std::string((char*)key.pattern, key.pattern_size));
+    }
+};
+
+typedef std::unordered_set<SslPattern, SslCacheKeyHasher> CnameCache;
+
 struct MatchedSslPatterns
 {
     SslPattern* mpattern;
@@ -49,8 +68,7 @@ class SslPatternMatchers
 {
 public:
     ~SslPatternMatchers();
-    void add_cert_pattern(uint8_t*, size_t, uint8_t, AppId);
-    void add_cname_pattern(uint8_t*, size_t, uint8_t, AppId);
+    void add_cert_pattern(uint8_t*, size_t, uint8_t, AppId, bool);
     void finalize_patterns();
     void reload_patterns();
     bool scan_hostname(const uint8_t*, size_t, AppId&, AppId&);
@@ -58,9 +76,8 @@ public:
 
 private:
     SslPatternList* cert_pattern_list = nullptr;
-    SslPatternList* cname_pattern_list = nullptr;
+    CnameCache cert_pattern_set;
     snort::SearchTool ssl_host_matcher = snort::SearchTool();
-    snort::SearchTool ssl_cname_matcher= snort::SearchTool();
 };
 
 #endif
index a9bee0de4d92574ccfa9ef5d849561d3b25583c8..5f9ab3fc987274c37005850189bcfe9e100373b4 100644 (file)
@@ -1228,24 +1228,25 @@ static int detector_add_ssl_cert_pattern(lua_State* L)
     int index = 1;
 
     uint8_t type = lua_tointeger(L, ++index);
-    AppId app_id  = (AppId)lua_tointeger(L, ++index);
+    AppId app_id = (AppId)lua_tointeger(L, ++index);
     size_t pattern_size = 0;
     const char* tmp_string = lua_tolstring(L, ++index, &pattern_size);
     if (!tmp_string or !pattern_size)
     {
-        ErrorMessage("appid: Invalid SSL Host pattern string in %s.\n", ud->get_detector()->get_name().c_str());
+        ErrorMessage("appid: Invalid SSL Host pattern string in %s.\n", 
+            ud->get_detector()->get_name().c_str());
         return 0;
     }
 
     uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string);
-    ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id);
+    ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id,
+        false);
     ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id);
 
     return 0;
 }
 
-// for Lua this looks something like: addDNSHostPattern(<appId>, '<pattern string>')
-static int detector_add_dns_host_pattern(lua_State* L)
+static int detector_add_ssl_cname_pattern(lua_State* L)
 {
     auto& ud = *UserData<LuaObject>::check(L, DETECTOR, 1);
     // Verify detector user data and that we are NOT in packet context
@@ -1262,17 +1263,21 @@ static int detector_add_dns_host_pattern(lua_State* L)
     const char* tmp_string = lua_tolstring(L, ++index, &pattern_size);
     if (!tmp_string or !pattern_size)
     {
-        ErrorMessage("appid: Invalid DNS Host pattern string.\n");
+        ErrorMessage("appid: Invalid SSL CN pattern string in %s.\n",
+            ud->get_detector()->get_name().c_str());
         return 0;
     }
 
     uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string);
-    ud->get_odp_ctxt().get_dns_matchers().add_host_pattern(pattern_str, pattern_size, type, app_id);
+    ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id,
+        true);
+    ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id);
 
     return 0;
 }
 
-static int detector_add_ssl_cname_pattern(lua_State* L)
+// for Lua this looks something like: addDNSHostPattern(<appId>, '<pattern string>')
+static int detector_add_dns_host_pattern(lua_State* L)
 {
     auto& ud = *UserData<LuaObject>::check(L, DETECTOR, 1);
     // Verify detector user data and that we are NOT in packet context
@@ -1283,19 +1288,18 @@ static int detector_add_ssl_cname_pattern(lua_State* L)
     int index = 1;
 
     uint8_t type = lua_tointeger(L, ++index);
-    AppId app_id  = (AppId)lua_tointeger(L, ++index);
+    AppId app_id = (AppId)lua_tointeger(L, ++index);
 
     size_t pattern_size = 0;
     const char* tmp_string = lua_tolstring(L, ++index, &pattern_size);
     if (!tmp_string or !pattern_size)
     {
-        ErrorMessage("appid: Invalid SSL CN pattern string in %s.\n", ud->get_detector()->get_name().c_str());
+        ErrorMessage("appid: Invalid DNS Host pattern string.\n");
         return 0;
     }
 
     uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string);
-    ud->get_odp_ctxt().get_ssl_matchers().add_cname_pattern(pattern_str, pattern_size, type, app_id);
-    ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id);
+    ud->get_odp_ctxt().get_dns_matchers().add_host_pattern(pattern_str, pattern_size, type, app_id);
 
     return 0;
 }
@@ -3074,9 +3078,9 @@ static const luaL_Reg detector_methods[] =
     { "addRTMPUrl",               detector_add_rtmp_url },
     { "addContentTypePattern",    detector_add_content_type_pattern },
     { "addSSLCertPattern",        detector_add_ssl_cert_pattern },
+    { "addSSLCnamePattern",       detector_add_ssl_cname_pattern },
     { "addSipUserAgent",          detector_add_sip_user_agent },
     { "addSipServer",             detector_add_sip_server },
-    { "addSSLCnamePattern",       detector_add_ssl_cname_pattern },
     { "addSSHPattern",            detector_add_ssh_client_pattern},
     { "addHostFirstPktApp",       detector_add_host_first_pkt_application },
     { "addHostPortApp",           detector_add_host_port_application },