using namespace snort;
-static void create_matcher(SearchTool& matcher, SslPatternList* list)
+static void create_matcher(SearchTool& matcher, SslPatternList* list, CnameCache& set)
{
size_t* pattern_index;
size_t size = 0;
for (element = list; element; element = element->next)
{
+ if (!element->dpattern->is_cname and set.count(*(element->dpattern)))
+ continue;
+
matcher.add(element->dpattern->pattern,
element->dpattern->pattern_size, element->dpattern, true);
(*pattern_index)++;
cm->match_start_pos = match_end_pos - target->pattern_size;
cm->next = *matches;
*matches = cm;
+
+ return 0;
+}
+
+static int cname_pattern_match(void* id, void*, int match_end_pos, void* data, void*)
+{
+ MatchedSslPatterns* cm;
+ MatchedSslPatterns** matches = (MatchedSslPatterns**)data;
+ SslPattern* target = (SslPattern*)id;
+ /* Only collect the match if it is a cname pattern. */
+ if (target->is_cname)
+ {
+ cm = (MatchedSslPatterns*)snort_alloc(sizeof(MatchedSslPatterns));
+ cm->mpattern = target;
+ cm->match_start_pos = match_end_pos - target->pattern_size;
+ cm->next = *matches;
+ *matches = cm;
+ }
return 0;
}
static bool scan_patterns(SearchTool& matcher, const uint8_t* data, size_t size,
- AppId& client_id, AppId& payload_id)
+ AppId& client_id, AppId& payload_id, bool is_cname_search)
{
MatchedSslPatterns* mp = nullptr;
SslPattern* best_match;
- matcher.find_all((const char*)data, size, cert_pattern_match, false, &mp);
+ if (is_cname_search)
+ matcher.find_all((const char*)data, size, cname_pattern_match, false, &mp);
+ else
+ matcher.find_all((const char*)data, size, cert_pattern_match, false, &mp);
if (!mp)
return false;
}
static void add_pattern(SslPatternList*& list, uint8_t* pattern_str, size_t
- pattern_size, uint8_t type, AppId app_id)
+ pattern_size, uint8_t type, AppId app_id, bool is_cname, CnameCache& set)
{
SslPatternList* new_ssl_pattern;
new_ssl_pattern->dpattern->app_id = app_id;
new_ssl_pattern->dpattern->pattern = pattern_str;
new_ssl_pattern->dpattern->pattern_size = pattern_size;
+ new_ssl_pattern->dpattern->is_cname = is_cname;
new_ssl_pattern->next = list;
list = new_ssl_pattern;
+
+ if (is_cname)
+ set.emplace(*(new_ssl_pattern->dpattern));
}
SslPatternMatchers::~SslPatternMatchers()
{
free_patterns(cert_pattern_list);
- free_patterns(cname_pattern_list);
-}
-
-void SslPatternMatchers::add_cert_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id)
-{
- add_pattern(cert_pattern_list, pattern_str, pattern_size, type, app_id);
}
-void SslPatternMatchers::add_cname_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id)
+void SslPatternMatchers::add_cert_pattern(uint8_t* pattern_str, size_t pattern_size, uint8_t type, AppId app_id, bool is_cname)
{
- add_pattern(cname_pattern_list, pattern_str, pattern_size, type, app_id);
+ add_pattern(cert_pattern_list, pattern_str, pattern_size, type, app_id, is_cname, cert_pattern_set);
}
void SslPatternMatchers::finalize_patterns()
{
- create_matcher(ssl_host_matcher, cert_pattern_list);
- create_matcher(ssl_cname_matcher, cname_pattern_list);
+ create_matcher(ssl_host_matcher, cert_pattern_list, cert_pattern_set);
+ cert_pattern_set.clear();
}
void SslPatternMatchers::reload_patterns()
{
ssl_host_matcher.reload();
- ssl_cname_matcher.reload();
}
bool SslPatternMatchers::scan_hostname(const uint8_t* hostname, size_t size, AppId& client_id, AppId& payload_id)
{
- return scan_patterns(ssl_host_matcher, hostname, size, client_id, payload_id);
+ return scan_patterns(ssl_host_matcher, hostname, size, client_id, payload_id, false);
}
bool SslPatternMatchers::scan_cname(const uint8_t* common_name, size_t size, AppId& client_id, AppId& payload_id)
{
- return scan_patterns(ssl_cname_matcher, common_name, size, client_id, payload_id);
+ return scan_patterns(ssl_host_matcher, common_name, size, client_id, payload_id, true);
}
#ifndef SSL_PATTERNS_H
#define SSL_PATTERNS_H
+#include <cstring>
+#include <unordered_set>
#include "search_engines/search_tool.h"
#include "application_ids.h"
AppId app_id;
uint8_t* pattern;
int pattern_size;
+ bool is_cname;
+
+ bool operator==(const SslPattern& v) const
+ {
+ return this->type == v.type and pattern_size == v.pattern_size
+ and (memcmp(pattern, v.pattern, (size_t)pattern_size) == 0);
+ }
};
+struct SslCacheKeyHasher
+{
+ size_t operator()(const SslPattern& key) const
+ {
+ return std::hash<std::string>{}(std::string((char*)key.pattern, key.pattern_size));
+ }
+};
+
+typedef std::unordered_set<SslPattern, SslCacheKeyHasher> CnameCache;
+
struct MatchedSslPatterns
{
SslPattern* mpattern;
{
public:
~SslPatternMatchers();
- void add_cert_pattern(uint8_t*, size_t, uint8_t, AppId);
- void add_cname_pattern(uint8_t*, size_t, uint8_t, AppId);
+ void add_cert_pattern(uint8_t*, size_t, uint8_t, AppId, bool);
void finalize_patterns();
void reload_patterns();
bool scan_hostname(const uint8_t*, size_t, AppId&, AppId&);
private:
SslPatternList* cert_pattern_list = nullptr;
- SslPatternList* cname_pattern_list = nullptr;
+ CnameCache cert_pattern_set;
snort::SearchTool ssl_host_matcher = snort::SearchTool();
- snort::SearchTool ssl_cname_matcher= snort::SearchTool();
};
#endif
int index = 1;
uint8_t type = lua_tointeger(L, ++index);
- AppId app_id = (AppId)lua_tointeger(L, ++index);
+ AppId app_id = (AppId)lua_tointeger(L, ++index);
size_t pattern_size = 0;
const char* tmp_string = lua_tolstring(L, ++index, &pattern_size);
if (!tmp_string or !pattern_size)
{
- ErrorMessage("appid: Invalid SSL Host pattern string in %s.\n", ud->get_detector()->get_name().c_str());
+ ErrorMessage("appid: Invalid SSL Host pattern string in %s.\n",
+ ud->get_detector()->get_name().c_str());
return 0;
}
uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string);
- ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id);
+ ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id,
+ false);
ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id);
return 0;
}
-// for Lua this looks something like: addDNSHostPattern(<appId>, '<pattern string>')
-static int detector_add_dns_host_pattern(lua_State* L)
+static int detector_add_ssl_cname_pattern(lua_State* L)
{
auto& ud = *UserData<LuaObject>::check(L, DETECTOR, 1);
// Verify detector user data and that we are NOT in packet context
const char* tmp_string = lua_tolstring(L, ++index, &pattern_size);
if (!tmp_string or !pattern_size)
{
- ErrorMessage("appid: Invalid DNS Host pattern string.\n");
+ ErrorMessage("appid: Invalid SSL CN pattern string in %s.\n",
+ ud->get_detector()->get_name().c_str());
return 0;
}
uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string);
- ud->get_odp_ctxt().get_dns_matchers().add_host_pattern(pattern_str, pattern_size, type, app_id);
+ ud->get_odp_ctxt().get_ssl_matchers().add_cert_pattern(pattern_str, pattern_size, type, app_id,
+ true);
+ ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id);
return 0;
}
-static int detector_add_ssl_cname_pattern(lua_State* L)
+// for Lua this looks something like: addDNSHostPattern(<appId>, '<pattern string>')
+static int detector_add_dns_host_pattern(lua_State* L)
{
auto& ud = *UserData<LuaObject>::check(L, DETECTOR, 1);
// Verify detector user data and that we are NOT in packet context
int index = 1;
uint8_t type = lua_tointeger(L, ++index);
- AppId app_id = (AppId)lua_tointeger(L, ++index);
+ AppId app_id = (AppId)lua_tointeger(L, ++index);
size_t pattern_size = 0;
const char* tmp_string = lua_tolstring(L, ++index, &pattern_size);
if (!tmp_string or !pattern_size)
{
- ErrorMessage("appid: Invalid SSL CN pattern string in %s.\n", ud->get_detector()->get_name().c_str());
+ ErrorMessage("appid: Invalid DNS Host pattern string.\n");
return 0;
}
uint8_t* pattern_str = (uint8_t*)snort_strdup(tmp_string);
- ud->get_odp_ctxt().get_ssl_matchers().add_cname_pattern(pattern_str, pattern_size, type, app_id);
- ud->get_odp_ctxt().get_app_info_mgr().set_app_info_active(app_id);
+ ud->get_odp_ctxt().get_dns_matchers().add_host_pattern(pattern_str, pattern_size, type, app_id);
return 0;
}
{ "addRTMPUrl", detector_add_rtmp_url },
{ "addContentTypePattern", detector_add_content_type_pattern },
{ "addSSLCertPattern", detector_add_ssl_cert_pattern },
+ { "addSSLCnamePattern", detector_add_ssl_cname_pattern },
{ "addSipUserAgent", detector_add_sip_user_agent },
{ "addSipServer", detector_add_sip_server },
- { "addSSLCnamePattern", detector_add_ssl_cname_pattern },
{ "addSSHPattern", detector_add_ssh_client_pattern},
{ "addHostFirstPktApp", detector_add_host_first_pkt_application },
{ "addHostPortApp", detector_add_host_port_application },