From: Sreeja Athirkandathil Narayanan (sathirka) Date: Thu, 16 Feb 2023 17:10:51 +0000 (+0000) Subject: Pull request #3763: appid: merge cname pattern matchers with ssl pattern matchers X-Git-Tag: 3.1.56.0~1 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=5964efedac646452510a07eb62a02d32708222ec;p=thirdparty%2Fsnort3.git Pull request #3763: appid: merge cname pattern matchers with ssl pattern matchers Merge in SNORT/snort3 from ~OSTEPANO/snort3:cname_and_cert_merge to master Squashed commit of the following: commit 9be16131179eeff287720a474b410885b19cff7a Author: Oleksandr Stepanov Date: Thu Feb 9 10:41:51 2023 -0500 appid: merge cname pattern matchers with ssl pattern matchers --- diff --git a/src/network_inspectors/appid/detector_plugins/ssl_patterns.cc b/src/network_inspectors/appid/detector_plugins/ssl_patterns.cc index 425d83a3e..27f3a0428 100644 --- a/src/network_inspectors/appid/detector_plugins/ssl_patterns.cc +++ b/src/network_inspectors/appid/detector_plugins/ssl_patterns.cc @@ -28,7 +28,7 @@ using namespace snort; -static void create_matcher(SearchTool& matcher, SslPatternList* list) +static void create_matcher(SearchTool& matcher, SslPatternList* list, CnameCache& set) { size_t* pattern_index; size_t size = 0; @@ -38,6 +38,9 @@ static void create_matcher(SearchTool& matcher, SslPatternList* list) for (element = list; element; element = element->next) { + if (!element->dpattern->is_cname and set.count(*(element->dpattern))) + continue; + matcher.add(element->dpattern->pattern, element->dpattern->pattern_size, element->dpattern, true); (*pattern_index)++; @@ -57,17 +60,38 @@ static int cert_pattern_match(void* id, void*, int match_end_pos, void* data, vo cm->match_start_pos = match_end_pos - target->pattern_size; cm->next = *matches; *matches = cm; + + return 0; +} + +static int cname_pattern_match(void* id, void*, int match_end_pos, void* data, void*) +{ + MatchedSslPatterns* cm; + MatchedSslPatterns** matches = (MatchedSslPatterns**)data; + SslPattern* target = (SslPattern*)id; + /* Only collect the match if it is a cname pattern. */ + if (target->is_cname) + { + cm = (MatchedSslPatterns*)snort_alloc(sizeof(MatchedSslPatterns)); + cm->mpattern = target; + cm->match_start_pos = match_end_pos - target->pattern_size; + cm->next = *matches; + *matches = cm; + } return 0; } static bool scan_patterns(SearchTool& matcher, const uint8_t* data, size_t size, - AppId& client_id, AppId& payload_id) + AppId& client_id, AppId& payload_id, bool is_cname_search) { MatchedSslPatterns* mp = nullptr; SslPattern* best_match; - matcher.find_all((const char*)data, size, cert_pattern_match, false, &mp); + if (is_cname_search) + matcher.find_all((const char*)data, size, cname_pattern_match, false, &mp); + else + matcher.find_all((const char*)data, size, cert_pattern_match, false, &mp); if (!mp) return false; @@ -132,7 +156,7 @@ static void free_patterns(SslPatternList*& list) } static void add_pattern(SslPatternList*& list, uint8_t* pattern_str, size_t - pattern_size, uint8_t type, AppId app_id) + pattern_size, uint8_t type, AppId app_id, bool is_cname, CnameCache& set) { SslPatternList* new_ssl_pattern; @@ -142,45 +166,42 @@ static void add_pattern(SslPatternList*& list, uint8_t* pattern_str, size_t new_ssl_pattern->dpattern->app_id = app_id; new_ssl_pattern->dpattern->pattern = pattern_str; new_ssl_pattern->dpattern->pattern_size = pattern_size; + new_ssl_pattern->dpattern->is_cname = is_cname; new_ssl_pattern->next = list; list = new_ssl_pattern; + + if (is_cname) + set.emplace(*(new_ssl_pattern->dpattern)); } SslPatternMatchers::~SslPatternMatchers() { free_patterns(cert_pattern_list); - free_patterns(cname_pattern_list); -} - -void SslPatternMatchers::add_cert_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id) -{ - add_pattern(cert_pattern_list, pattern_str, pattern_size, type, app_id); } -void SslPatternMatchers::add_cname_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id) +void SslPatternMatchers::add_cert_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id, bool is_cname) { - add_pattern(cname_pattern_list, pattern_str, pattern_size, type, app_id); + add_pattern(cert_pattern_list, pattern_str, pattern_size, type, app_id, is_cname, cert_pattern_set); } void SslPatternMatchers::finalize_patterns() { - create_matcher(ssl_host_matcher, cert_pattern_list); - create_matcher(ssl_cname_matcher, cname_pattern_list); + create_matcher(ssl_host_matcher, cert_pattern_list, cert_pattern_set); + cert_pattern_set.clear(); } void SslPatternMatchers::reload_patterns() { ssl_host_matcher.reload(); - ssl_cname_matcher.reload(); } bool SslPatternMatchers::scan_hostname(const uint8_t* hostname, size_t size, AppId& client_id, AppId& payload_id) { - return scan_patterns(ssl_host_matcher, hostname, size, client_id, payload_id); + return scan_patterns(ssl_host_matcher, hostname, size, client_id, payload_id, false); } bool SslPatternMatchers::scan_cname(const uint8_t* common_name, size_t size, AppId& client_id, AppId& payload_id) { - return scan_patterns(ssl_cname_matcher, common_name, size, client_id, payload_id); + return scan_patterns(ssl_host_matcher, common_name, size, client_id, payload_id, true); } diff --git a/src/network_inspectors/appid/detector_plugins/ssl_patterns.h b/src/network_inspectors/appid/detector_plugins/ssl_patterns.h index 044726a28..aa2a27d8b 100644 --- a/src/network_inspectors/appid/detector_plugins/ssl_patterns.h +++ b/src/network_inspectors/appid/detector_plugins/ssl_patterns.h @@ -21,6 +21,8 @@ #ifndef SSL_PATTERNS_H #define SSL_PATTERNS_H +#include +#include #include "search_engines/search_tool.h" #include "application_ids.h" @@ -30,8 +32,25 @@ struct SslPattern AppId app_id; uint8_t* pattern; int pattern_size; + bool is_cname; + + bool operator==(const SslPattern& v) const + { + return this->type == v.type and pattern_size == v.pattern_size + and (memcmp(pattern, v.pattern, (size_t)pattern_size) == 0); + } }; +struct SslCacheKeyHasher +{ + size_t operator()(const SslPattern& key) const + { + return std::hash{}(std::string((char*)key.pattern, key.pattern_size)); + } +}; + +typedef std::unordered_set CnameCache; + struct MatchedSslPatterns { SslPattern* mpattern; @@ -49,8 +68,7 @@ class SslPatternMatchers { public: ~SslPatternMatchers(); - void add_cert_pattern(uint8_t*, size_t, uint8_t, AppId); - void add_cname_pattern(uint8_t*, size_t, uint8_t, AppId); + void add_cert_pattern(uint8_t*, size_t, uint8_t, AppId, bool); void finalize_patterns(); void reload_patterns(); bool scan_hostname(const uint8_t*, size_t, AppId&, AppId&); @@ -58,9 +76,8 @@ public: private: SslPatternList* cert_pattern_list = nullptr; - SslPatternList* cname_pattern_list = nullptr; + CnameCache cert_pattern_set; snort::SearchTool ssl_host_matcher = snort::SearchTool(); - snort::SearchTool ssl_cname_matcher= snort::SearchTool(); }; #endif diff --git a/src/network_inspectors/appid/lua_detector_api.cc b/src/network_inspectors/appid/lua_detector_api.cc index a9bee0de4..5f9ab3fc9 100644 --- a/src/network_inspectors/appid/lua_detector_api.cc +++ b/src/network_inspectors/appid/lua_detector_api.cc @@ -1228,24 +1228,25 @@ static int detector_add_ssl_cert_pattern(lua_State* L) int index = 1; uint8_t type = lua_tointeger(L, ++index); - AppId app_id = (AppId)lua_tointeger(L, ++index); + AppId app_id = (AppId)lua_tointeger(L, ++index); size_t pattern_size = 0; const char* tmp_string = lua_tolstring(L, ++index, &pattern_size); if (!tmp_string or !pattern_size) { - ErrorMessage("appid: Invalid SSL Host pattern string in %s.\n", ud->get_detector()->get_name().c_str()); + ErrorMessage("appid: Invalid SSL Host pattern string in %s.\n", + ud->get_detector()->get_name().c_str()); return 0; } uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string); - ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id); + ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id, + false); ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id); return 0; } -// for Lua this looks something like: addDNSHostPattern(, '') -static int detector_add_dns_host_pattern(lua_State* L) +static int detector_add_ssl_cname_pattern(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); // Verify detector user data and that we are NOT in packet context @@ -1262,17 +1263,21 @@ static int detector_add_dns_host_pattern(lua_State* L) const char* tmp_string = lua_tolstring(L, ++index, &pattern_size); if (!tmp_string or !pattern_size) { - ErrorMessage("appid: Invalid DNS Host pattern string.\n"); + ErrorMessage("appid: Invalid SSL CN pattern string in %s.\n", + ud->get_detector()->get_name().c_str()); return 0; } uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string); - ud->get_odp_ctxt().get_dns_matchers().add_host_pattern(pattern_str, pattern_size, type, app_id); + ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id, + true); + ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id); return 0; } -static int detector_add_ssl_cname_pattern(lua_State* L) +// for Lua this looks something like: addDNSHostPattern(, '') +static int detector_add_dns_host_pattern(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); // Verify detector user data and that we are NOT in packet context @@ -1283,19 +1288,18 @@ static int detector_add_ssl_cname_pattern(lua_State* L) int index = 1; uint8_t type = lua_tointeger(L, ++index); - AppId app_id = (AppId)lua_tointeger(L, ++index); + AppId app_id = (AppId)lua_tointeger(L, ++index); size_t pattern_size = 0; const char* tmp_string = lua_tolstring(L, ++index, &pattern_size); if (!tmp_string or !pattern_size) { - ErrorMessage("appid: Invalid SSL CN pattern string in %s.\n", ud->get_detector()->get_name().c_str()); + ErrorMessage("appid: Invalid DNS Host pattern string.\n"); return 0; } uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string); - ud->get_odp_ctxt().get_ssl_matchers().add_cname_pattern(pattern_str, pattern_size, type, app_id); - ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id); + ud->get_odp_ctxt().get_dns_matchers().add_host_pattern(pattern_str, pattern_size, type, app_id); return 0; } @@ -3074,9 +3078,9 @@ static const luaL_Reg detector_methods[] = { "addRTMPUrl", detector_add_rtmp_url }, { "addContentTypePattern", detector_add_content_type_pattern }, { "addSSLCertPattern", detector_add_ssl_cert_pattern }, + { "addSSLCnamePattern", detector_add_ssl_cname_pattern }, { "addSipUserAgent", detector_add_sip_user_agent }, { "addSipServer", detector_add_sip_server }, - { "addSSLCnamePattern", detector_add_ssl_cname_pattern }, { "addSSHPattern", detector_add_ssh_client_pattern}, { "addHostFirstPktApp", detector_add_host_first_pkt_application }, { "addHostPortApp", detector_add_host_port_application },