From: Russ Combs (rucombs) Date: Mon, 25 Oct 2021 22:48:35 +0000 (+0000) Subject: Merge pull request #3122 in SNORT/snort3 from ~RUCOMBS/snort3:hyper_serial to master X-Git-Tag: 3.1.16.0~18 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=4e09ce8286d5897f5d57db8e1edec1f082030e8b;p=thirdparty%2Fsnort3.git Merge pull request #3122 in SNORT/snort3 from ~RUCOMBS/snort3:hyper_serial to master Squashed commit of the following: commit 9daf5f9c73643d751835d24790aab34c9382f338 Author: russ Date: Wed Oct 13 14:19:08 2021 -0400 detection: refactor mpse serialization commit 5b0ab03288a64707313c5f3f4f1214df235556c1 Author: russ Date: Wed Oct 13 10:19:58 2021 -0400 detection: rename PortGroup to the more apt RuleGroup (and related) PortGroup is a legacy name that predates service. RuleGroups are a collection of rules based on port (port, src|dst|any, #) or service (service, c2s|s2c). commit 47fa569f433c9c0ae034693c0caf76cfec65a89c Author: russ Date: Wed Oct 13 10:12:01 2021 -0400 detection: replace PortGroup::alloc/free with ctor/dtor commit 412073be22c8d8da0f7b532351bb377465186aad Author: russ Date: Mon Oct 11 15:33:47 2021 -0400 search_engine: support port group serialization commit 181e18b47f0a49a5a39dda02a44dc4f9702a3f97 Author: russ Date: Mon Oct 11 09:43:20 2021 -0400 ips: correct fast pattern port group counts commit edbeadd92064f02a0f7690f14805cb037ecbd980 Author: russ Date: Sun Oct 10 12:57:52 2021 -0400 mpse: add md5 check to deserialization commit 2dc6cde03deddcf2af26626fee5075e957d06fa9 Author: russ Date: Thu Oct 7 10:24:09 2021 -0400 hyperscan: sort patterns for dump / load stability commit 8fcc0ac4b79fe51e8d2a76484dc05238069b331b Author: russ Date: Thu Oct 7 07:53:37 2021 -0400 search_engine: support hyperscan serialization Dump hyperscan databases for service rule groups to the given directory with --dump-rule-databases. They can be reloaded with search_engine.rule_db_dir. This does not serialize port group databases. --- diff --git a/src/detection/fp_config.h b/src/detection/fp_config.h index 9b5b169a4..245223a9b 100644 --- a/src/detection/fp_config.h +++ b/src/detection/fp_config.h @@ -25,6 +25,8 @@ #ifndef FP_CONFIG_H #define FP_CONFIG_H +#include + namespace snort { struct MpseApi; @@ -116,6 +118,12 @@ public: void set_debug_print_rule_groups_uncompiled() { portlists_flags |= PL_DEBUG_PRINT_RULEGROUPS_UNCOMPILED; } + void set_rule_db_dir(const char* s) + { rule_db_dir = s; } + + const std::string& get_rule_db_dir() const + { return rule_db_dir; } + void set_search_opt(bool flag) { search_opt = flag; } @@ -161,6 +169,8 @@ private: int portlists_flags = 0; int num_patterns_truncated = 0; // due to max_pattern_len + + std::string rule_db_dir; }; #endif diff --git a/src/detection/fp_create.cc b/src/detection/fp_create.cc index 78e61dfbd..647458d83 100644 --- a/src/detection/fp_create.cc +++ b/src/detection/fp_create.cc @@ -427,7 +427,7 @@ static int pmx_create_tree_offload(SnortConfig* sc, void* id, void** existing_tr return pmx_create_tree(sc, id, existing_tree, Mpse::MPSE_TYPE_OFFLOAD); } -static int fpFinishPortGroupRule( +static int fpFinishRuleGroupRule( Mpse* mpse, OptTreeNode* otn, PatternMatchData* pmd, FastPatternConfig* fp, bool get_final_pat) { const char* pattern; @@ -459,20 +459,20 @@ static int fpFinishPortGroupRule( return 0; } -static int fpFinishPortGroup(SnortConfig* sc, PortGroup* pg, FastPatternConfig* fp) +static int fpFinishRuleGroup(SnortConfig* sc, RuleGroup* pg, FastPatternConfig* fp) { - int i; int rules = 0; if (pg == nullptr) return -1; + if (fp == nullptr) { - snort_free(pg); + delete pg; return -1; } - for (i = PM_TYPE_PKT; i < PM_TYPE_MAX; i++) + for (int i = PM_TYPE_PKT; i < PM_TYPE_MAX; i++) { if (pg->mpsegrp[i] != nullptr) { @@ -539,7 +539,7 @@ static int fpFinishPortGroup(SnortConfig* sc, PortGroup* pg, FastPatternConfig* if (!rules) { /* Nothing in the port group so we can just free it */ - snort_free(pg); + delete pg; return -1; } @@ -549,11 +549,11 @@ static int fpFinishPortGroup(SnortConfig* sc, PortGroup* pg, FastPatternConfig* static void fpAddAlternatePatterns( Mpse* mpse, OptTreeNode* otn, PatternMatchData* pmd, FastPatternConfig* fp) { - fpFinishPortGroupRule(mpse, otn, pmd, fp, false); + fpFinishRuleGroupRule(mpse, otn, pmd, fp, false); } -static int fpAddPortGroupRule( - SnortConfig* sc, PortGroup* pg, OptTreeNode* otn, FastPatternConfig* fp, bool srvc) +static int fpAddRuleGroupRule( + SnortConfig* sc, RuleGroup* pg, OptTreeNode* otn, FastPatternConfig* fp, bool srvc) { const MpseApi* search_api = nullptr; const MpseApi* offload_search_api = nullptr; @@ -664,7 +664,7 @@ static int fpAddPortGroupRule( add_nfp_rule = true; // Now add patterns - if (fpFinishPortGroupRule( + if (fpFinishRuleGroupRule( pg->mpsegrp[main_pmd->pm_type]->normal_mpse, otn, main_pmd, fp, true) == 0) { if (main_pmd->pattern_size > otn->longestPatternLen) @@ -687,7 +687,7 @@ static int fpAddPortGroupRule( add_nfp_rule = true; // Now add patterns - if (fpFinishPortGroupRule( + if (fpFinishRuleGroupRule( pg->mpsegrp[main_pmd->pm_type]->offload_mpse, otn, ol_pmd, fp, true) == 0) { if (ol_pmd->pattern_size > otn->longestPatternLen) @@ -730,8 +730,8 @@ static int fpAddPortGroupRule( /* * Original PortRuleMaps for each protocol requires creating the following structures. * - * PORT_RULE_MAP -> srcPortGroup,dstPortGroup,genericPortGroup - * PortGroup -> pgPatData, pgPatDataUri (acsm objects), (also rule_node lists 1/rule, + * PORT_RULE_MAP -> srcRuleGroup,dstRuleGroup,genericRuleGroup + * RuleGroup -> pgPatData, pgPatDataUri (acsm objects), (also rule_node lists 1/rule, * not needed). each rule content added to an acsm object has a PMX data ptr * associated with it. * RULE_NODE -> iRuleNodeID (used for bitmap object index) @@ -739,19 +739,19 @@ static int fpAddPortGroupRule( * * PortList model supports the same structures except: * - * PortGroup -> no rule_node lists needed, PortObjects maintain a list of rules used + * RuleGroup -> no rule_node lists needed, PortObjects maintain a list of rules used * * Generation of PortRuleMaps and data is done differently. * - * 1) Build tcp/udp/icmp/ip src and dst PortGroup objects based on the PortList Objects rules. + * 1) Build tcp/udp/icmp/ip src and dst RuleGroup objects based on the PortList Objects rules. * * 2) For each protocols PortList objects walk it's ports and assign the PORT_RULE_MAP src and - * dst PortGroup[port] array pointers to that PortList objects PortGroup. + * dst RuleGroup[port] array pointers to that PortList objects RuleGroup. * * Implementation: * - * Each PortList Object will be translated into a PortGroup, then pointed to by the - * PortGroup array in the PORT_RULE_MAP for the protocol + * Each PortList Object will be translated into a RuleGroup, then pointed to by the + * RuleGroup array in the PORT_RULE_MAP for the protocol * * protocol = tcp, udp, ip, icmp - one port_rule_map for each of these protocols * { create a port_rule_map @@ -790,7 +790,7 @@ static int fpAddPortGroupRule( */ struct PortIteratorData { - PortIteratorData(PortGroup** a, PortGroup* g) + PortIteratorData(RuleGroup** a, RuleGroup* g) { array = a; group = g; @@ -802,8 +802,8 @@ struct PortIteratorData pid->array[port] = pid->group; } - PortGroup** array; - PortGroup* group; + RuleGroup** array; + RuleGroup* group; }; static void fpCreateInitRuleMap( @@ -960,14 +960,14 @@ static int fpGetFinalPattern( return 0; } -static void fpPortGroupPrintRuleCount(PortGroup* pg, const char* what) +static void fpRuleGroupPrintRuleCount(RuleGroup* pg, const char* what) { int type; if (pg == nullptr) return; - LogMessage("PortGroup rule summary (%s):\n", what); + LogMessage("RuleGroup rule summary (%s):\n", what); for (type = PM_TYPE_PKT; type < PM_TYPE_MAX; type++) { @@ -997,13 +997,13 @@ static void fpDeletePMX(void* pv) } /* - * Create the PortGroup for these PortObject2 entities + * Create the RuleGroup for these PortObject2 entities * * This builds the 1st pass multi-pattern state machines for * content and uricontent based on the rules in the PortObjects * hash table. */ -static void fpCreatePortObject2PortGroup(SnortConfig* sc, PortObject2* po, PortObject2* poaa) +static void fpCreatePortObject2RuleGroup(SnortConfig* sc, PortObject2* po, PortObject2* poaa) { assert( po ); @@ -1017,12 +1017,12 @@ static void fpCreatePortObject2PortGroup(SnortConfig* sc, PortObject2* po, PortO return; /* create a port_group */ - PortGroup* pg = PortGroup::alloc(); + RuleGroup* pg = new RuleGroup; s_group = "port"; /* * Walk the rules in the PortObject and add to - * the PortGroup pattern state machine + * the RuleGroup pattern state machine * and to the port group RULE_NODE lists. * (The lists are still used in some cases * during detection to walk the rules in a group @@ -1057,11 +1057,11 @@ static void fpCreatePortObject2PortGroup(SnortConfig* sc, PortObject2* po, PortO assert(otn); if ( is_network_protocol(otn->snort_protocol_id) ) - fpAddPortGroupRule(sc, pg, otn, fp, false); + fpAddRuleGroupRule(sc, pg, otn, fp, false); } if (fp->get_debug_print_rule_group_build_details()) - fpPortGroupPrintRuleCount(pg, pox == po ? "ports" : "any"); + fpRuleGroupPrintRuleCount(pg, pox == po ? "ports" : "any"); if (pox == poaa) break; @@ -1070,7 +1070,7 @@ static void fpCreatePortObject2PortGroup(SnortConfig* sc, PortObject2* po, PortO } // This might happen if there was ip proto only rules...Don't return failure - if (fpFinishPortGroup(sc, pg, fp) != 0) + if (fpFinishRuleGroup(sc, pg, fp) != 0) return; po->group = pg; @@ -1080,7 +1080,7 @@ static void fpCreatePortObject2PortGroup(SnortConfig* sc, PortObject2* po, PortO /* * Create the port groups for this port table */ -static void fpCreatePortTablePortGroups(SnortConfig* sc, PortTable* p, PortObject2* poaa) +static void fpCreatePortTableRuleGroups(SnortConfig* sc, PortTable* p, PortObject2* poaa) { int cnt = 1; FastPatternConfig* fp = sc->fast_pattern_config; @@ -1098,12 +1098,12 @@ static void fpCreatePortTablePortGroups(SnortConfig* sc, PortTable* p, PortObjec if (fp->get_debug_print_rule_group_build_details()) LogMessage("Creating Port Group Object %d of %d\n", cnt++, p->pt_mpo_hash->get_count()); - /* if the object is not referenced, don't add it to the PortGroups + /* if the object is not referenced, don't add it to the RuleGroups * as it may overwrite other objects that are more inclusive. */ if ( !po->port_cnt ) continue; - fpCreatePortObject2PortGroup(sc, po, poaa); + fpCreatePortObject2RuleGroup(sc, po, poaa); } } @@ -1113,7 +1113,7 @@ static void fpCreatePortTablePortGroups(SnortConfig* sc, PortTable* p, PortObjec * note: any ports are standard PortObjects not PortObject2s so we have to * upgrade them for the create port group function */ -static int fpCreatePortGroups(SnortConfig* sc, RulePortTables* p) +static int fpCreateRuleGroups(SnortConfig* sc, RulePortTables* p) { if (!get_rule_count()) return 0; @@ -1128,17 +1128,17 @@ static int fpCreatePortGroups(SnortConfig* sc, RulePortTables* p) if ( log_rule_group_details ) LogMessage("\nIP-SRC "); - fpCreatePortTablePortGroups(sc, p->ip.src, add_any_any); + fpCreatePortTableRuleGroups(sc, p->ip.src, add_any_any); if ( log_rule_group_details ) LogMessage("\nIP-DST "); - fpCreatePortTablePortGroups(sc, p->ip.dst, add_any_any); + fpCreatePortTableRuleGroups(sc, p->ip.dst, add_any_any); if ( log_rule_group_details ) LogMessage("\nIP-ANY "); - fpCreatePortObject2PortGroup(sc, po2, nullptr); + fpCreatePortObject2RuleGroup(sc, po2, nullptr); p->ip.any->group = po2->group; po2->group = nullptr; PortObject2Free(po2); @@ -1150,17 +1150,17 @@ static int fpCreatePortGroups(SnortConfig* sc, RulePortTables* p) if ( log_rule_group_details ) LogMessage("\nICMP-SRC "); - fpCreatePortTablePortGroups(sc, p->icmp.src, add_any_any); + fpCreatePortTableRuleGroups(sc, p->icmp.src, add_any_any); if ( log_rule_group_details ) LogMessage("\nICMP-DST "); - fpCreatePortTablePortGroups(sc, p->icmp.dst, add_any_any); + fpCreatePortTableRuleGroups(sc, p->icmp.dst, add_any_any); if ( log_rule_group_details ) LogMessage("\nICMP-ANY "); - fpCreatePortObject2PortGroup(sc, po2, nullptr); + fpCreatePortObject2RuleGroup(sc, po2, nullptr); p->icmp.any->group = po2->group; po2->group = nullptr; PortObject2Free(po2); @@ -1171,17 +1171,17 @@ static int fpCreatePortGroups(SnortConfig* sc, RulePortTables* p) if ( log_rule_group_details ) LogMessage("\nTCP-SRC "); - fpCreatePortTablePortGroups(sc, p->tcp.src, add_any_any); + fpCreatePortTableRuleGroups(sc, p->tcp.src, add_any_any); if ( log_rule_group_details ) LogMessage("\nTCP-DST "); - fpCreatePortTablePortGroups(sc, p->tcp.dst, add_any_any); + fpCreatePortTableRuleGroups(sc, p->tcp.dst, add_any_any); if ( log_rule_group_details ) LogMessage("\nTCP-ANY "); - fpCreatePortObject2PortGroup(sc, po2, nullptr); + fpCreatePortObject2RuleGroup(sc, po2, nullptr); p->tcp.any->group = po2->group; po2->group = nullptr; PortObject2Free(po2); @@ -1193,17 +1193,17 @@ static int fpCreatePortGroups(SnortConfig* sc, RulePortTables* p) if ( log_rule_group_details ) LogMessage("\nUDP-SRC "); - fpCreatePortTablePortGroups(sc, p->udp.src, add_any_any); + fpCreatePortTableRuleGroups(sc, p->udp.src, add_any_any); if ( log_rule_group_details ) LogMessage("\nUDP-DST "); - fpCreatePortTablePortGroups(sc, p->udp.dst, add_any_any); + fpCreatePortTableRuleGroups(sc, p->udp.dst, add_any_any); if ( log_rule_group_details ) LogMessage("\nUDP-ANY "); - fpCreatePortObject2PortGroup(sc, po2, nullptr); + fpCreatePortObject2RuleGroup(sc, po2, nullptr); p->udp.any->group = po2->group; po2->group = nullptr; PortObject2Free(po2); @@ -1214,7 +1214,7 @@ static int fpCreatePortGroups(SnortConfig* sc, RulePortTables* p) if ( log_rule_group_details ) LogMessage("\nSVC-ANY "); - fpCreatePortObject2PortGroup(sc, po2, nullptr); + fpCreatePortObject2RuleGroup(sc, po2, nullptr); p->svc_any->group = po2->group; po2->group = nullptr; PortObject2Free(po2); @@ -1231,10 +1231,10 @@ static int fpCreatePortGroups(SnortConfig* sc, RulePortTables* p) * ...could use a service id instead (bytes, fixed length,etc...) * list- list of otns for this service */ -static void fpBuildServicePortGroupByServiceOtnList( +static void fpBuildServiceRuleGroupByServiceOtnList( SnortConfig* sc, GHash* p, const char* srvc, SF_LIST* list, FastPatternConfig* fp) { - PortGroup* pg = PortGroup::alloc(); + RuleGroup* pg = new RuleGroup; s_group = srvc; /* @@ -1247,10 +1247,10 @@ static void fpBuildServicePortGroupByServiceOtnList( otn; otn = (OptTreeNode*)sflist_next(&cursor) ) { - fpAddPortGroupRule(sc, pg, otn, fp, true); + fpAddRuleGroupRule(sc, pg, otn, fp, true); } - if (fpFinishPortGroup(sc, pg, fp) != 0) + if (fpFinishRuleGroup(sc, pg, fp) != 0) return; /* Add the port_group using it's service name */ @@ -1258,11 +1258,11 @@ static void fpBuildServicePortGroupByServiceOtnList( } /* - * For each service we create a PortGroup based on the otn's defined to + * For each service we create a RuleGroup based on the otn's defined to * be applicable to that service by the metadata option. * * Then we lookup the protocol/srvc ordinal in the target-based area - * and assign the PortGroup for the srvc to it. + * and assign the RuleGroup for the srvc to it. * * spg - service port group (lookup should be by service id/tag) * - this table maintains a port_group ptr for each service @@ -1270,8 +1270,8 @@ static void fpBuildServicePortGroupByServiceOtnList( * - this table maintains a SF_LIST ptr (list of rule otns) for each service * */ -static void fpBuildServicePortGroups( - SnortConfig* sc, GHash* spg, PortGroupVector& sopg, GHash* srm, FastPatternConfig* fp) +static void fpBuildServiceRuleGroups( + SnortConfig* sc, GHash* spg, RuleGroupVector& sopg, GHash* srm, FastPatternConfig* fp) { for (GHashNode* n = srm->find_first(); n; n = srm->find_next()) { @@ -1280,10 +1280,10 @@ static void fpBuildServicePortGroups( assert(list and srvc); - fpBuildServicePortGroupByServiceOtnList(sc, spg, srvc, list, fp); + fpBuildServiceRuleGroupByServiceOtnList(sc, spg, srvc, list, fp); - /* Add this PortGroup to the protocol-ordinal -> port_group table */ - PortGroup* pg = (PortGroup*)spg->find(srvc); + /* Add this RuleGroup to the protocol-ordinal -> port_group table */ + RuleGroup* pg = (RuleGroup*)spg->find(srvc); if ( !pg ) { ParseError("*** failed to create and find a port group for '%s'",srvc); @@ -1298,19 +1298,19 @@ static void fpBuildServicePortGroups( } /* - * For each proto+dir+service build a PortGroup + * For each proto+dir+service build a RuleGroup */ -static void fpCreateServiceMapPortGroups(SnortConfig* sc) +static void fpCreateServiceMapRuleGroups(SnortConfig* sc) { FastPatternConfig* fp = sc->fast_pattern_config; - sc->spgmmTable = ServicePortGroupMapNew(); + sc->spgmmTable = ServiceRuleGroupMapNew(); sc->sopgTable = new sopg_table_t(sc->proto_ref->get_count()); - fpBuildServicePortGroups(sc, sc->spgmmTable->to_srv, + fpBuildServiceRuleGroups(sc, sc->spgmmTable->to_srv, sc->sopgTable->to_srv, sc->srmmTable->to_srv, fp); - fpBuildServicePortGroups(sc, sc->spgmmTable->to_cli, + fpBuildServiceRuleGroups(sc, sc->spgmmTable->to_cli, sc->sopgTable->to_cli, sc->srmmTable->to_cli, fp); } @@ -1399,14 +1399,14 @@ static void fp_print_service_rules_by_proto(SnortConfig* sc) fp_print_service_rules(sc, sc->srmmTable->to_srv, sc->srmmTable->to_cli); } -static void fp_sum_port_groups(PortGroup* pg, unsigned c[PM_TYPE_MAX]) +static void fp_sum_port_groups(RuleGroup* pg, unsigned c[PM_TYPE_MAX]) { if ( !pg ) return; for ( int i = PM_TYPE_PKT; i < PM_TYPE_MAX; ++i ) if ( pg->mpsegrp[i] and pg->mpsegrp[i]->normal_mpse and - pg->mpsegrp[i]->normal_mpse->get_pattern_count() ) + pg->mpsegrp[i]->normal_mpse->get_pattern_count() ) c[i]++; } @@ -1416,7 +1416,7 @@ static void fp_sum_service_groups(GHash* h, unsigned c[PM_TYPE_MAX]) node; node = h->find_next()) { - PortGroup* pg = (PortGroup*)node->data; + RuleGroup* pg = (RuleGroup*)node->data; fp_sum_port_groups(pg, c); } } @@ -1447,9 +1447,9 @@ static void fp_print_service_groups(srmm_table_t* srmm) static void fp_sum_port_groups(PortTable* tab, unsigned c[PM_TYPE_MAX]) { - for (GHashNode* node = tab->pt_mpxo_hash->find_first(); + for (GHashNode* node = tab->pt_mpo_hash->find_first(); node; - node = tab->pt_mpxo_hash->find_next()) + node = tab->pt_mpo_hash->find_next()) { PortObject2* po = (PortObject2*)node->data; fp_sum_port_groups(po->group, c); @@ -1466,28 +1466,28 @@ static void fp_print_port_groups(RulePortTables* port_tables) fp_sum_port_groups(port_tables->ip.src, src); fp_sum_port_groups(port_tables->ip.dst, dst); - fp_sum_port_groups((PortGroup*)port_tables->ip.any->group, any); + fp_sum_port_groups((RuleGroup*)port_tables->ip.any->group, any); PortObjectFinalize(port_tables->ip.any); PortObjectFinalize(port_tables->ip.nfp); fp_sum_port_groups(port_tables->icmp.src, src); fp_sum_port_groups(port_tables->icmp.dst, dst); - fp_sum_port_groups((PortGroup*)port_tables->icmp.any->group, any); + fp_sum_port_groups((RuleGroup*)port_tables->icmp.any->group, any); PortObjectFinalize(port_tables->icmp.any); PortObjectFinalize(port_tables->icmp.nfp); fp_sum_port_groups(port_tables->tcp.src, src); fp_sum_port_groups(port_tables->tcp.dst, dst); - fp_sum_port_groups((PortGroup*)port_tables->tcp.any->group, any); + fp_sum_port_groups((RuleGroup*)port_tables->tcp.any->group, any); PortObjectFinalize(port_tables->tcp.any); PortObjectFinalize(port_tables->tcp.nfp); fp_sum_port_groups(port_tables->udp.src, src); fp_sum_port_groups(port_tables->udp.dst, dst); - fp_sum_port_groups((PortGroup*)port_tables->udp.any->group, any); + fp_sum_port_groups((RuleGroup*)port_tables->udp.any->group, any); PortObjectFinalize(port_tables->udp.any); PortObjectFinalize(port_tables->udp.nfp); @@ -1509,10 +1509,10 @@ static void fp_print_port_groups(RulePortTables* port_tables) } /* - * Build Service based PortGroups using the rules + * Build Service based RuleGroups using the rules * metadata option service parameter. */ -static void fpCreateServicePortGroups(SnortConfig* sc) +static void fpCreateServiceRuleGroups(SnortConfig* sc) { FastPatternConfig* fp = sc->fast_pattern_config; @@ -1524,10 +1524,10 @@ static void fpCreateServicePortGroups(SnortConfig* sc) if ( fp->get_debug_print_rule_group_build_details() ) fpPrintServiceRuleMaps(sc); - fpCreateServiceMapPortGroups(sc); + fpCreateServiceMapRuleGroups(sc); if (fp->get_debug_print_rule_group_build_details()) - fpPrintServicePortGroupSummary(sc); + fpPrintServiceRuleGroupSummary(sc); ServiceMapFree(sc->srmmTable); sc->srmmTable = nullptr; @@ -1553,10 +1553,7 @@ static unsigned can_build_mt(FastPatternConfig* fp) } /* -* Port list version -* * 7/2007 - man -* * Build Pattern Groups for 1st pass of content searching using * multi-pattern search method. */ @@ -1582,11 +1579,10 @@ int fpCreateFastPacketDetection(SnortConfig* sc) MpseManager::start_search_engine(fp->get_search_api()); - /* Use PortObjects to create PortGroups */ if ( log_rule_group_details ) LogMessage("Creating Port Groups....\n"); - fpCreatePortGroups(sc, port_tables); + fpCreateRuleGroups(sc, port_tables); if ( log_rule_group_details ) { @@ -1594,7 +1590,6 @@ int fpCreateFastPacketDetection(SnortConfig* sc) LogMessage("Creating Rule Maps....\n"); } - /* Create rule_maps */ fpCreateRuleMaps(sc, port_tables); if ( log_rule_group_details ) @@ -1603,18 +1598,19 @@ int fpCreateFastPacketDetection(SnortConfig* sc) LogMessage("Creating Service Based Rule Maps....\n"); } - /* Build Service based port groups - rules require service metdata - * i.e. 'metatdata: service [=] service-name, ... ;' - * - * Also requires a service attribute for lookup ... - */ - fpCreateServicePortGroups(sc); + fpCreateServiceRuleGroups(sc); if ( log_rule_group_details ) LogMessage("Service Based Rule Maps Done....\n"); + unsigned mpse_loaded = 0; + unsigned mpse_dumped = 0; + if ( !sc->test_mode() or sc->mem_check() ) { + if ( !fp->get_rule_db_dir().empty() ) + mpse_loaded = fp_deserialize(sc, fp->get_rule_db_dir()); + unsigned c = compile_mpses(sc, can_build_mt(fp)); unsigned expected = mpse_count + offload_mpse_count; @@ -1627,6 +1623,9 @@ int fpCreateFastPacketDetection(SnortConfig* sc) fp_print_port_groups(port_tables); fp_print_service_groups(sc->spgmmTable); + if ( !sc->rule_db_dir.empty() ) + mpse_dumped = fp_serialize(sc, sc->rule_db_dir); + if ( mpse_count ) { LogLabel("search engine"); @@ -1639,8 +1638,9 @@ int fpCreateFastPacketDetection(SnortConfig* sc) MpseManager::print_mpse_summary(fp->get_offload_search_api()); } - if ( fp->get_num_patterns_truncated() ) - LogMessage("%25.25s: %-12u\n", "truncated patterns", fp->get_num_patterns_truncated()); + LogCount("truncated patterns", fp->get_num_patterns_truncated()); + LogCount("mpse_loaded", mpse_loaded); + LogCount("mpse_dumped", mpse_dumped); MpseManager::setup_search_engine(fp->get_search_api(), sc); @@ -1657,7 +1657,7 @@ void fpDeleteFastPacketDetection(SnortConfig* sc) delete sc->detection_option_tree_hash_table; fpFreeRuleMaps(sc); - ServicePortGroupMapFree(sc->spgmmTable); + ServiceRuleGroupMapFree(sc->spgmmTable); if ( sc->sopgTable ) delete sc->sopgTable; diff --git a/src/detection/fp_detect.cc b/src/detection/fp_detect.cc index b49ead2cc..4640d94ab 100644 --- a/src/detection/fp_detect.cc +++ b/src/detection/fp_detect.cc @@ -879,7 +879,7 @@ static inline int batch_search( static inline void search_buffer( Inspector* gadget, InspectionBuffer& buf, InspectionBuffer::Type ibt, - Packet* p, PortGroup* pg, PmType pmt, PegCount& cnt) + Packet* p, RuleGroup* pg, PmType pmt, PegCount& cnt) { if ( MpseGroup* so = pg->mpsegrp[pmt] ) { @@ -894,7 +894,7 @@ static inline void search_buffer( } } -static int fp_search(PortGroup* port_group, Packet* p, bool srvc) +static int fp_search(RuleGroup* port_group, Packet* p, bool srvc) { Inspector* gadget = p->flow ? p->flow->gadget : nullptr; InspectionBuffer buf; @@ -984,7 +984,7 @@ static int fp_search(PortGroup* port_group, Packet* p, bool srvc) } static inline void eval_fp( - PortGroup* port_group, Packet* p, char ip_rule, bool srvc) + RuleGroup* port_group, Packet* p, char ip_rule, bool srvc) { const uint8_t* tmp_payload = nullptr; uint16_t tmp_dsize = 0; @@ -1020,7 +1020,7 @@ static inline void eval_fp( } static inline void eval_nfp( - PortGroup* port_group, Packet* p, char ip_rule) + RuleGroup* port_group, Packet* p, char ip_rule) { bool repeat = false; int8_t curr_ip_layer = 0; @@ -1102,7 +1102,7 @@ static inline void eval_nfp( // for performance purposes. static inline void fpEvalHeaderSW( - PortGroup* port_group, Packet* p, char ip_rule, FPTask task, bool srvc = false) + RuleGroup* port_group, Packet* p, char ip_rule, FPTask task, bool srvc = false) { if ( !p->is_detection_enabled(p->packet_flags & PKT_FROM_CLIENT) ) return; @@ -1116,7 +1116,7 @@ static inline void fpEvalHeaderSW( static inline void fpEvalHeaderIp(Packet* p, FPTask task) { - PortGroup* any = nullptr, * ip_group = nullptr; + RuleGroup* any = nullptr, * ip_group = nullptr; if ( !prmFindRuleGroupIp(p->context->conf->prmIpRTNX, ANYPORT, &ip_group, &any) ) return; @@ -1133,7 +1133,7 @@ static inline void fpEvalHeaderIp(Packet* p, FPTask task) static inline void fpEvalHeaderIcmp(Packet* p, FPTask task) { - PortGroup* any = nullptr, * type = nullptr; + RuleGroup* any = nullptr, * type = nullptr; if ( !prmFindRuleGroupIcmp(p->context->conf->prmIcmpRTNX, p->ptrs.icmph->type, &type, &any) ) return; @@ -1148,7 +1148,7 @@ static inline void fpEvalHeaderIcmp(Packet* p, FPTask task) static inline void fpEvalHeaderTcp(Packet* p, FPTask task) { - PortGroup* src = nullptr, * dst = nullptr, * any = nullptr; + RuleGroup* src = nullptr, * dst = nullptr, * any = nullptr; if ( !prmFindRuleGroupTcp(p->context->conf->prmTcpRTNX, p->ptrs.dp, p->ptrs.sp, &src, &dst, &any) ) return; @@ -1165,7 +1165,7 @@ static inline void fpEvalHeaderTcp(Packet* p, FPTask task) static inline void fpEvalHeaderUdp(Packet* p, FPTask task) { - PortGroup* src = nullptr, * dst = nullptr, * any = nullptr; + RuleGroup* src = nullptr, * dst = nullptr, * any = nullptr; if ( !prmFindRuleGroupUdp(p->context->conf->prmUdpRTNX, p->ptrs.dp, p->ptrs.sp, &src, &dst, &any) ) return; @@ -1187,7 +1187,7 @@ static inline void fpEvalHeaderSvc(Packet* p, FPTask task) if (snort_protocol_id == UNKNOWN_PROTOCOL_ID or snort_protocol_id == INVALID_PROTOCOL_ID) return; - PortGroup* svc = nullptr; + RuleGroup* svc = nullptr; if (p->is_from_application_server()) svc = p->context->conf->sopgTable->get_port_group(false, snort_protocol_id); diff --git a/src/detection/fp_detect.h b/src/detection/fp_detect.h index 70c565ad4..d54b9393e 100644 --- a/src/detection/fp_detect.h +++ b/src/detection/fp_detect.h @@ -42,7 +42,7 @@ struct ProfileStats; } class Cursor; -struct PortGroup; +struct RuleGroup; struct OptTreeNode; extern THREAD_LOCAL snort::ProfileStats mpsePerfStats; diff --git a/src/detection/fp_utils.cc b/src/detection/fp_utils.cc index 99af3f0e9..c4c745b56 100644 --- a/src/detection/fp_utils.cc +++ b/src/detection/fp_utils.cc @@ -24,19 +24,29 @@ #include #include +#include +#include #include #include +#include #include +#include "framework/mpse.h" +#include "framework/mpse_batch.h" +#include "hash/ghash.h" #include "log/messages.h" #include "main/snort_config.h" #include "parser/parse_conf.h" #include "pattern_match_data.h" #include "ports/port_group.h" +#include "ports/port_table.h" +#include "ports/rule_port_tables.h" #include "target_based/snort_protocols.h" #include "treenodes.h" #include "utils/util.h" +#include "service_map.h" + #ifdef UNIT_TEST #include "catch/snort_catch.h" #endif @@ -230,10 +240,194 @@ bool FpSelector::is_better_than( return false; } +//-------------------------------------------------------------------------- +// mpse database serialization +//-------------------------------------------------------------------------- + +static unsigned mpse_loaded, mpse_dumped; + +static bool store(const std::string& s, const uint8_t* data, size_t len) +{ + std::ofstream out(s.c_str(), std::ofstream::binary); + out.write((const char*)data, len); + return true; +} + +static bool fetch(const std::string& s, uint8_t*& data, size_t& len) +{ + std::ifstream in(s.c_str(), std::ifstream::binary); + + if ( !in.is_open() ) + return false; + + in.seekg (0, in.end); + len = in.tellg(); + in.seekg (0); + + data = new uint8_t[len]; + in.read((char*)data, len); + + return true; +} + +static std::string make_db_name( + const std::string& path, const char* proto, const char* dir, const char* buf, const std::string& id) +{ + std::stringstream ss; + + ss << path << "/"; + ss << proto << "_"; + ss << dir << "_"; + ss << buf << "_"; + + ss << std::hex << std::setfill('0') << std::setw(2); + + for ( auto c : id ) + ss << (unsigned)(uint8_t)c; + + ss << ".hsdb"; + + return ss.str(); +} + +static bool db_dump(const std::string& path, const char* proto, const char* dir, RuleGroup* g) +{ + for ( auto i = 0; i < PM_TYPE_MAX; ++i ) + { + if ( !g->mpsegrp[i] ) + continue; + + std::string id; + g->mpsegrp[i]->normal_mpse->get_hash(id); + + std::string file = make_db_name(path, proto, dir, pm_type_strings[i], id); + + uint8_t* db = nullptr; + size_t len = 0; + + if ( g->mpsegrp[i]->normal_mpse->serialize(db, len) and db and len > 0 ) + { + store(file, db, len); + free(db); + ++mpse_dumped; + } + else + { + ParseWarning(WARN_RULES, "Failed to serialize %s", file.c_str()); + return false; + } + } + return true; +} + +static bool db_load(const std::string& path, const char* proto, const char* dir, RuleGroup* g) +{ + for ( auto i = 0; i < PM_TYPE_MAX; ++i ) + { + if ( !g->mpsegrp[i] ) + continue; + + std::string id; + g->mpsegrp[i]->normal_mpse->get_hash(id); + + std::string file = make_db_name(path, proto, dir, pm_type_strings[i], id); + + uint8_t* db = nullptr; + size_t len = 0; + + if ( !fetch(file, db, len) ) + { + ParseWarning(WARN_RULES, "Failed to read %s", file.c_str()); + return false; + } + else if ( !g->mpsegrp[i]->normal_mpse->deserialize(db, len) ) + { + ParseWarning(WARN_RULES, "Failed to deserialize %s", file.c_str()); + return false; + } + delete[] db; + ++mpse_loaded; + } + return true; +} + +typedef bool (*db_io)(const std::string&, const char*, const char*, RuleGroup*); + +static void port_io( + const std::string& path, const char* proto, const char* end, PortTable* pt, db_io func) +{ + for (GHashNode* node = pt->pt_mpo_hash->find_first(); + node; + node = pt->pt_mpo_hash->find_next()) + { + PortObject2* po = (PortObject2*)node->data; + + if ( !po or !po->group ) + continue; + + func(path, proto, end, po->group); + } +} + +static void port_io( + const std::string& path, const char* proto, const char* end, PortObject* po, db_io func) +{ + if ( po->group ) + func(path, proto, end, po->group); +} + +static void svc_io(const std::string& path, const char* dir, GHash* h, db_io func) +{ + for ( GHashNode* n = h->find_first(); n; n = h->find_next()) + { + func(path, (const char*)n->key, dir, (RuleGroup*)n->data); + } +} + +static void fp_io(const SnortConfig* sc, const std::string& path, db_io func) +{ + auto* pt = sc->port_tables; + + port_io(path, "ip", "src", pt->ip.src, func); + port_io(path, "ip", "dst", pt->ip.dst, func); + port_io(path, "ip", "any", pt->ip.any, func); + + port_io(path, "icmp", "src", pt->icmp.src, func); + port_io(path, "icmp", "dst", pt->icmp.dst, func); + port_io(path, "icmp", "any", pt->icmp.any, func); + + port_io(path, "tcp", "src", pt->tcp.src, func); + port_io(path, "tcp", "dst", pt->tcp.dst, func); + port_io(path, "tcp", "any", pt->tcp.any, func); + + port_io(path, "udp", "src", pt->udp.src, func); + port_io(path, "udp", "dst", pt->udp.dst, func); + port_io(path, "udp", "any", pt->udp.any, func); + + auto* sp = sc->spgmmTable; + + svc_io(path, "s2c", sp->to_cli, func); + svc_io(path, "c2s", sp->to_srv, func); +} + //-------------------------------------------------------------------------- // public methods //-------------------------------------------------------------------------- +unsigned fp_serialize(const SnortConfig* sc, const std::string& dir) +{ + mpse_dumped = 0; + fp_io(sc, dir, db_dump); + return mpse_dumped; +} + +unsigned fp_deserialize(const SnortConfig* sc, const std::string& dir) +{ + mpse_loaded = 0; + fp_io(sc, dir, db_load); + return mpse_loaded; +} + void validate_services(SnortConfig* sc, OptTreeNode* otn) { std::string svc; diff --git a/src/detection/fp_utils.h b/src/detection/fp_utils.h index 95a438cf3..84a564470 100644 --- a/src/detection/fp_utils.h +++ b/src/detection/fp_utils.h @@ -22,7 +22,10 @@ #define FP_UTILS_H // fast pattern utilities + +#include #include + #include "framework/ips_option.h" #include "framework/mpse.h" #include "ports/port_group.h" @@ -47,5 +50,8 @@ unsigned compile_mpses(struct snort::SnortConfig*, bool parallel = false); void validate_services(struct snort::SnortConfig*, OptTreeNode*); +unsigned fp_serialize(const struct snort::SnortConfig*, const std::string& dir); +unsigned fp_deserialize(const struct snort::SnortConfig*, const std::string& dir); + #endif diff --git a/src/detection/pcrm.cc b/src/detection/pcrm.cc index 45fd160f5..07b9f4650 100644 --- a/src/detection/pcrm.cc +++ b/src/detection/pcrm.cc @@ -51,13 +51,13 @@ PORT_RULE_MAP* prmNewMap() /* ** DESCRIPTION -** Given a PORT_RULE_MAP, this function selects the PortGroup or -** PortGroups necessary to fully match a given dport, sport pair. +** Given a PORT_RULE_MAP, this function selects the RuleGroup or +** RuleGroups necessary to fully match a given dport, sport pair. ** The selection logic looks at both the dport and sport and ** determines if one or both are unique. If one is unique, then -** the appropriate PortGroup ptr is set. If both are unique, then -** both th src and dst PortGroup ptrs are set. If neither of the -** ports are unique, then the gen PortGroup ptr is set. +** the appropriate RuleGroup ptr is set. If both are unique, then +** both th src and dst RuleGroup ptrs are set. If neither of the +** ports are unique, then the gen RuleGroup ptr is set. ** ** FORMAL OUTPUT ** int - 0: Don't evaluate @@ -65,7 +65,7 @@ PORT_RULE_MAP* prmNewMap() ** ** NOTES ** Currently, if there is a "unique conflict", we return both the src -** and dst PortGroups. This conflict forces us to do two searches, one +** and dst RuleGroups. This conflict forces us to do two searches, one ** for the src and one for the dst. So we are taking twice the time to ** inspect a packet then usual. Obviously, this is not good. There ** are several options that we have to deal with unique conflicts, but @@ -78,9 +78,9 @@ static int prmFindRuleGroup( PORT_RULE_MAP* p, int dport, int sport, - PortGroup** src, - PortGroup** dst, - PortGroup** gen + RuleGroup** src, + RuleGroup** dst, + RuleGroup** gen ) { if ( !p ) @@ -117,26 +117,26 @@ static int prmFindRuleGroup( ** are also used in the file fpdetect.c, where we do lookups ** on the initialized variables. */ -int prmFindRuleGroupIp(PORT_RULE_MAP* prm, int ip_proto, PortGroup** ip_group, PortGroup** gen) +int prmFindRuleGroupIp(PORT_RULE_MAP* prm, int ip_proto, RuleGroup** ip_group, RuleGroup** gen) { - PortGroup* src; + RuleGroup* src; return prmFindRuleGroup(prm, ip_proto, ANYPORT, &src, ip_group, gen); } -int prmFindRuleGroupIcmp(PORT_RULE_MAP* prm, int type, PortGroup** type_group, PortGroup** gen) +int prmFindRuleGroupIcmp(PORT_RULE_MAP* prm, int type, RuleGroup** type_group, RuleGroup** gen) { - PortGroup* src; + RuleGroup* src; return prmFindRuleGroup(prm, type, ANYPORT, &src, type_group, gen); } -int prmFindRuleGroupTcp(PORT_RULE_MAP* prm, int dport, int sport, PortGroup** src, - PortGroup** dst, PortGroup** gen) +int prmFindRuleGroupTcp(PORT_RULE_MAP* prm, int dport, int sport, RuleGroup** src, + RuleGroup** dst, RuleGroup** gen) { return prmFindRuleGroup(prm, dport, sport, src, dst, gen); } -int prmFindRuleGroupUdp(PORT_RULE_MAP* prm, int dport, int sport, PortGroup** src, - PortGroup** dst, PortGroup** gen) +int prmFindRuleGroupUdp(PORT_RULE_MAP* prm, int dport, int sport, RuleGroup** src, + RuleGroup** dst, RuleGroup** gen) { return prmFindRuleGroup(prm, dport, sport, src, dst, gen); } diff --git a/src/detection/pcrm.h b/src/detection/pcrm.h index 805da70d6..17b43a503 100644 --- a/src/detection/pcrm.h +++ b/src/detection/pcrm.h @@ -43,17 +43,17 @@ struct PORT_RULE_MAP int prmNumDstGroups; int prmNumSrcGroups; - PortGroup* prmSrcPort[snort::MAX_PORTS]; - PortGroup* prmDstPort[snort::MAX_PORTS]; - PortGroup* prmGeneric; + RuleGroup* prmSrcPort[snort::MAX_PORTS]; + RuleGroup* prmDstPort[snort::MAX_PORTS]; + RuleGroup* prmGeneric; }; PORT_RULE_MAP* prmNewMap(); -int prmFindRuleGroupTcp(PORT_RULE_MAP*, int, int, PortGroup**, PortGroup**, PortGroup**); -int prmFindRuleGroupUdp(PORT_RULE_MAP*, int, int, PortGroup**, PortGroup**, PortGroup**); -int prmFindRuleGroupIp(PORT_RULE_MAP*, int, PortGroup**, PortGroup**); -int prmFindRuleGroupIcmp(PORT_RULE_MAP*, int, PortGroup**, PortGroup**); +int prmFindRuleGroupTcp(PORT_RULE_MAP*, int, int, RuleGroup**, RuleGroup**, RuleGroup**); +int prmFindRuleGroupUdp(PORT_RULE_MAP*, int, int, RuleGroup**, RuleGroup**, RuleGroup**); +int prmFindRuleGroupIp(PORT_RULE_MAP*, int, RuleGroup**, RuleGroup**); +int prmFindRuleGroupIcmp(PORT_RULE_MAP*, int, RuleGroup**, RuleGroup**); #endif diff --git a/src/detection/service_map.cc b/src/detection/service_map.cc index 695d54121..08f647e7c 100644 --- a/src/detection/service_map.cc +++ b/src/detection/service_map.cc @@ -91,7 +91,7 @@ void ServiceMapFree(srmm_table_t* table) //------------------------------------------------------------------------- static void delete_pg(void* pv) -{ PortGroup::free((PortGroup*)pv); } +{ delete (RuleGroup*)pv; } static GHash* alloc_spgmm() { @@ -107,7 +107,7 @@ static void free_spgmm(GHash* table) delete table; } -srmm_table_t* ServicePortGroupMapNew() +srmm_table_t* ServiceRuleGroupMapNew() { srmm_table_t* table = (srmm_table_t*)snort_calloc(sizeof(srmm_table_t)); @@ -117,7 +117,7 @@ srmm_table_t* ServicePortGroupMapNew() return table; } -void ServicePortGroupMapFree(srmm_table_t* table) +void ServiceRuleGroupMapFree(srmm_table_t* table) { if ( !table ) return; @@ -170,10 +170,10 @@ static void ServiceMapAddOtn( ServiceMapAddOtnRaw(srmm->to_srv, servicename, otn); } -void fpPrintServicePortGroupSummary(SnortConfig* sc) +void fpPrintServiceRuleGroupSummary(SnortConfig* sc) { LogMessage("+--------------------------------\n"); - LogMessage("| Service-PortGroup Table Summary \n"); + LogMessage("| Service-RuleGroup Table Summary \n"); LogMessage("---------------------------------\n"); if ( unsigned n = sc->spgmmTable->to_srv->get_count() ) @@ -226,9 +226,9 @@ sopg_table_t::sopg_table_t(unsigned n) to_cli.resize(n, nullptr); } -PortGroup* sopg_table_t::get_port_group(bool c2s, SnortProtocolId snort_protocol_id) +RuleGroup* sopg_table_t::get_port_group(bool c2s, SnortProtocolId snort_protocol_id) { - PortGroupVector& v = c2s ? to_srv : to_cli; + RuleGroupVector& v = c2s ? to_srv : to_cli; if ( snort_protocol_id >= v.size() ) return nullptr; diff --git a/src/detection/service_map.h b/src/detection/service_map.h index 83b014a39..d7023130e 100644 --- a/src/detection/service_map.h +++ b/src/detection/service_map.h @@ -37,7 +37,7 @@ namespace snort struct SnortConfig; class GHash; } -struct PortGroup; +struct RuleGroup; // Service Rule Map Master Table struct srmm_table_t @@ -49,22 +49,22 @@ struct srmm_table_t srmm_table_t* ServiceMapNew(); void ServiceMapFree(srmm_table_t*); -srmm_table_t* ServicePortGroupMapNew(); -void ServicePortGroupMapFree(srmm_table_t*); +srmm_table_t* ServiceRuleGroupMapNew(); +void ServiceRuleGroupMapFree(srmm_table_t*); -void fpPrintServicePortGroupSummary(snort::SnortConfig*); +void fpPrintServiceRuleGroupSummary(snort::SnortConfig*); void fpCreateServiceMaps(snort::SnortConfig*); -// Service/Protocol Ordinal To PortGroup table -typedef std::vector PortGroupVector; +// Service/Protocol Ordinal To RuleGroup table +typedef std::vector RuleGroupVector; struct sopg_table_t { sopg_table_t(unsigned size); - PortGroup* get_port_group(bool c2s, SnortProtocolId svc); + RuleGroup* get_port_group(bool c2s, SnortProtocolId svc); - PortGroupVector to_srv; - PortGroupVector to_cli; + RuleGroupVector to_srv; + RuleGroupVector to_cli; }; diff --git a/src/framework/mpse.h b/src/framework/mpse.h index 9cb08ada3..b110fc1f0 100644 --- a/src/framework/mpse.h +++ b/src/framework/mpse.h @@ -98,6 +98,10 @@ public: virtual int print_info() { return 0; } virtual int get_pattern_count() const { return 0; } + virtual bool serialize(uint8_t*&, size_t&) const { return false; } + virtual bool deserialize(const uint8_t*, size_t) { return false; } + virtual void get_hash(std::string&) { } + const char* get_method() { return method.c_str(); } void set_verbose(bool b = true) { verbose = b; } diff --git a/src/main/modules.cc b/src/main/modules.cc index 75edfb4f6..f0b1b9a6b 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -182,6 +182,9 @@ static const Parameter search_engine_params[] = { "offload_search_method", Parameter::PT_DYNAMIC, (void*)&get_search_methods, nullptr, "set fast pattern offload algorithm - choose available search engine" }, + { "rule_db_dir", Parameter::PT_STRING, nullptr, nullptr, + "deserialize rule databases from given directory" }, + { "search_optimize", Parameter::PT_BOOL, nullptr, "true", "tweak state machine construction for better performance" }, @@ -285,6 +288,9 @@ bool SearchEngineModule::set(const char*, Value& v, SnortConfig* sc) else if ( v.is("detect_raw_tcp") ) fp->set_stream_insert(v.get_bool()); + else if ( v.is("rule_db_dir") ) + fp->set_rule_db_dir(v.get_string()); + else if ( v.is("search_method") ) { if ( !fp->set_search_method(v.get_string()) ) diff --git a/src/main/snort_config.cc b/src/main/snort_config.cc index b4a0bbf93..63dc1a2cf 100644 --- a/src/main/snort_config.cc +++ b/src/main/snort_config.cc @@ -408,6 +408,10 @@ void SnortConfig::merge(const SnortConfig* cmd_line_conf) if (cmd_line_conf->dirty_pig) dirty_pig = cmd_line_conf->dirty_pig; + // --dump-rule-databases + if (!cmd_line_conf->rule_db_dir.empty()) + rule_db_dir = cmd_line_conf->rule_db_dir; + // --id-offset id_offset = cmd_line_conf->id_offset; // --id-subdir @@ -647,6 +651,12 @@ void SnortConfig::set_obfuscation_mask(const char* mask) obfuscation_net.set(mask); } +void SnortConfig::set_rule_db_dir(const char* directory) +{ + assert(directory); + rule_db_dir = directory; +} + void SnortConfig::set_gid(const char* args) { struct group* gr; diff --git a/src/main/snort_config.h b/src/main/snort_config.h index da791bc49..eaf01791b 100644 --- a/src/main/snort_config.h +++ b/src/main/snort_config.h @@ -275,6 +275,7 @@ public: std::string chroot_dir; /* -t or config chroot */ std::string include_path; std::string plugin_path; + std::string rule_db_dir; std::vector script_paths; mode_t file_mask = 0; @@ -475,6 +476,7 @@ public: void set_obfuscation_mask(const char*); void set_include_path(const char*); void set_process_all_events(bool); + void set_rule_db_dir(const char*); void set_show_year(bool); void set_tunnel_verdicts(const char*); void set_tweaks(const char*); diff --git a/src/main/snort_module.cc b/src/main/snort_module.cc index 724c8a58d..9e0a35f12 100644 --- a/src/main/snort_module.cc +++ b/src/main/snort_module.cc @@ -368,6 +368,9 @@ static const Parameter s_params[] = { "--dump-defaults", Parameter::PT_STRING, "(optional)", nullptr, "[] output module defaults in Lua format" }, + { "--dump-rule-databases", Parameter::PT_STRING, nullptr, nullptr, + "dump rule databases to given directory (hyperscan only)" }, + { "--dump-rule-deps", Parameter::PT_IMPLIED, nullptr, nullptr, "dump rule dependencies in json format for use by other tools" }, @@ -910,6 +913,11 @@ bool SnortModule::set(const char*, Value& v, SnortConfig* sc) else if ( v.is("--dump-defaults") ) dump_defaults(sc, v.get_string()); + else if ( v.is("--dump-rule-databases") ) + { + sc->set_rule_db_dir(v.get_string()); + sc->run_flags |= (RUN_FLAG__TEST | RUN_FLAG__MEM_CHECK); + } else if ( v.is("--dump-rule-deps") ) { sc->run_flags |= (RUN_FLAG__DUMP_RULE_DEPS | RUN_FLAG__TEST); diff --git a/src/ports/port_group.cc b/src/ports/port_group.cc index 967bdae8f..2902d5b2c 100644 --- a/src/ports/port_group.cc +++ b/src/ports/port_group.cc @@ -28,32 +28,22 @@ #include "framework/mpse_batch.h" #include "utils/util.h" -void PortGroup::add_rule() +void RuleGroup::add_rule() { rule_count++; } -PortGroup* PortGroup::alloc() -{ return (PortGroup*)snort_calloc(sizeof(PortGroup)); } - -void PortGroup::free(PortGroup* pg) +RuleGroup::~RuleGroup() { - pg->delete_nfp_rules(); + delete_nfp_rules(); for (int i = PM_TYPE_PKT; i < PM_TYPE_MAX; i++) - { - if (pg->mpsegrp[i]) - { - delete pg->mpsegrp[i]; - pg->mpsegrp[i] = nullptr; - } - } + delete mpsegrp[i]; - free_detection_option_root(&pg->nfp_tree); - snort_free(pg); + free_detection_option_root(&nfp_tree); } -bool PortGroup::add_nfp_rule(void* rd) +bool RuleGroup::add_nfp_rule(void* rd) { if ( !nfp_head ) { @@ -77,7 +67,7 @@ bool PortGroup::add_nfp_rule(void* rd) return true; } -void PortGroup::delete_nfp_rules() +void RuleGroup::delete_nfp_rules() { RULE_NODE* rn = nfp_head; diff --git a/src/ports/port_group.h b/src/ports/port_group.h index 240620447..5b23df5d7 100644 --- a/src/ports/port_group.h +++ b/src/ports/port_group.h @@ -30,8 +30,8 @@ namespace snort class MpseGroup; } -// PortGroup contains a set of fast patterns in the form of an MPSE and a -// set of non-fast-pattern (nfp) rules. when a PortGroup is selected, the +// RuleGroup contains a set of fast patterns in the form of an MPSE and a +// set of non-fast-pattern (nfp) rules. when a RuleGroup is selected, the // MPSE will run fp rules if there is a match on the associated fast // patterns. it will always run nfp rules since there is no way to filter // them out. @@ -68,26 +68,27 @@ struct RULE_NODE int iRuleNodeID; }; -struct PortGroup +struct RuleGroup { + RuleGroup() = default; + ~RuleGroup(); + // non-fast-pattern list - RULE_NODE* nfp_head, * nfp_tail; + RULE_NODE* nfp_head = nullptr; + RULE_NODE* nfp_tail = nullptr; // pattern matchers - snort::MpseGroup* mpsegrp[PM_TYPE_MAX]; + snort::MpseGroup* mpsegrp[PM_TYPE_MAX] = { }; // detection option tree - void* nfp_tree; + void* nfp_tree = nullptr; - unsigned rule_count; - unsigned nfp_rule_count; + unsigned rule_count = 0; + unsigned nfp_rule_count = 0; void add_rule(); bool add_nfp_rule(void*); void delete_nfp_rules(); - - static PortGroup* alloc(); - static void free(PortGroup*); }; #endif diff --git a/src/ports/port_object.cc b/src/ports/port_object.cc index 630ff93c6..960f6325a 100644 --- a/src/ports/port_object.cc +++ b/src/ports/port_object.cc @@ -63,7 +63,7 @@ void PortObjectFree(void* pv) sflist_free_all(po->rule_list, snort_free); if (po->group ) - PortGroup::free(po->group); + delete po->group; snort_free(po); } diff --git a/src/ports/port_object.h b/src/ports/port_object.h index e4ff88d74..3a043ab7d 100644 --- a/src/ports/port_object.h +++ b/src/ports/port_object.h @@ -26,10 +26,10 @@ //------------------------------------------------------------------------- // PortObject supports a set of PortObjectItems -// associates rules with a PortGroup. +// associates rules with a RuleGroup. //------------------------------------------------------------------------- -struct PortGroup; +struct RuleGroup; struct PortObjectItem; struct PortObject @@ -41,7 +41,7 @@ struct PortObject SF_LIST* item_list; /* list of port and port-range items */ SF_LIST* rule_list; /* list of rules */ - PortGroup* group; // based on rule_list - only used by any-any ports + RuleGroup* group; // based on rule_list - only used by any-any ports }; PortObject* PortObjectNew(); diff --git a/src/ports/port_object2.cc b/src/ports/port_object2.cc index 21f55b171..fb62f064d 100644 --- a/src/ports/port_object2.cc +++ b/src/ports/port_object2.cc @@ -134,7 +134,7 @@ void PortObject2Free(PortObject2* po) delete po->port_list; if (po->group ) - PortGroup::free(po->group); + delete po->group; snort_free(po); } diff --git a/src/ports/port_object2.h b/src/ports/port_object2.h index 026da558a..fc06b349a 100644 --- a/src/ports/port_object2.h +++ b/src/ports/port_object2.h @@ -47,7 +47,7 @@ struct PortObject2 snort::GHash* rule_hash; /* hash of rule (rule-indexes) in use */ PortBitSet* port_list; /* for collecting ports that use this object */ - struct PortGroup* group; /* PortGroup based on rule_hash */ + struct RuleGroup* group; /* RuleGroup based on rule_hash */ int port_cnt; /* count of ports using this object */ }; diff --git a/src/ports/port_table.cc b/src/ports/port_table.cc index 95cb06e66..406fc5fa5 100644 --- a/src/ports/port_table.cc +++ b/src/ports/port_table.cc @@ -805,7 +805,7 @@ void PortTablePrintInput(PortTable* p) } /* - Prints the original (normalized) PortGroups and + Prints the original (normalized) RuleGroups and as specified by the user */ void PortTablePrintUserRules(PortTable* p) @@ -826,7 +826,7 @@ void PortTablePrintUserRules(PortTable* p) /* Prints the Unique Port Groups and rules that reference them */ -void PortTablePrintPortGroups(PortTable* p) +void PortTablePrintRuleGroups(PortTable* p) { /* normalized user PortObjects and rule ids */ LogMessage(">>>PortTable - Compiled Port Groups\n"); diff --git a/src/ports/port_table.h b/src/ports/port_table.h index 38107e441..e808de8ba 100644 --- a/src/ports/port_table.h +++ b/src/ports/port_table.h @@ -87,7 +87,7 @@ int PortTablePrintCompiledEx(PortTable*, rim_print_f); void PortTablePrintInput(PortTable*); void PortTablePrintUserRules(PortTable*); -void PortTablePrintPortGroups(PortTable*); +void PortTablePrintRuleGroups(PortTable*); void RuleListSortUniq(SF_LIST*); void PortTableSortUniqRules(PortTable*); diff --git a/src/search_engines/hyperscan.cc b/src/search_engines/hyperscan.cc index 1108aadc4..6267b94be 100644 --- a/src/search_engines/hyperscan.cc +++ b/src/search_engines/hyperscan.cc @@ -22,14 +22,18 @@ #include "config.h" #endif -#include -#include - +#include #include #include +#include +#include + +#include +#include #include "framework/module.h" #include "framework/mpse.h" +#include "hash/hashes.h" #include "helpers/scratch_allocator.h" #include "log/messages.h" #include "main/snort_config.h" @@ -97,6 +101,14 @@ void Pattern::escape(const uint8_t* s, unsigned n, bool literal) } } +static bool compare(const Pattern& a, const Pattern& b) +{ + if ( a.pat != b.pat ) + return a.pat < b.pat; + + return a.flags < b.flags; +} + typedef std::vector PatternVector; // we need to update scratch in each compiler thread as each pattern is processed @@ -165,6 +177,14 @@ public: unsigned id, unsigned long long from, unsigned long long to, unsigned flags, void*); + bool serialize(uint8_t*& buf, size_t& sz) const override + { return hs_db and (hs_serialize_database(hs_db, (char**)&buf, &sz) == HS_SUCCESS) and buf; } + + bool deserialize(const uint8_t* buf, size_t sz) override + { return (hs_deserialize_database((const char*)buf, sz, &hs_db) == HS_SUCCESS) and hs_db; } + + void get_hash(std::string&) override; + private: void user_ctor(SnortConfig*); void user_dtor(); @@ -220,6 +240,14 @@ void HyperscanMpse::user_dtor() int HyperscanMpse::prep_patterns(SnortConfig* sc) { + if ( hs_db ) + { + if ( agent ) + user_ctor(sc); + + return 0; + } + if ( pvector.empty() ) return -1; @@ -229,6 +257,9 @@ int HyperscanMpse::prep_patterns(SnortConfig* sc) return -1; } + // sort for consistent serialization + std::sort(pvector.begin(), pvector.end(), compare); + hs_compile_error_t* errptr = nullptr; std::vector pats; std::vector flags; @@ -265,6 +296,23 @@ int HyperscanMpse::prep_patterns(SnortConfig* sc) return 0; } +void HyperscanMpse::get_hash(std::string& hash) +{ + if ( !hs_db ) + std::sort(pvector.begin(), pvector.end(), compare); + + std::stringstream ss; + + for ( auto& p : pvector ) + ss << p.pat << p.flags; + + std::string str = ss.str(); + uint8_t buf[MD5_HASH_SIZE]; + + md5((const uint8_t*)str.c_str(), str.size(), buf); + hash.assign((const char*)buf, sizeof(buf)); +} + void HyperscanMpse::reuse_search() { if ( pvector.empty() ) diff --git a/src/search_engines/test/hyperscan_test.cc b/src/search_engines/test/hyperscan_test.cc index dd9c73bac..cd2582964 100644 --- a/src/search_engines/test/hyperscan_test.cc +++ b/src/search_engines/test/hyperscan_test.cc @@ -137,6 +137,7 @@ void LogCount(char const*, uint64_t, FILE*) unsigned get_instance_id() { return 0; } +void md5(const unsigned char*, size_t, unsigned char*) { } } void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }