From 69ecbc1c88d2efc503bf100b91fc89faa8e63c01 Mon Sep 17 00:00:00 2001 From: "Hui Cao (huica)" Date: Fri, 16 Mar 2018 16:11:12 -0400 Subject: [PATCH] Merge pull request #1140 in SNORT/snort3 from proto_ref2 to master Squashed commit of the following: commit eac8b70f9a764d9834c66603f0ea818284c531eb Author: Steve Chew Date: Thu Mar 15 14:17:33 2018 -0400 ProtoRef: Converge on single name for SnortProtocolId. Fix threading problems. --- src/detection/detection_options.cc | 10 +-- src/detection/fp_create.cc | 29 +++++---- src/detection/fp_detect.cc | 26 ++++---- src/detection/fp_utils.cc | 14 ++-- src/detection/fp_utils.h | 2 +- src/detection/service_map.cc | 28 ++++---- src/detection/service_map.h | 4 +- src/detection/signature.h | 4 +- src/detection/treenodes.h | 7 +- src/flow/expect_cache.cc | 36 +++++----- src/flow/expect_cache.h | 3 +- src/flow/flow.h | 4 +- src/flow/flow_control.cc | 4 +- src/flow/flow_control.h | 2 +- src/framework/inspector.h | 16 +++-- src/framework/ips_option.h | 7 +- src/hash/test/ghash_test.cc | 2 +- src/host_tracker/host_cache.cc | 10 ++- src/host_tracker/host_tracker.h | 13 ++-- src/host_tracker/host_tracker_module.cc | 2 +- .../test/host_cache_module_test.cc | 15 ++--- src/host_tracker/test/host_cache_test.cc | 20 +++++- .../test/host_tracker_module_test.cc | 8 ++- src/host_tracker/test/host_tracker_test.cc | 6 +- src/ips_options/ips_content.cc | 2 +- src/ips_options/ips_flow.cc | 2 +- src/ips_options/ips_regex.cc | 2 +- src/ips_options/ips_sd_pattern.cc | 2 +- src/ips_options/test/ips_regex_test.cc | 3 +- src/loggers/alert_sf_socket.cc | 2 +- src/main/modules.cc | 10 +-- src/main/snort_config.cc | 39 +++++++---- src/main/snort_config.h | 8 ++- src/managers/ips_manager.cc | 8 +-- src/managers/ips_manager.h | 6 +- .../appid/app_info_table.cc | 17 ++--- src/network_inspectors/appid/app_info_table.h | 8 +-- src/network_inspectors/appid/appid_config.cc | 29 ++++++--- src/network_inspectors/appid/appid_config.h | 10 +-- .../appid/appid_discovery.cc | 2 +- .../appid/appid_inspector.cc | 4 +- src/network_inspectors/appid/appid_session.cc | 26 ++++---- src/network_inspectors/appid/appid_session.h | 6 +- .../appid/client_plugins/client_discovery.cc | 2 +- .../appid/detector_plugins/detector_sip.cc | 14 ++-- .../appid/lua_detector_api.cc | 6 +- .../service_plugins/service_discovery.cc | 2 +- .../appid/service_plugins/service_ftp.cc | 6 +- .../appid/service_plugins/service_ftp.h | 6 +- .../appid/service_plugins/service_rexec.cc | 7 +- .../appid/service_plugins/service_rexec.h | 2 +- .../appid/service_plugins/service_rpc.cc | 7 +- .../appid/service_plugins/service_rpc.h | 2 +- .../appid/service_plugins/service_rshell.cc | 6 +- .../appid/service_plugins/service_rshell.h | 2 +- .../appid/service_plugins/service_snmp.cc | 7 +- .../appid/service_plugins/service_snmp.h | 2 +- .../appid/service_plugins/service_ssl.cc | 8 ++- .../appid/service_plugins/service_tftp.cc | 7 +- .../appid/service_plugins/service_tftp.h | 2 +- .../appid/thirdparty_appid_utils.cc | 4 +- src/network_inspectors/binder/binder.cc | 14 ++-- src/parser/parse_conf.cc | 8 +-- src/parser/parse_rule.cc | 39 ++++++----- src/parser/parser.cc | 6 +- src/profiler/rule_profiler.cc | 2 +- src/protocols/packet.h | 10 +-- src/search_engines/test/hyperscan_test.cc | 2 +- src/search_engines/test/search_tool_test.cc | 2 +- .../dce_rpc/ips_dce_iface.cc | 8 +-- src/service_inspectors/ftp_telnet/ft_main.h | 5 +- src/service_inspectors/ftp_telnet/ftp.cc | 4 +- src/service_inspectors/ftp_telnet/pp_ftp.cc | 8 +-- src/stream/base/stream_ha.cc | 2 +- src/stream/file/file_session.cc | 2 +- src/stream/stream.cc | 56 ++++++++-------- src/stream/stream.h | 10 +-- src/stream/tcp/tcp_reassembler.cc | 8 +-- src/target_based/sftarget_data.h | 3 +- src/target_based/sftarget_hostentry.cc | 4 +- src/target_based/sftarget_hostentry.h | 2 +- src/target_based/sftarget_reader.cc | 16 ++--- src/target_based/snort_protocols.cc | 65 +++++++++++++------ src/target_based/snort_protocols.h | 42 +++++++----- 84 files changed, 487 insertions(+), 381 deletions(-) diff --git a/src/detection/detection_options.cc b/src/detection/detection_options.cc index cab467d4f..4d8a00bd8 100644 --- a/src/detection/detection_options.cc +++ b/src/detection/detection_options.cc @@ -419,16 +419,16 @@ int detection_option_node_evaluate( // Add the match for this otn to the queue. { OptTreeNode* otn = (OptTreeNode*)node->option_data; - int16_t app_proto = p->get_application_protocol(); + SnortProtocolId snort_protocol_id = p->get_snort_protocol_id(); int check_ports = 1; - if ( app_proto and ((OtnxMatchData*)(pomd))->check_ports != 2 ) + if ( snort_protocol_id != UNKNOWN_PROTOCOL_ID and ((OtnxMatchData*)(pomd))->check_ports != 2 ) { auto sig_info = otn->sigInfo; for ( unsigned svc_idx = 0; svc_idx < sig_info.num_services; ++svc_idx ) { - if ( app_proto == sig_info.services[svc_idx].service_ordinal ) + if ( snort_protocol_id == sig_info.services[svc_idx].snort_protocol_id ) { check_ports = 0; break; // out of for @@ -440,10 +440,10 @@ int detection_option_node_evaluate( // none of the services match DebugFormat(DEBUG_DETECT, "[**] SID %u not matched because of service mismatch (%d!=%d [**]\n", - sig_info.sid, app_proto, sig_info.services[0].service_ordinal); + sig_info.sid, snort_protocol_id, sig_info.services[0].snort_protocol_id); trace_logf(detection, TRACE_RULE_EVAL, "SID %u not matched because of service mismatch %d!=%d \n", - sig_info.sid, app_proto, sig_info.services[0].service_ordinal); + sig_info.sid, snort_protocol_id, sig_info.services[0].snort_protocol_id); break; // out of case } } diff --git a/src/detection/fp_create.cc b/src/detection/fp_create.cc index ff7c9f50e..796d2d858 100644 --- a/src/detection/fp_create.cc +++ b/src/detection/fp_create.cc @@ -930,7 +930,7 @@ static int fpCreatePortObject2PortGroup( otn = OtnLookup(sc->otn_map, gid, sid); assert(otn); - if ( is_network_protocol(otn->proto) ) + if ( is_network_protocol(otn->snort_protocol_id) ) fpAddPortGroupRule(sc, pg, otn, fp, false); } @@ -1242,11 +1242,14 @@ static void fpBuildServicePortGroups( ParseError("*** failed to create and find a port group for '%s'",srvc); continue; } - int16_t id = sc->proto_ref->find(srvc); - assert(id != SFTARGET_UNKNOWN_PROTOCOL); + SnortProtocolId snort_protocol_id = sc->proto_ref->find(srvc); + assert(snort_protocol_id != UNKNOWN_PROTOCOL_ID); + assert((unsigned)snort_protocol_id < sopg.size()); - assert((unsigned)id < sopg.size()); - sopg[ id ] = pg; + if(snort_protocol_id == UNKNOWN_PROTOCOL_ID) + continue; + + sopg[ snort_protocol_id ] = pg; } } @@ -1319,13 +1322,13 @@ static void fpPrintServiceRuleMapTable(GHash* p, const char* proto, const char* } } -static void fpPrintServiceRuleMaps(SnortConfig* sc, srmm_table_t* service_map) +static void fpPrintServiceRuleMaps(SnortConfig* sc) { for ( int i = SNORT_PROTO_IP; i < SNORT_PROTO_MAX; ++i ) { const char* s = sc->proto_ref->get_name(i); - fpPrintServiceRuleMapTable(service_map->to_srv[i], s, "to server"); - fpPrintServiceRuleMapTable(service_map->to_cli[i], s, "to client"); + fpPrintServiceRuleMapTable(sc->srmmTable->to_srv[i], s, "to server"); + fpPrintServiceRuleMapTable(sc->srmmTable->to_cli[i], s, "to client"); } } @@ -1362,10 +1365,10 @@ static void fp_print_service_rules(SnortConfig* sc, GHash* cli, GHash* srv, cons LogMessage("%25.25s: %8u%8u\n", "total", ctot, stot); } -static void fp_print_service_rules_by_proto(SnortConfig* sc, srmm_table_t* srmm) +static void fp_print_service_rules_by_proto(SnortConfig* sc) { for ( int i = SNORT_PROTO_IP; i < SNORT_PROTO_MAX; ++i ) - fp_print_service_rules(sc, srmm->to_srv[i], srmm->to_cli[i], sc->proto_ref->get_name(i)); + fp_print_service_rules(sc, sc->srmmTable->to_srv[i], sc->srmmTable->to_cli[i], sc->proto_ref->get_name(i)); } static void fp_sum_port_groups(PortGroup* pg, unsigned c[PM_TYPE_MAX]) @@ -1490,15 +1493,15 @@ static int fpCreateServicePortGroups(SnortConfig* sc) if (fpCreateServiceMaps(sc)) return -1; - fp_print_service_rules_by_proto(sc, sc->srmmTable); + fp_print_service_rules_by_proto(sc); if ( fp->get_debug_print_rule_group_build_details() ) - fpPrintServiceRuleMaps(sc, sc->srmmTable); + fpPrintServiceRuleMaps(sc); fpCreateServiceMapPortGroups(sc); if (fp->get_debug_print_rule_group_build_details()) - fpPrintServicePortGroupSummary(sc, sc->spgmmTable); + fpPrintServicePortGroupSummary(sc); ServiceMapFree(sc->srmmTable); sc->srmmTable = nullptr; diff --git a/src/detection/fp_detect.cc b/src/detection/fp_detect.cc index e43318db4..a80003ccf 100644 --- a/src/detection/fp_detect.cc +++ b/src/detection/fp_detect.cc @@ -1145,40 +1145,40 @@ static inline void fpEvalHeaderUdp(Packet* p, OtnxMatchData* omd) fpEvalHeaderSW(any, p, 1, 0, 0, omd); } -static inline bool fpEvalHeaderSvc(Packet* p, OtnxMatchData* omd, int proto) +static inline bool fpEvalHeaderSvc(Packet* p, OtnxMatchData* omd, SnortProtocolId proto_id) { PortGroup* svc = nullptr, * file = nullptr; - int16_t proto_ordinal = p->get_application_protocol(); + SnortProtocolId snort_protocol_id = p->get_snort_protocol_id(); - DebugFormat(DEBUG_ATTRIBUTE, "proto_ordinal=%d\n", proto_ordinal); + DebugFormat(DEBUG_ATTRIBUTE, "snort_protocol_id=%hu\n", snort_protocol_id); - if (proto_ordinal > 0) + if (snort_protocol_id != UNKNOWN_PROTOCOL_ID and snort_protocol_id != INVALID_PROTOCOL_ID) { if (p->is_from_server()) /* to cli */ { DebugMessage(DEBUG_ATTRIBUTE, "pkt_from_server\n"); - svc = SnortConfig::get_conf()->sopgTable->get_port_group(proto, false, proto_ordinal); - file = SnortConfig::get_conf()->sopgTable->get_port_group(proto, false, SNORT_PROTO_FILE); + svc = SnortConfig::get_conf()->sopgTable->get_port_group(proto_id, false, snort_protocol_id); + file = SnortConfig::get_conf()->sopgTable->get_port_group(proto_id, false, SNORT_PROTO_FILE); } if (p->is_from_client()) /* to srv */ { DebugMessage(DEBUG_ATTRIBUTE, "pkt_from_client\n"); - svc = SnortConfig::get_conf()->sopgTable->get_port_group(proto, true, proto_ordinal); - file = SnortConfig::get_conf()->sopgTable->get_port_group(proto, true, SNORT_PROTO_FILE); + svc = SnortConfig::get_conf()->sopgTable->get_port_group(proto_id, true, snort_protocol_id); + file = SnortConfig::get_conf()->sopgTable->get_port_group(proto_id, true, SNORT_PROTO_FILE); } DebugFormat(DEBUG_ATTRIBUTE, "fpEvalHeaderSvc:targetbased-ordinal-lookup: " - "sport=%d, dport=%d, proto_ordinal=%d, proto=%d, src:%p, " - "file:%p\n",p->ptrs.sp,p->ptrs.dp,proto_ordinal,proto,(void*)svc,(void*)file); + "sport=%d, dport=%d, snort_protocol_id=%hu, proto_id=%d, src:%p, " + "file:%p\n",p->ptrs.sp,p->ptrs.dp,snort_protocol_id,proto_id,(void*)svc,(void*)file); } // FIXIT-P put alert service rules with file data fp in alert file group and // verify ports and service during rule eval to avoid searching file data 2x. - int check_ports = (proto == SNORT_PROTO_USER) ? 2 : 1; + int check_ports = (proto_id == SNORT_PROTO_USER) ? 2 : 1; if ( file ) fpEvalHeaderSW(file, p, check_ports, 0, 2, omd); @@ -1277,12 +1277,12 @@ static int fpEvalPacket(Packet* p) // use ports if we don't know service or don't have rules else if ( p->proto_bits & PROTO_BIT__TCP ) { - if ( !p->get_application_protocol() or !fpEvalHeaderSvc(p, omd, SNORT_PROTO_TCP) ) + if ( p->get_snort_protocol_id() == UNKNOWN_PROTOCOL_ID or !fpEvalHeaderSvc(p, omd, SNORT_PROTO_TCP) ) fpEvalHeaderTcp(p, omd); } else if ( p->proto_bits & PROTO_BIT__UDP ) { - if ( !p->get_application_protocol() or !fpEvalHeaderSvc(p, omd, SNORT_PROTO_UDP) ) + if ( p->get_snort_protocol_id() == UNKNOWN_PROTOCOL_ID or !fpEvalHeaderSvc(p, omd, SNORT_PROTO_UDP) ) fpEvalHeaderUdp(p, omd); } break; diff --git a/src/detection/fp_utils.cc b/src/detection/fp_utils.cc index 035d02f15..9e18408b1 100644 --- a/src/detection/fp_utils.cc +++ b/src/detection/fp_utils.cc @@ -47,7 +47,7 @@ using namespace snort; static void finalize_content(OptFpList* ofl) { - PatternMatchData* pmd = get_pmd(ofl, 0, RULE_WO_DIR); + PatternMatchData* pmd = get_pmd(ofl, UNKNOWN_PROTOCOL_ID, RULE_WO_DIR); if ( !pmd ) return; @@ -59,7 +59,7 @@ static void finalize_content(OptFpList* ofl) static void clear_fast_pattern_only(OptFpList* ofl) { - PatternMatchData* pmd = get_pmd(ofl, 0, RULE_WO_DIR); + PatternMatchData* pmd = get_pmd(ofl, UNKNOWN_PROTOCOL_ID, RULE_WO_DIR); if ( pmd && pmd->fp_only > 0 ) pmd->fp_only = 0; @@ -119,17 +119,17 @@ static RuleDirection get_dir(OptTreeNode* otn) // public utilities //-------------------------------------------------------------------------- -PatternMatchData* get_pmd(OptFpList* ofl, int proto, RuleDirection direction) +PatternMatchData* get_pmd(OptFpList* ofl, SnortProtocolId snort_protocol_id, RuleDirection direction) { if ( !ofl->ips_opt ) return nullptr; - return ofl->ips_opt->get_pattern(proto, direction); + return ofl->ips_opt->get_pattern(snort_protocol_id, direction); } bool is_fast_pattern_only(OptFpList* ofl) { - PatternMatchData* pmd = get_pmd(ofl, 0, RULE_WO_DIR); + PatternMatchData* pmd = get_pmd(ofl, UNKNOWN_PROTOCOL_ID, RULE_WO_DIR); if ( !pmd ) return false; @@ -307,7 +307,7 @@ PatternMatchVector get_fp_content( } RuleDirection dir = get_dir(otn); - PatternMatchData* tmp = get_pmd(ofl, otn->proto, dir); + PatternMatchData* tmp = get_pmd(ofl, otn->snort_protocol_id, dir); if ( !tmp ) continue; @@ -343,7 +343,7 @@ PatternMatchVector get_fp_content( else exclude = false; - if ( best.pmd and otn->proto == SNORT_PROTO_FILE and best.cat != CAT_SET_FILE ) + if ( best.pmd and otn->snort_protocol_id == SNORT_PROTO_FILE and best.cat != CAT_SET_FILE ) { ParseWarning(WARN_RULES, "file rule %u:%u does not have file_data fast pattern", otn->sigInfo.gid, otn->sigInfo.sid); diff --git a/src/detection/fp_utils.h b/src/detection/fp_utils.h index 0cb2f5a91..3d4f9d6a3 100644 --- a/src/detection/fp_utils.h +++ b/src/detection/fp_utils.h @@ -28,7 +28,7 @@ struct OptFpList; struct OptTreeNode; -struct PatternMatchData* get_pmd(OptFpList*, int proto, snort::RuleDirection); +struct PatternMatchData* get_pmd(OptFpList*, SnortProtocolId, snort::RuleDirection); bool is_fast_pattern_only(OptFpList*); void validate_fast_pattern(OptTreeNode*); diff --git a/src/detection/service_map.cc b/src/detection/service_map.cc index 1ab74ec75..b447ffeca 100644 --- a/src/detection/service_map.cc +++ b/src/detection/service_map.cc @@ -191,15 +191,15 @@ static void ServiceMapAddOtnRaw(GHash* table, const char* servicename, OptTreeNo * service name. */ static int ServiceMapAddOtn( - srmm_table_t* srmm, int proto, const char* servicename, OptTreeNode* otn) + srmm_table_t* srmm, SnortProtocolId proto_id, const char* servicename, OptTreeNode* otn) { assert(servicename and otn); - if ( proto > SNORT_PROTO_USER ) - proto = SNORT_PROTO_USER; + if ( proto_id > SNORT_PROTO_USER ) + proto_id = SNORT_PROTO_USER; - GHash* to_srv = srmm->to_srv[proto]; - GHash* to_cli = srmm->to_cli[proto]; + GHash* to_srv = srmm->to_srv[proto_id]; + GHash* to_cli = srmm->to_cli[proto_id]; if ( !OtnFlowFromClient(otn) ) ServiceMapAddOtnRaw(to_cli, servicename, otn); @@ -210,7 +210,7 @@ static int ServiceMapAddOtn( return 0; } -void fpPrintServicePortGroupSummary(SnortConfig* sc, srmm_table_t* srvc_pg_map) +void fpPrintServicePortGroupSummary(SnortConfig* sc) { LogMessage("+--------------------------------\n"); LogMessage("| Service-PortGroup Table Summary \n"); @@ -218,10 +218,10 @@ void fpPrintServicePortGroupSummary(SnortConfig* sc, srmm_table_t* srvc_pg_map) for ( int i = SNORT_PROTO_IP; i < SNORT_PROTO_MAX; i++ ) { - if ( unsigned n = srvc_pg_map->to_srv[i]->count ) + if ( unsigned n = sc->spgmmTable->to_srv[i]->count ) LogMessage("| %s to server : %d services\n", sc->proto_ref->get_name(i), n); - if ( unsigned n = srvc_pg_map->to_cli[i]->count ) + if ( unsigned n = sc->spgmmTable->to_cli[i]->count ) LogMessage("| %s to client : %d services\n", sc->proto_ref->get_name(i), n); } @@ -265,7 +265,7 @@ int fpCreateServiceMaps(SnortConfig* sc) { const char* svc = otn->sigInfo.services[svc_idx].service; - if ( ServiceMapAddOtn(sc->srmmTable, rtn->proto, svc, otn) ) + if ( ServiceMapAddOtn(sc->srmmTable, rtn->snort_protocol_id, svc, otn) ) return -1; } } @@ -293,16 +293,16 @@ sopg_table_t::sopg_table_t(unsigned n) } PortGroup* sopg_table_t::get_port_group( - int proto, bool c2s, int16_t proto_ordinal) + SnortProtocolId proto_id, bool c2s, SnortProtocolId snort_protocol_id) { - assert(proto < SNORT_PROTO_MAX); + assert(proto_id < SNORT_PROTO_MAX); - PortGroupVector& v = c2s ? to_srv[proto] : to_cli[proto]; + PortGroupVector& v = c2s ? to_srv[proto_id] : to_cli[proto_id]; - if ( (unsigned)proto_ordinal >= v.size() ) + if ( snort_protocol_id >= v.size() ) return nullptr; - return v[proto_ordinal]; + return v[snort_protocol_id]; } bool sopg_table_t::set_user_mode() diff --git a/src/detection/service_map.h b/src/detection/service_map.h index e36dbd5af..60394982f 100644 --- a/src/detection/service_map.h +++ b/src/detection/service_map.h @@ -52,7 +52,7 @@ void ServiceMapFree(srmm_table_t*); srmm_table_t* ServicePortGroupMapNew(); void ServicePortGroupMapFree(srmm_table_t*); -void fpPrintServicePortGroupSummary(snort::SnortConfig*, srmm_table_t*); +void fpPrintServicePortGroupSummary(snort::SnortConfig*); int fpCreateServiceMaps(snort::SnortConfig*); // Service/Protocol Ordinal To PortGroup table @@ -62,7 +62,7 @@ struct sopg_table_t { sopg_table_t(unsigned size); bool set_user_mode(); - PortGroup* get_port_group(int proto, bool c2s, int16_t proto_ordinal); + PortGroup* get_port_group(SnortProtocolId proto_id, bool c2s, SnortProtocolId snort_protocol_id); PortGroupVector to_srv[SNORT_PROTO_MAX]; PortGroupVector to_cli[SNORT_PROTO_MAX]; diff --git a/src/detection/signature.h b/src/detection/signature.h index e11d94a92..d6d307080 100644 --- a/src/detection/signature.h +++ b/src/detection/signature.h @@ -27,6 +27,8 @@ #include #include +#include "target_based/snort_protocols.h" + namespace snort { struct SnortConfig; @@ -74,7 +76,7 @@ ClassType* ClassTypeLookupByType(snort::SnortConfig*, const char*); struct SignatureServiceInfo { char* service; - int16_t service_ordinal; + SnortProtocolId snort_protocol_id; }; struct OtnKey diff --git a/src/detection/treenodes.h b/src/detection/treenodes.h index b1948f110..ab9484dc5 100644 --- a/src/detection/treenodes.h +++ b/src/detection/treenodes.h @@ -95,8 +95,9 @@ struct OptTreeNode int chain_node_number; int evalIndex; /* where this rule sits in the evaluation sets */ - int proto; /* protocol, added for integrity checks - during rule parsing */ + + // Added for integrity checks during rule parsing. + SnortProtocolId snort_protocol_id; unsigned ruleIndex; // unique index @@ -147,7 +148,7 @@ struct RuleTreeNode struct ListHead* listhead; - int proto; + SnortProtocolId snort_protocol_id; uint32_t flags; /* control flags */ diff --git a/src/flow/expect_cache.cc b/src/flow/expect_cache.cc index 3e5e06f5e..a4f87cf45 100644 --- a/src/flow/expect_cache.cc +++ b/src/flow/expect_cache.cc @@ -101,7 +101,7 @@ struct ExpectNode bool reversed_key = false; int direction = 0; unsigned count = 0; - int16_t appId = 0; + SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID; ExpectFlow* head = nullptr; ExpectFlow* tail = nullptr; @@ -251,10 +251,10 @@ bool ExpectCache::process_expected(ExpectNode* node, FlowKey& key, Packet* p, Fl free_list = head; /* If this is 0, we're ignoring, otherwise setting id of new session */ - if (!node->appId) + if (!node->snort_protocol_id) ignoring = node->direction ? true : false; - else if (lws->ssn_state.application_protocol != node->appId) - lws->ssn_state.application_protocol = node->appId; + else if (lws->ssn_state.snort_protocol_id != node->snort_protocol_id) + lws->ssn_state.snort_protocol_id = node->snort_protocol_id; if (!node->count) hash_table->remove(&key); @@ -310,21 +310,24 @@ ExpectCache::~ExpectCache() * Preprocessors may add sessions to be expected altogether or to be associated * with some data. For example, FTP preprocessor may add data channel that * should be expected. Alternatively, FTP preprocessor may add session with - * appId FTP-DATA. + * snort protocol ID FTP-DATA. * * It is assumed that only one of cliPort or srvPort should be known (!0). This * violation of this assumption will cause hash collision that will cause some * session to be not expected and expected. This will occur only rarely and * therefore acceptable design optimization. * - * Also, appId is assumed to be consistent between different preprocessors. - * Each session can be assigned only one AppId. When new appId mismatches - * existing appId, new appId and associated data is not stored. + * Also, snort_protocol_id is assumed to be consistent between different + * preprocessors. Each session can be assigned only one snort protocol ID. + * When new snort_protocol_id mismatches existing snort_protocol_id, new + * snort_protocol_id and associated data is not stored. * */ -int ExpectCache::add_flow(const Packet *ctrlPkt, PktType type, IpProtocol ip_proto, - const SfIp* cliIP, uint16_t cliPort, const SfIp* srvIP, uint16_t srvPort, - char direction, FlowData* fd, int16_t appId) +int ExpectCache::add_flow(const Packet *ctrlPkt, + PktType type, IpProtocol ip_proto, + const SfIp* cliIP, uint16_t cliPort, + const SfIp* srvIP, uint16_t srvPort, + char direction, FlowData* fd, SnortProtocolId snort_protocol_id) { /* Just pull the VLAN ID, MPLS ID, and Address Space ID from the control packet until we have a use case for not doing so. */ @@ -364,12 +367,13 @@ int ExpectCache::add_flow(const Packet *ctrlPkt, PktType type, IpProtocol ip_pro if (!new_node) { - /* Requests will be rejected if the AppID doesn't match what has already been set. */ - if (node->appId != appId) + // Requests will be rejected if the snort_protocol_id doesn't + // match what has already been set. + if (node->snort_protocol_id != snort_protocol_id) { - if (node->appId && appId) + if (node->snort_protocol_id && snort_protocol_id) return -1; - node->appId = appId; + node->snort_protocol_id = snort_protocol_id; } last = node->tail; @@ -390,7 +394,7 @@ int ExpectCache::add_flow(const Packet *ctrlPkt, PktType type, IpProtocol ip_pro } else { - node->appId = appId; + node->snort_protocol_id = snort_protocol_id; node->reversed_key = reversed_key; node->direction = direction; node->head = node->tail = nullptr; diff --git a/src/flow/expect_cache.h b/src/flow/expect_cache.h index d60e8d862..e41b2f761 100644 --- a/src/flow/expect_cache.h +++ b/src/flow/expect_cache.h @@ -62,6 +62,7 @@ //------------------------------------------------------------------------- #include #include "flow/flow_key.h" +#include "target_based/snort_protocols.h" struct ExpectNode; @@ -97,7 +98,7 @@ public: int add_flow(const snort::Packet *ctrlPkt, PktType, IpProtocol, const snort::SfIp* cliIP, uint16_t cliPort, const snort::SfIp* srvIP, uint16_t srvPort, - char direction, snort::FlowData*, int16_t appId = 0); + char direction, snort::FlowData*, SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID); bool is_expected(snort::Packet*); bool check(snort::Packet*, snort::Flow*); diff --git a/src/flow/flow.h b/src/flow/flow.h index e7872e4b5..4eaef486a 100644 --- a/src/flow/flow.h +++ b/src/flow/flow.h @@ -31,6 +31,7 @@ #include "framework/inspector.h" #include "protocols/layer.h" #include "sfip/sf_ip.h" +#include "target_based/snort_protocols.h" #define SSNFLAG_SEEN_CLIENT 0x00000001 #define SSNFLAG_SEEN_SENDER 0x00000001 @@ -134,7 +135,7 @@ struct LwState uint32_t session_flags; int16_t ipprotocol; - int16_t application_protocol; + SnortProtocolId snort_protocol_id; char direction; char ignore_direction; @@ -379,5 +380,6 @@ inline bool Flow::is_detection_enabled(bool to_server) return !(ssn_state.session_flags & SSNFLAG_NO_DETECT_TO_CLIENT); } } + #endif diff --git a/src/flow/flow_control.cc b/src/flow/flow_control.cc index 091f72217..3ae234884 100644 --- a/src/flow/flow_control.cc +++ b/src/flow/flow_control.cc @@ -797,11 +797,11 @@ int FlowControl::add_expected( const Packet* ctrlPkt, PktType type, IpProtocol ip_proto, const SfIp *srcIP, uint16_t srcPort, const SfIp *dstIP, uint16_t dstPort, - int16_t appId, FlowData* fd) + SnortProtocolId snort_protocol_id, FlowData* fd) { return exp_cache->add_flow( ctrlPkt, type, ip_proto, srcIP, srcPort, dstIP, dstPort, - SSN_DIR_BOTH, fd, appId); + SSN_DIR_BOTH, fd, snort_protocol_id); } bool FlowControl::is_expected(Packet* p) diff --git a/src/flow/flow_control.h b/src/flow/flow_control.h index 632403c8c..bfe6ce513 100644 --- a/src/flow/flow_control.h +++ b/src/flow/flow_control.h @@ -90,7 +90,7 @@ public: const snort::Packet* ctrlPkt, PktType, IpProtocol, const snort::SfIp *srcIP, uint16_t srcPort, const snort::SfIp *dstIP, uint16_t dstPort, - int16_t appId, snort::FlowData*); + SnortProtocolId snort_protocol_id, snort::FlowData*); PegCount get_flows(PktType); PegCount get_total_prunes(PktType) const; diff --git a/src/framework/inspector.h b/src/framework/inspector.h index 9bffc1871..a18ad9898 100644 --- a/src/framework/inspector.h +++ b/src/framework/inspector.h @@ -28,6 +28,7 @@ #include "framework/base_api.h" #include "main/thread.h" +#include "target_based/snort_protocols.h" class Session; @@ -36,8 +37,6 @@ namespace snort struct SnortConfig; struct Packet; -typedef int16_t ServiceId; - // this is the current version of the api #define INSAPI_VERSION ((BASE_API_VERSION << 16) | 0) @@ -52,14 +51,12 @@ struct InspectionBuffer unsigned len; }; - struct InspectApi; //------------------------------------------------------------------------- // api for class //------------------------------------------------------------------------- - class SO_PUBLIC Inspector { public: @@ -102,8 +99,12 @@ public: bool is_inactive(); - void set_service(ServiceId id) { srv_id = id; } - ServiceId get_service() { return srv_id; } + void set_service(SnortProtocolId snort_protocol_id_param) + { + snort_protocol_id = snort_protocol_id_param; + } + + SnortProtocolId get_service() { return snort_protocol_id; } // for well known buffers // well known buffers may be included among generic below, @@ -144,7 +145,7 @@ protected: private: const InspectApi* api; std::atomic_uint* ref_count; - ServiceId srv_id; + SnortProtocolId snort_protocol_id; }; template @@ -206,5 +207,6 @@ struct InspectApi inline const char* Inspector::get_name() { return api->base.name; } } + #endif diff --git a/src/framework/ips_option.h b/src/framework/ips_option.h index defe2f524..281bbfbdb 100644 --- a/src/framework/ips_option.h +++ b/src/framework/ips_option.h @@ -26,6 +26,7 @@ #include "detection/rule_option_types.h" #include "framework/base_api.h" #include "main/snort_types.h" +#include "target_based/snort_protocols.h" //------------------------------------------------------------------------- // api for class @@ -95,10 +96,10 @@ public: { return CAT_NONE; } // for fast-pattern options like content - virtual PatternMatchData* get_pattern(int /*proto*/, RuleDirection = RULE_WO_DIR) + virtual PatternMatchData* get_pattern(SnortProtocolId, RuleDirection = RULE_WO_DIR) { return nullptr; } - virtual struct PatternMatchData* get_alternate_pattern() + virtual PatternMatchData* get_alternate_pattern() { return nullptr; } static int eval(void* v, Cursor& c, Packet* p) @@ -128,7 +129,7 @@ enum RuleOptType typedef void (* IpsOptFunc)(SnortConfig*); -typedef IpsOption* (* IpsNewFunc)(class Module*, struct OptTreeNode*); +typedef IpsOption* (* IpsNewFunc)(Module*, OptTreeNode*); typedef void (* IpsDelFunc)(IpsOption*); struct IpsApi diff --git a/src/hash/test/ghash_test.cc b/src/hash/test/ghash_test.cc index 8090a471e..e4a096eaa 100644 --- a/src/hash/test/ghash_test.cc +++ b/src/hash/test/ghash_test.cc @@ -37,7 +37,7 @@ using namespace snort; static SnortConfig my_config; THREAD_LOCAL SnortConfig *snort_conf = &my_config; -SnortConfig::SnortConfig(SnortConfig*) +SnortConfig::SnortConfig(const SnortConfig* const) { snort_conf->run_flags = 0;} // run_flags is used indirectly from HashFnc class by calling SnortConfig::static_hash() SnortConfig::~SnortConfig() = default; diff --git a/src/host_tracker/host_cache.cc b/src/host_tracker/host_cache.cc index f476d83c4..6b43376f6 100644 --- a/src/host_tracker/host_cache.cc +++ b/src/host_tracker/host_cache.cc @@ -24,6 +24,9 @@ #include "host_cache.h" +#include "main/snort_config.h" +#include "target_based/snort_protocols.h" + using namespace snort; #define LRU_CACHE_INITIAL_SIZE 65535 @@ -39,11 +42,11 @@ void host_cache_add_host_tracker(HostTracker* ht) namespace snort { -bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, const char* /*service*/) +bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, const char* service) { HostIpKey ipkey((const uint8_t*) ipaddr.get_ip6_ptr()); - uint16_t proto = 0; // FIXIT-M not safe with multithreads SnortConfig::get_conf()->proto_ref->add(service)); - HostApplicationEntry app_entry(ipproto, port, proto); + SnortProtocolId proto_id = SnortConfig::get_conf()->proto_ref->find(service); + HostApplicationEntry app_entry(ipproto, port, proto_id); std::shared_ptr ht; if (!host_cache.find(ipkey, ht)) @@ -62,3 +65,4 @@ bool host_cache_add_service(const SfIp& ipaddr, Protocol ipproto, Port port, con return ht->add_service(app_entry); } } + diff --git a/src/host_tracker/host_tracker.h b/src/host_tracker/host_tracker.h index 5cedccf56..b977c7551 100644 --- a/src/host_tracker/host_tracker.h +++ b/src/host_tracker/host_tracker.h @@ -33,6 +33,7 @@ #include "framework/counts.h" #include "main/thread.h" #include "sfip/sf_ip.h" +#include "target_based/snort_protocols.h" // FIXIT-M For now this emulates the Snort++ attribute table. // Need to add in host_tracker.h data eventually. @@ -54,16 +55,14 @@ struct HostApplicationEntry { Port port = 0; Protocol ipproto = 0; - Protocol protocol = 0; - - static const Protocol UNKNOWN_PROTOCOL = 0; + SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID; HostApplicationEntry() = default; - HostApplicationEntry(Protocol ipproto_param, Port port_param, Protocol protocol_param) : + HostApplicationEntry(Protocol ipproto_param, Port port_param, SnortProtocolId protocol_param) : port(port_param), ipproto(ipproto_param), - protocol(protocol_param) + snort_protocol_id(protocol_param) { } @@ -166,7 +165,7 @@ public: // Returns false when not found. bool find_service(Protocol ipproto, Port port, HostApplicationEntry& app_entry) { - HostApplicationEntry tmp_entry(ipproto, port, HostApplicationEntry::UNKNOWN_PROTOCOL); + HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID); host_tracker_stats.service_finds++; std::lock_guard lck(host_tracker_lock); @@ -185,7 +184,7 @@ public: // Returns true if entry existed. False otherwise. bool remove_service(Protocol ipproto, Port port) { - HostApplicationEntry tmp_entry(ipproto, port, HostApplicationEntry::UNKNOWN_PROTOCOL); + HostApplicationEntry tmp_entry(ipproto, port, UNKNOWN_PROTOCOL_ID); host_tracker_stats.service_removes++; std::lock_guard lck(host_tracker_lock); diff --git a/src/host_tracker/host_tracker_module.cc b/src/host_tracker/host_tracker_module.cc index dda6010b4..92488329c 100644 --- a/src/host_tracker/host_tracker_module.cc +++ b/src/host_tracker/host_tracker_module.cc @@ -86,7 +86,7 @@ bool HostTrackerModule::set(const char*, Value& v, SnortConfig* sc) host->set_stream_policy(v.get_long() + 1); else if ( v.is("name") ) - app.protocol = sc->proto_ref->add(v.get_string()); + app.snort_protocol_id = sc->proto_ref->add(v.get_string()); else if ( v.is("proto") ) app.ipproto = sc->proto_ref->add(v.get_string()); diff --git a/src/host_tracker/test/host_cache_module_test.cc b/src/host_tracker/test/host_cache_module_test.cc index 46d94c9fd..e1e2c03bb 100644 --- a/src/host_tracker/test/host_cache_module_test.cc +++ b/src/host_tracker/test/host_cache_module_test.cc @@ -25,6 +25,7 @@ #include "host_tracker/host_cache_module.h" #include "host_tracker/host_cache.h" +#include "main/snort_config.h" #include #include @@ -33,17 +34,11 @@ using namespace snort; -// Fake AddProtocolReference to avoid bringing in a ton of dependencies. -int16_t AddProtocolReference(const char* protocol) -{ - if (!strcmp("servicename", protocol)) - return 3; - if (!strcmp("tcp", protocol)) - return 2; - return 1; -} +// Fakes to avoid bringing in a ton of dependencies. +SnortProtocolId ProtocolReference::add(char const*) { return 0; } +SnortProtocolId ProtocolReference::find(char const*) { return 0; } +SnortConfig* SnortConfig::get_conf() { return nullptr; } -// Fake show_stats to avoid bringing in a ton of dependencies. void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { } diff --git a/src/host_tracker/test/host_cache_test.cc b/src/host_tracker/test/host_cache_test.cc index a45343bfd..31e33f0c7 100644 --- a/src/host_tracker/test/host_cache_test.cc +++ b/src/host_tracker/test/host_cache_test.cc @@ -25,20 +25,34 @@ #include "host_tracker/host_cache.h" +#include "main/snort_config.h" + #include #include using namespace snort; -// Fake AddProtocolReference to avoid bringing in a ton of dependencies. -int16_t AddProtocolReference(const char* protocol) +SnortConfig s_conf; +THREAD_LOCAL SnortConfig* snort_conf = &s_conf; + +SnortConfig::SnortConfig(const SnortConfig* const) { } + +SnortConfig::~SnortConfig() { } + +SnortConfig* SnortConfig::get_conf() +{ return snort_conf; } + +SnortProtocolId ProtocolReference::find(char const*) { return 0; } + +SnortProtocolId ProtocolReference::add(const char* protocol) { + if (!strcmp("servicename", protocol)) + return 3; if (!strcmp("tcp", protocol)) return 2; return 1; } -// Ditto for snort_strdup() char* snort_strdup(const char* str) { return strdup(str); diff --git a/src/host_tracker/test/host_tracker_module_test.cc b/src/host_tracker/test/host_tracker_module_test.cc index 17946cb5d..905520c5b 100644 --- a/src/host_tracker/test/host_tracker_module_test.cc +++ b/src/host_tracker/test/host_tracker_module_test.cc @@ -26,14 +26,18 @@ #include "host_tracker/host_cache.h" #include "host_tracker/host_tracker_module.h" #include "target_based/snort_protocols.h" +#include "main/snort_config.h" #include #include using namespace snort; -// Fake to avoid bringing in a ton of dependencies. -int16_t ProtocolReference::add(const char* protocol) +SnortConfig* SnortConfig::get_conf() { return nullptr; } + +SnortProtocolId ProtocolReference::find(char const*) { return 0; } + +SnortProtocolId ProtocolReference::add(const char* protocol) { if (!strcmp("servicename", protocol)) return 3; diff --git a/src/host_tracker/test/host_tracker_test.cc b/src/host_tracker/test/host_tracker_test.cc index b09352eda..67fb15807 100644 --- a/src/host_tracker/test/host_tracker_test.cc +++ b/src/host_tracker/test/host_tracker_test.cc @@ -103,20 +103,20 @@ TEST(host_tracker, add_find_service_test) CHECK(true == ret); CHECK(actual_entry.port == 2112); CHECK(actual_entry.ipproto == 6); - CHECK(actual_entry.protocol == 3); + CHECK(actual_entry.snort_protocol_id == 3); ht.add_service(app_entry2); ret = ht.find_service(6, 2112, actual_entry); CHECK(true == ret); CHECK(actual_entry.port == 2112); CHECK(actual_entry.ipproto == 6); - CHECK(actual_entry.protocol == 3); + CHECK(actual_entry.snort_protocol_id == 3); ret = ht.find_service(17, 7777, actual_entry); CHECK(true == ret); CHECK(actual_entry.port == 7777); CHECK(actual_entry.ipproto == 17); - CHECK(actual_entry.protocol == 10); + CHECK(actual_entry.snort_protocol_id == 10); // Try adding an entry that exists already. ret = ht.add_service(app_entry1); diff --git a/src/ips_options/ips_content.cc b/src/ips_options/ips_content.cc index 4db8b9723..9f13f533d 100644 --- a/src/ips_options/ips_content.cc +++ b/src/ips_options/ips_content.cc @@ -135,7 +135,7 @@ public: EvalStatus eval(Cursor& c, Packet*) override { return CheckANDPatternMatch(config, c); } - PatternMatchData* get_pattern(int, RuleDirection) override + PatternMatchData* get_pattern(SnortProtocolId, RuleDirection) override { return &config->pmd; } protected: diff --git a/src/ips_options/ips_flow.cc b/src/ips_options/ips_flow.cc index 6ad1996f4..8646e5e1e 100644 --- a/src/ips_options/ips_flow.cc +++ b/src/ips_options/ips_flow.cc @@ -418,7 +418,7 @@ static IpsOption* flow_ctor(Module* p, OptTreeNode* otn) if ( m->data.unestablished ) otn->unestablished = 1; - if (otn->proto == SNORT_PROTO_ICMP) + if (otn->snort_protocol_id == SNORT_PROTO_ICMP) { if ( (m->data.only_reassembled != ONLY_FRAG) && (m->data.ignore_reassembled != IGNORE_FRAG) ) diff --git a/src/ips_options/ips_regex.cc b/src/ips_options/ips_regex.cc index 7ca3870fe..f165fec6d 100644 --- a/src/ips_options/ips_regex.cc +++ b/src/ips_options/ips_regex.cc @@ -98,7 +98,7 @@ public: bool retry(Cursor&) override; - PatternMatchData* get_pattern(int, RuleDirection) override + PatternMatchData* get_pattern(SnortProtocolId, RuleDirection) override { return &config.pmd; } EvalStatus eval(Cursor&, Packet*) override; diff --git a/src/ips_options/ips_sd_pattern.cc b/src/ips_options/ips_sd_pattern.cc index 488a3c7ee..fd9ac7a10 100644 --- a/src/ips_options/ips_sd_pattern.cc +++ b/src/ips_options/ips_sd_pattern.cc @@ -124,7 +124,7 @@ public: uint32_t hash() const override; bool operator==(const IpsOption&) const override; - PatternMatchData* get_pattern(int, RuleDirection) override + PatternMatchData* get_pattern(SnortProtocolId, RuleDirection) override { return &config.pmd; } EvalStatus eval(Cursor&, Packet* p) override; diff --git a/src/ips_options/test/ips_regex_test.cc b/src/ips_options/test/ips_regex_test.cc index d9a602d7c..c25ea8601 100644 --- a/src/ips_options/test/ips_regex_test.cc +++ b/src/ips_options/test/ips_regex_test.cc @@ -57,7 +57,7 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf; static SnortState s_state; -SnortConfig::SnortConfig(SnortConfig*) +SnortConfig::SnortConfig(const SnortConfig* const) { state = &s_state; memset(state, 0, sizeof(*state)); @@ -81,7 +81,6 @@ static unsigned s_parse_errors = 0; void ParseError(const char*, ...) { s_parse_errors++; } - unsigned get_instance_id() { return 0; } diff --git a/src/loggers/alert_sf_socket.cc b/src/loggers/alert_sf_socket.cc index 34f701981..5e2046cbd 100644 --- a/src/loggers/alert_sf_socket.cc +++ b/src/loggers/alert_sf_socket.cc @@ -255,7 +255,7 @@ static OptTreeNode* OptTreeNode_Search(uint32_t, uint32_t sid) OptTreeNode* otn = (OptTreeNode*)hashNode->data; RuleTreeNode* rtn = getRuntimeRtnFromOtn(otn); - if ( rtn and is_network_protocol(rtn->proto) ) + if ( rtn and is_network_protocol(rtn->snort_protocol_id) ) { if (otn->sigInfo.sid == sid) return otn; diff --git a/src/main/modules.cc b/src/main/modules.cc index 98a8d4530..16da1cdc4 100644 --- a/src/main/modules.cc +++ b/src/main/modules.cc @@ -1091,7 +1091,7 @@ public: bool NetworkModule::set(const char*, Value& v, SnortConfig* sc) { - NetworkPolicy* p = snort::get_network_policy(); + NetworkPolicy* p = get_network_policy(); if ( v.is("checksum_drop") ) ConfigChecksumDrop(v.get_string()); @@ -1161,7 +1161,7 @@ public: bool InspectionModule::set(const char*, Value& v, SnortConfig* sc) { - InspectionPolicy* p = snort::get_inspection_policy(); + InspectionPolicy* p = get_inspection_policy(); if ( v.is("id") ) { @@ -1251,7 +1251,7 @@ public: { return ips_module_pegs; } PegCount* get_counts() const override - { return (PegCount*) &snort::ips_module_stats; } + { return (PegCount*) &ips_module_stats; } Usage get_usage() const override { return DETECT; } @@ -1259,7 +1259,7 @@ public: bool IpsModule::set(const char*, Value& v, SnortConfig* sc) { - IpsPolicy* p = snort::get_ips_policy(); + IpsPolicy* p = get_ips_policy(); if ( v.is("enable_builtin_rules") ) p->enable_builtin_rules = v.get_bool(); @@ -1846,7 +1846,7 @@ bool HostsModule::set(const char*, Value& v, SnortConfig* sc) host->hostInfo.streamPolicy = v.get_long() + 1; else if ( app and v.is("name") ) - app->protocol = sc->proto_ref->add(v.get_string()); + app->snort_protocol_id = sc->proto_ref->add(v.get_string()); else if ( app and v.is("proto") ) app->ipproto = sc->proto_ref->add(v.get_string()); diff --git a/src/main/snort_config.cc b/src/main/snort_config.cc index 4f6fe2497..2ea5c2700 100644 --- a/src/main/snort_config.cc +++ b/src/main/snort_config.cc @@ -178,15 +178,7 @@ static void init_policies(SnortConfig* sc) } } -//------------------------------------------------------------------------- -// public methods -//------------------------------------------------------------------------- - -/* A lot of this initialization can be skipped if not running in IDS mode - * but the goal is to minimize config checks at run time when running in - * IDS mode so we keep things simple and enforce that the only difference - * among run_modes is how we handle packets via the log_func. */ -SnortConfig::SnortConfig(SnortConfig* other_conf) +void SnortConfig::init(const SnortConfig* const other_conf, ProtocolReference* protocol_reference) { homenet.clear(); obfuscation_net.clear(); @@ -214,7 +206,7 @@ SnortConfig::SnortConfig(SnortConfig* other_conf) thread_config = new ThreadConfig(); memset(evalOrder, 0, sizeof(evalOrder)); - proto_ref = new ProtocolReference; + proto_ref = new ProtocolReference(protocol_reference); } else { @@ -227,6 +219,25 @@ SnortConfig::SnortConfig(SnortConfig* other_conf) set_network_policy(policy_map->get_network_policy()); } +//------------------------------------------------------------------------- +// public methods +//------------------------------------------------------------------------- + +/* A lot of this initialization can be skipped if not running in IDS mode + * but the goal is to minimize config checks at run time when running in + * IDS mode so we keep things simple and enforce that the only difference + * among run_modes is how we handle packets via the log_func. */ +SnortConfig::SnortConfig(const SnortConfig* const other_conf) +{ + init(other_conf, nullptr); +} + +// Copy the ProtocolReference data into the new SnortConfig. +SnortConfig::SnortConfig(ProtocolReference* protocol_reference) +{ + init(nullptr, protocol_reference); +} + SnortConfig::~SnortConfig() { if ( cloned ) @@ -335,7 +346,7 @@ void SnortConfig::post_setup() #endif } -void SnortConfig::clone(SnortConfig* conf) +void SnortConfig::clone(const SnortConfig* const conf) { *this = *conf; if (conf->homenet.get_family() != 0) @@ -964,14 +975,14 @@ void SnortConfig::set_alert_mode(const char* val) output = val; output_flags |= OUTPUT_FLAG__ALERTS; - snort::Snort::set_main_hook(DetectionEngine::inspect); + Snort::set_main_hook(DetectionEngine::inspect); } void SnortConfig::set_log_mode(const char* val) { if (strcasecmp(val, LOG_NONE) == 0) { - snort::Snort::set_main_hook(snort_ignore); + Snort::set_main_hook(snort_ignore); EventManager::enable_logs(false); } else @@ -979,7 +990,7 @@ void SnortConfig::set_log_mode(const char* val) if ( !strcmp(val, LOG_DUMP) ) val = LOG_CODECS; output = val; - snort::Snort::set_main_hook(snort_log); + Snort::set_main_hook(snort_log); } } diff --git a/src/main/snort_config.h b/src/main/snort_config.h index 3b6b25cb3..c49f8bf10 100644 --- a/src/main/snort_config.h +++ b/src/main/snort_config.h @@ -159,8 +159,12 @@ struct SnortState struct SnortConfig { +private: + void init(const SnortConfig* const, ProtocolReference*); + public: - SnortConfig(SnortConfig* other_conf = nullptr ); + SnortConfig(const SnortConfig* const other_conf = nullptr); + SnortConfig(ProtocolReference* protocol_reference); ~SnortConfig(); SnortConfig(const SnortConfig&) = delete; @@ -170,7 +174,7 @@ public: bool verify(); void merge(SnortConfig*); - void clone(SnortConfig*); + void clone(const SnortConfig* const); public: //------------------------------------------------------ diff --git a/src/managers/ips_manager.cc b/src/managers/ips_manager.cc index 459d0b3d5..d90a17dd2 100644 --- a/src/managers/ips_manager.cc +++ b/src/managers/ips_manager.cc @@ -166,7 +166,7 @@ const char* IpsManager::get_option_keyword() } bool IpsManager::option_begin( - SnortConfig* sc, const char* key, int /*proto*/) + SnortConfig* sc, const char* key, SnortProtocolId) { Option* opt = get_opt(key); @@ -245,7 +245,7 @@ bool IpsManager::option_set( } bool IpsManager::option_end( - SnortConfig* sc, OptTreeNode* otn, int proto, + SnortConfig* sc, OptTreeNode* otn, SnortProtocolId snort_protocol_id, const char* key, RuleOptType& type) { if ( current_keyword.empty() ) @@ -254,9 +254,9 @@ bool IpsManager::option_end( assert(!strcmp(current_keyword.c_str(), key)); #ifdef NDEBUG - UNUSED(proto); + UNUSED(snort_protocol_id); #else - assert(proto == otn->proto); + assert(snort_protocol_id == otn->snort_protocol_id); #endif Module* mod = current_module; diff --git a/src/managers/ips_manager.h b/src/managers/ips_manager.h index f290d09f9..f7a006491 100644 --- a/src/managers/ips_manager.h +++ b/src/managers/ips_manager.h @@ -63,14 +63,14 @@ public: static void instantiate(const snort::IpsApi*, snort::Module*, snort::SnortConfig*); static bool get_option( - snort::SnortConfig*, struct OptTreeNode*, int proto, + snort::SnortConfig*, struct OptTreeNode*, SnortProtocolId, const char* keyword, char* args, snort::RuleOptType&); - static bool option_begin(snort::SnortConfig*, const char* key, int proto); + static bool option_begin(snort::SnortConfig*, const char* key, SnortProtocolId); static bool option_set( snort::SnortConfig*, const char* key, const char* opt, const char* val); static bool option_end( - snort::SnortConfig*, OptTreeNode*, int proto, const char* key, snort::RuleOptType&); + snort::SnortConfig*, OptTreeNode*, SnortProtocolId, const char* key, snort::RuleOptType&); static void delete_option(snort::IpsOption*); static const char* get_option_keyword(); diff --git a/src/network_inspectors/appid/app_info_table.cc b/src/network_inspectors/appid/app_info_table.cc index 2970c3d81..91d520c0f 100644 --- a/src/network_inspectors/appid/app_info_table.cc +++ b/src/network_inspectors/appid/app_info_table.cc @@ -506,16 +506,15 @@ void AppInfoManager::load_appid_config(AppIdModuleConfig* config, const char* pa fclose(config_file); } -int16_t AppInfoManager::add_appid_protocol_reference(const char* protocol) +SnortProtocolId AppInfoManager::add_appid_protocol_reference(const char* protocol, + snort::SnortConfig* sc) { - static std::mutex apr_mutex; - - std::lock_guard lock(apr_mutex); - int16_t id = snort::SnortConfig::get_conf()->proto_ref->add(protocol); - return id; + SnortProtocolId snort_protocol_id = sc->proto_ref->add(protocol); + return snort_protocol_id; } -void AppInfoManager::init_appid_info_table(AppIdModuleConfig* mod_config) +void AppInfoManager::init_appid_info_table(AppIdModuleConfig* mod_config, + snort::SnortConfig* sc) { if ( !mod_config->app_detector_dir ) { @@ -592,8 +591,10 @@ void AppInfoManager::init_appid_info_table(AppIdModuleConfig* mod_config) /* snort service key, if it exists */ token = strtok_r(nullptr, CONF_SEPARATORS, &context); + + // FIXIT-H: Sometimes the token is "~". Should we ignore those? if (token) - entry->snortId = add_appid_protocol_reference(token); + entry->snort_protocol_id = add_appid_protocol_reference(token, sc); if ((app_id = get_static_app_info_entry(entry->appId))) { diff --git a/src/network_inspectors/appid/app_info_table.h b/src/network_inspectors/appid/app_info_table.h index 07500b677..6667d1040 100644 --- a/src/network_inspectors/appid/app_info_table.h +++ b/src/network_inspectors/appid/app_info_table.h @@ -31,7 +31,7 @@ #include "flow/flow.h" #include "framework/counts.h" #include "main/thread.h" -#include "protocols/packet.h" +#include "target_based/snort_protocols.h" #include "utils/util.h" #define APP_PRIORITY_DEFAULT 2 @@ -75,7 +75,7 @@ public: uint32_t serviceId; uint32_t clientId; uint32_t payloadId; - int16_t snortId = snort::SFTARGET_UNKNOWN_PROTOCOL; + SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID; uint32_t flags = 0; uint32_t priority = APP_PRIORITY_DEFAULT; ClientDetector* client_detector = nullptr; @@ -139,10 +139,10 @@ public: return entry ? entry->priority : 0; } - void init_appid_info_table(AppIdModuleConfig*); + void init_appid_info_table(AppIdModuleConfig*, snort::SnortConfig*); void cleanup_appid_info_table(); void dump_app_info_table(); - int16_t add_appid_protocol_reference(const char* protocol); + SnortProtocolId add_appid_protocol_reference(const char* protocol, snort::SnortConfig*); private: AppInfoManager() = default; diff --git a/src/network_inspectors/appid/appid_config.cc b/src/network_inspectors/appid/appid_config.cc index 625b88a1c..1ca926c64 100644 --- a/src/network_inspectors/appid/appid_config.cc +++ b/src/network_inspectors/appid/appid_config.cc @@ -46,6 +46,8 @@ #define MAX_DISPLAY_SIZE 65536 #define MAX_LINE 2048 +using namespace snort; + uint32_t app_id_netmasks[33] = { 0x00000000, 0x80000000, 0xC0000000, 0xE0000000, 0xF0000000, 0xF8000000, 0xFC000000, 0xFE000000, 0xFF000000, 0xFF800000, 0xFFC00000, 0xFFE00000, 0xFFF00000, 0xFFF80000, @@ -59,16 +61,23 @@ struct PortList uint16_t port; }; -int16_t snortId_for_unsynchronized; -int16_t snortId_for_ftp_data; -int16_t snortId_for_http2; +SnortProtocolId snortId_for_unsynchronized; +SnortProtocolId snortId_for_ftp_data; +SnortProtocolId snortId_for_http2; -static void map_app_names_to_snort_ids() +static void map_app_names_to_snort_ids(SnortConfig* sc) { /* init globals for snortId compares */ - snortId_for_unsynchronized = snort::SnortConfig::get_conf()->proto_ref->add("unsynchronized"); - snortId_for_ftp_data = snort::SnortConfig::get_conf()->proto_ref->add("ftp-data"); - snortId_for_http2 = snort::SnortConfig::get_conf()->proto_ref->add("http2"); + snortId_for_unsynchronized = sc->proto_ref->add("unsynchronized"); + snortId_for_ftp_data = sc->proto_ref->add("ftp-data"); + snortId_for_http2 = sc->proto_ref->add("http2"); + + // Have to create SnortProtocolIds during configuration initialization. + sc->proto_ref->add("rexec"); + sc->proto_ref->add("rsh-error"); + sc->proto_ref->add("snmp"); + sc->proto_ref->add("sunrpc"); + sc->proto_ref->add("tftp"); } AppIdModuleConfig::AppIdModuleConfig() @@ -736,16 +745,16 @@ void AppIdConfig::set_safe_search_enforcement(bool enabled) mod_config->safe_search_enabled = enabled; } -bool AppIdConfig::init_appid( ) +bool AppIdConfig::init_appid(SnortConfig* sc) { - app_info_mgr.init_appid_info_table(mod_config); + app_info_mgr.init_appid_info_table(mod_config, sc); #ifdef USE_RNA_CONFIG load_analysis_config(mod_config->conf_file, 0, mod_config->instance_id); #endif read_port_detectors(ODP_PORT_DETECTORS); read_port_detectors(CUSTOM_PORT_DETECTORS); ThirdPartyAppIDInit(mod_config); - map_app_names_to_snort_ids(); + map_app_names_to_snort_ids(sc); return true; } diff --git a/src/network_inspectors/appid/appid_config.h b/src/network_inspectors/appid/appid_config.h index d8f214d3d..818802cbf 100644 --- a/src/network_inspectors/appid/appid_config.h +++ b/src/network_inspectors/appid/appid_config.h @@ -26,8 +26,10 @@ #include "application_ids.h" #include "framework/decode_data.h" +#include "main/snort_config.h" #include "protocols/ipv6.h" #include "sfip/sf_ip.h" +#include "target_based/snort_protocols.h" #include "utils/sflsq.h" #define APP_ID_MAX_DIRS 16 @@ -40,9 +42,9 @@ class AppInfoManager; extern unsigned appIdPolicyId; extern uint32_t app_id_netmasks[]; -extern int16_t snortId_for_unsynchronized; -extern int16_t snortId_for_ftp_data; -extern int16_t snortId_for_http2; +extern SnortProtocolId snortId_for_unsynchronized; +extern SnortProtocolId snortId_for_ftp_data; +extern SnortProtocolId snortId_for_http2; struct PortExclusion { @@ -112,7 +114,7 @@ public: AppIdConfig(AppIdModuleConfig*); ~AppIdConfig(); - bool init_appid(); + bool init_appid(snort::SnortConfig*); void cleanup(); void show(); void set_safe_search_enforcement(bool enabled); diff --git a/src/network_inspectors/appid/appid_discovery.cc b/src/network_inspectors/appid/appid_discovery.cc index 54c3b9a48..bef16a7b8 100644 --- a/src/network_inspectors/appid/appid_discovery.cc +++ b/src/network_inspectors/appid/appid_discovery.cc @@ -585,7 +585,7 @@ static void lookup_appid_by_host_port(AppIdSession& asd, Packet* p, IpProtocol p break; default: asd.service.set_id(hv->appId); - asd.sync_with_snort_id(hv->appId, p); + asd.sync_with_snort_protocol_id(hv->appId, p); asd.service_disco_state = APPID_DISCO_STATE_FINISHED; asd.client_disco_state = APPID_DISCO_STATE_FINISHED; asd.set_session_flags(APPID_SESSION_SERVICE_DETECTED); diff --git a/src/network_inspectors/appid/appid_inspector.cc b/src/network_inspectors/appid/appid_inspector.cc index 8ef38be84..b45691d0e 100644 --- a/src/network_inspectors/appid/appid_inspector.cc +++ b/src/network_inspectors/appid/appid_inspector.cc @@ -98,7 +98,7 @@ AppIdConfig* AppIdInspector::get_appid_config() return active_config; } -bool AppIdInspector::configure(SnortConfig*) +bool AppIdInspector::configure(SnortConfig* sc) { assert(!active_config); @@ -113,7 +113,7 @@ bool AppIdInspector::configure(SnortConfig*) my_seh = SipEventHandler::create(); my_seh->subscribe(); - active_config->init_appid(); + active_config->init_appid(sc); return true; // FIXIT-M some of this stuff may be needed in some fashion... diff --git a/src/network_inspectors/appid/appid_session.cc b/src/network_inspectors/appid/appid_session.cc index 80ecc2336..b66827331 100644 --- a/src/network_inspectors/appid/appid_session.cc +++ b/src/network_inspectors/appid/appid_session.cc @@ -120,7 +120,7 @@ AppIdSession* AppIdSession::allocate_session(const Packet* p, IpProtocol proto, asd->flow = p->flow; asd->stats.first_packet_second = p->pkth->ts.tv_sec; asd->set_session_logging_state(p, direction); - asd->snort_id = snortId_for_unsynchronized; + asd->snort_protocol_id = snortId_for_unsynchronized; p->flow->set_flow_data(asd); return asd; } @@ -208,7 +208,7 @@ static inline PktType get_pkt_type_from_ip_proto(IpProtocol proto) AppIdSession* AppIdSession::create_future_session(const Packet* ctrlPkt, const SfIp* cliIp, uint16_t cliPort, const SfIp* srvIp, uint16_t srvPort, IpProtocol proto, - int16_t app_id, int /*flags*/, AppIdInspector& inspector) + SnortProtocolId snort_protocol_id, int /*flags*/, AppIdInspector& inspector) { char src_ip[INET6_ADDRSTRLEN]; char dst_ip[INET6_ADDRSTRLEN]; @@ -221,8 +221,8 @@ AppIdSession* AppIdSession::create_future_session(const Packet* ctrlPkt, const S AppIdSession* asd = new AppIdSession(proto, cliIp, 0, inspector); asd->common.policyId = asd->config->appIdPolicyId; - if ( Stream::set_application_protocol_id_expected(ctrlPkt, type, proto, cliIp, cliPort, srvIp, - srvPort, app_id, asd) ) + if ( Stream::set_snort_protocol_id_expected(ctrlPkt, type, proto, cliIp, cliPort, srvIp, + srvPort, snort_protocol_id, asd) ) { sfip_ntop(cliIp, src_ip, sizeof(src_ip)); sfip_ntop(srvIp, dst_ip, sizeof(dst_ip)); @@ -292,7 +292,7 @@ void AppIdSession::reinit_session_data() APPID_SESSION_SSL_SESSION|APPID_SESSION_HTTP_SESSION | APPID_SESSION_APP_REINSPECT); } -void AppIdSession::sync_with_snort_id(AppId newAppId, Packet* p) +void AppIdSession::sync_with_snort_protocol_id(AppId newAppId, Packet* p) { if (newAppId > APP_ID_NONE && newAppId < SF_APPID_MAX) { @@ -330,21 +330,21 @@ void AppIdSession::sync_with_snort_id(AppId newAppId, Packet* p) AppInfoTableEntry* entry = app_info_mgr->get_app_info_entry(newAppId); if ( entry ) { - int16_t tempSnortId = entry->snortId; + SnortProtocolId tmp_snort_protocol_id = entry->snort_protocol_id; // A particular APP_ID_xxx may not be assigned a service_snort_key value - // in the rna_app.yaml file entry; so ignore the tempSnortId == 0 case. - if ( tempSnortId == 0 && (newAppId == APP_ID_HTTP2)) - tempSnortId = snortId_for_http2; + // in the rna_app.yaml file entry; so ignore the snort_protocol_id == UNKNOWN_PROTOCOL_ID case. + if ( tmp_snort_protocol_id == UNKNOWN_PROTOCOL_ID && (newAppId == APP_ID_HTTP2)) + tmp_snort_protocol_id = snortId_for_http2; - if ( tempSnortId != snort_id ) + if ( tmp_snort_protocol_id != snort_protocol_id ) { - snort_id = tempSnortId; + snort_protocol_id = tmp_snort_protocol_id; if (session_logging_enabled) - if (tempSnortId == snortId_for_http2) + if (tmp_snort_protocol_id == snortId_for_http2) LogMessage("AppIdDbg %s Telling Snort that it's HTTP/2\n", session_logging_id); - p->flow->ssn_state.application_protocol = tempSnortId; + p->flow->ssn_state.snort_protocol_id = tmp_snort_protocol_id; } } } diff --git a/src/network_inspectors/appid/appid_session.h b/src/network_inspectors/appid/appid_session.h index 8aa09eb79..64b853bad 100644 --- a/src/network_inspectors/appid/appid_session.h +++ b/src/network_inspectors/appid/appid_session.h @@ -148,7 +148,7 @@ public: static AppIdSession* allocate_session(const snort::Packet*, IpProtocol, int, AppIdInspector&); static AppIdSession* create_future_session(const snort::Packet*, const snort::SfIp*, uint16_t, const snort::SfIp*, - uint16_t, IpProtocol, int16_t, int, AppIdInspector&); + uint16_t, IpProtocol, SnortProtocolId, int, AppIdInspector&); AppIdInspector& get_inspector() const { @@ -202,7 +202,7 @@ public: uint16_t init_tpPackets = 0; uint16_t resp_tpPackets = 0; bool tp_reinspect_by_initiator = false; - int16_t snort_id = 0; + SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID; /* Length-based detectors. */ LengthKey length_sequence; @@ -279,7 +279,7 @@ public: void check_app_detection_restart(); void update_encrypted_app_id(AppId); void examine_rtmp_metadata(); - void sync_with_snort_id(AppId, snort::Packet*); + void sync_with_snort_protocol_id(AppId, snort::Packet*); void stop_rna_service_inspection(snort::Packet*, int); bool is_payload_appid_set(); diff --git a/src/network_inspectors/appid/client_plugins/client_discovery.cc b/src/network_inspectors/appid/client_plugins/client_discovery.cc index 2bdbaf1cb..11423c8af 100644 --- a/src/network_inspectors/appid/client_plugins/client_discovery.cc +++ b/src/network_inspectors/appid/client_plugins/client_discovery.cc @@ -435,7 +435,7 @@ bool ClientDiscovery::do_client_discovery(AppIdSession& asd, Packet* p, int dire LogMessage("AppIdDbg %s Got a preface for HTTP/2\n", asd.session_logging_id); if ( !was_service && asd.is_service_detected() ) - asd.sync_with_snort_id(asd.service.get_id(), p); + asd.sync_with_snort_protocol_id(asd.service.get_id(), p); return isTpAppidDiscoveryDone; } diff --git a/src/network_inspectors/appid/detector_plugins/detector_sip.cc b/src/network_inspectors/appid/detector_plugins/detector_sip.cc index 9da9e1a43..b72239cbd 100644 --- a/src/network_inspectors/appid/detector_plugins/detector_sip.cc +++ b/src/network_inspectors/appid/detector_plugins/detector_sip.cc @@ -29,6 +29,8 @@ #include "app_info_table.h" #include "protocols/packet.h" +using namespace snort; + static const char SIP_REGISTER_BANNER[] = "REGISTER "; static const char SIP_INVITE_BANNER[] = "INVITE "; static const char SIP_CANCEL_BANNER[] = "CANCEL "; @@ -327,10 +329,10 @@ static int get_sip_client_app(void* patternMatcher, const char* pattern, uint32_ return 1; } -void SipServiceDetector::createRtpFlow(AppIdSession& asd, const snort::Packet* pkt, - const snort::SfIp* cliIp, uint16_t cliPort, const snort::SfIp* srvIp, uint16_t srvPort, - IpProtocol proto, int16_t app_id) +void SipServiceDetector::createRtpFlow(AppIdSession& asd, const Packet* pkt, const SfIp* cliIp, + uint16_t cliPort, const SfIp* srvIp, uint16_t srvPort, IpProtocol proto, int16_t app_id) { + // FIXIT-H: Passing app_id instead of SnortProtocolId to create_future_session is incorrect. We need to look up snort_protocol_id. AppIdSession* fp = AppIdSession::create_future_session(pkt, cliIp, cliPort, srvIp, srvPort, proto, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if ( fp ) @@ -464,17 +466,17 @@ int SipServiceDetector::validate(AppIdDiscoveryArgs& args) THREAD_LOCAL SipUdpClientDetector* SipEventHandler::client = nullptr; THREAD_LOCAL SipServiceDetector* SipEventHandler::service = nullptr; -void SipEventHandler::handle(snort::DataEvent& event, snort::Flow* flow) +void SipEventHandler::handle(DataEvent& event, Flow* flow) { SipEvent& sip_event = (SipEvent&)event; AppIdSession* asd = nullptr; if ( flow ) - asd = snort::appid_api.get_appid_session(*flow); + asd = appid_api.get_appid_session(*flow); if ( !asd ) { - const snort::Packet* p = sip_event.get_packet(); + const Packet* p = sip_event.get_packet(); IpProtocol protocol = p->is_tcp() ? IpProtocol::TCP : IpProtocol::UDP; int direction = p->is_from_client() ? APP_ID_FROM_INITIATOR : APP_ID_FROM_RESPONDER; asd = AppIdSession::allocate_session(p, protocol, direction, diff --git a/src/network_inspectors/appid/lua_detector_api.cc b/src/network_inspectors/appid/lua_detector_api.cc index 49a5c74d8..f3ebe1180 100644 --- a/src/network_inspectors/appid/lua_detector_api.cc +++ b/src/network_inspectors/appid/lua_detector_api.cc @@ -2101,7 +2101,7 @@ static int create_future_flow(lua_State* L) { SfIp client_addr; SfIp server_addr; - int16_t snort_app_id = 0; + SnortProtocolId snort_protocol_id = UNKNOWN_PROTOCOL_ID; AppIdDetector* ud = *UserData::check(L, DETECTOR, 1); LuaStateDescriptor* lsd = ud->validate_lua_state(true); @@ -2127,11 +2127,11 @@ static int create_future_flow(lua_State* L) app_id_to_snort); if (!entry) return 0; - snort_app_id = entry->snortId; + snort_protocol_id = entry->snort_protocol_id; } AppIdSession* fp = AppIdSession::create_future_session(lsd->ldp.pkt, &client_addr, - client_port, &server_addr, server_port, proto, snort_app_id, + client_port, &server_addr, server_port, proto, snort_protocol_id, APPID_EARLY_SESSION_FLAG_FW_RULE, ud->get_handler().get_inspector()); if (fp) { diff --git a/src/network_inspectors/appid/service_plugins/service_discovery.cc b/src/network_inspectors/appid/service_plugins/service_discovery.cc index 6a44b606d..e09cf3fb5 100644 --- a/src/network_inspectors/appid/service_plugins/service_discovery.cc +++ b/src/network_inspectors/appid/service_plugins/service_discovery.cc @@ -679,7 +679,7 @@ bool ServiceDiscovery::do_service_discovery(AppIdSession& asd, Packet* p, int di APPID_SESSION_SERVICE_DETECTED | APPID_SESSION_NOT_A_SERVICE | APPID_SESSION_IGNORE_HOST) == APPID_SESSION_SERVICE_DETECTED) { - asd.sync_with_snort_id(asd.service.get_id(), p); + asd.sync_with_snort_protocol_id(asd.service.get_id(), p); } } diff --git a/src/network_inspectors/appid/service_plugins/service_ftp.cc b/src/network_inspectors/appid/service_plugins/service_ftp.cc index 427261d70..892a2c8ac 100644 --- a/src/network_inspectors/appid/service_plugins/service_ftp.cc +++ b/src/network_inspectors/appid/service_plugins/service_ftp.cc @@ -91,7 +91,6 @@ FtpServiceDetector::FtpServiceDetector(ServiceDiscovery* sd) name = "ftp"; proto = IpProtocol::TCP; detectorType = DETECTOR_TYPE_DECODER; - ftp_data_app_id = AppInfoManager::get_instance().add_appid_protocol_reference("ftp-data"); tcp_patterns = { @@ -799,8 +798,11 @@ void FtpServiceDetector::create_expected_session(AppIdSession& asd, const Packet uint16_t cliPort, const SfIp* srvIp, uint16_t srvPort, IpProtocol proto, int flags, APPID_SESSION_DIRECTION dir) { + if(ftp_data_snort_protocol_id == UNKNOWN_PROTOCOL_ID) + ftp_data_snort_protocol_id = SnortConfig::get_conf()->proto_ref->find("ftp-data"); + AppIdSession* fp = AppIdSession::create_future_session(pkt, cliIp, cliPort, srvIp, srvPort, - proto, ftp_data_app_id, flags, handler->get_inspector()); + proto, ftp_data_snort_protocol_id, flags, handler->get_inspector()); if (fp) // initialize data session { diff --git a/src/network_inspectors/appid/service_plugins/service_ftp.h b/src/network_inspectors/appid/service_plugins/service_ftp.h index 5438a8605..3b867e551 100644 --- a/src/network_inspectors/appid/service_plugins/service_ftp.h +++ b/src/network_inspectors/appid/service_plugins/service_ftp.h @@ -35,10 +35,10 @@ public: private: void create_expected_session(AppIdSession& asd,const snort::Packet* pkt, - const snort::SfIp* cliIp, uint16_t cliPort, const snort::SfIp* srvIp, uint16_t srvPort, - IpProtocol proto, int flags, APPID_SESSION_DIRECTION dir); + const snort::SfIp* cliIp, uint16_t cliPort, const snort::SfIp* srvIp, + uint16_t srvPort, IpProtocol proto, int flags, APPID_SESSION_DIRECTION dir); - int16_t ftp_data_app_id = 0; + SnortProtocolId ftp_data_snort_protocol_id = UNKNOWN_PROTOCOL_ID; }; #endif diff --git a/src/network_inspectors/appid/service_plugins/service_rexec.cc b/src/network_inspectors/appid/service_plugins/service_rexec.cc index a74fa5d95..70eeca6b6 100644 --- a/src/network_inspectors/appid/service_plugins/service_rexec.cc +++ b/src/network_inspectors/appid/service_plugins/service_rexec.cc @@ -63,8 +63,6 @@ RexecServiceDetector::RexecServiceDetector(ServiceDiscovery* sd) proto = IpProtocol::TCP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppInfoManager::get_instance().add_appid_protocol_reference("rexec"); - appid_registry = { { APP_ID_EXEC, APPINFO_FLAG_SERVICE_ADDITIONAL } @@ -119,6 +117,9 @@ int RexecServiceDetector::validate(AppIdDiscoveryArgs& args) switch (rd->state) { case REXEC_STATE_PORT: + if(rexec_snort_protocol_id == UNKNOWN_PROTOCOL_ID) + rexec_snort_protocol_id = snort::SnortConfig::get_conf()->proto_ref->find("rexec"); + if (args.dir != APP_ID_FROM_INITIATOR) goto bail; if (size > REXEC_MAX_PORT_PACKET) @@ -143,7 +144,7 @@ int RexecServiceDetector::validate(AppIdDiscoveryArgs& args) dip = args.pkt->ptrs.ip_api.get_dst(); sip = args.pkt->ptrs.ip_api.get_src(); AppIdSession* pf = AppIdSession::create_future_session(args.pkt, dip, 0, sip, (uint16_t)port, - IpProtocol::TCP, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); + IpProtocol::TCP, rexec_snort_protocol_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if (pf) { ServiceREXECData* tmp_rd = (ServiceREXECData*)snort_calloc( diff --git a/src/network_inspectors/appid/service_plugins/service_rexec.h b/src/network_inspectors/appid/service_plugins/service_rexec.h index b05df552f..8191b2fb9 100644 --- a/src/network_inspectors/appid/service_plugins/service_rexec.h +++ b/src/network_inspectors/appid/service_plugins/service_rexec.h @@ -34,7 +34,7 @@ public: int validate(AppIdDiscoveryArgs&) override; private: - int16_t app_id = 0; + SnortProtocolId rexec_snort_protocol_id = UNKNOWN_PROTOCOL_ID; }; #endif diff --git a/src/network_inspectors/appid/service_plugins/service_rpc.cc b/src/network_inspectors/appid/service_plugins/service_rpc.cc index be63786bd..11fb05c8e 100644 --- a/src/network_inspectors/appid/service_plugins/service_rpc.cc +++ b/src/network_inspectors/appid/service_plugins/service_rpc.cc @@ -185,8 +185,6 @@ RpcServiceDetector::RpcServiceDetector(ServiceDiscovery* sd) struct rpcent* rpc; RPCProgram* prog; - app_id = AppInfoManager::get_instance().add_appid_protocol_reference("sunrpc"); - if (!rpc_programs) { while ((rpc = getrpcent())) @@ -402,13 +400,16 @@ int RpcServiceDetector::validate_packet(const uint8_t* data, uint16_t size, int pmr = (const ServiceRPCPortmapReply*)data; if (pmr->port) { + if(sunrpc_snort_protocol_id == UNKNOWN_PROTOCOL_ID) + sunrpc_snort_protocol_id = SnortConfig::get_conf()->proto_ref->find("sunrpc"); + const SfIp* dip = pkt->ptrs.ip_api.get_dst(); const SfIp* sip = pkt->ptrs.ip_api.get_src(); tmp = ntohl(pmr->port); AppIdSession* pf = AppIdSession::create_future_session( pkt, dip, 0, sip, (uint16_t)tmp, - (IpProtocol)ntohl((uint32_t)rd->proto), app_id, 0, + (IpProtocol)ntohl((uint32_t)rd->proto), sunrpc_snort_protocol_id, 0, handler->get_inspector()); if (pf) { diff --git a/src/network_inspectors/appid/service_plugins/service_rpc.h b/src/network_inspectors/appid/service_plugins/service_rpc.h index aa1374cf0..fdd63b3b0 100644 --- a/src/network_inspectors/appid/service_plugins/service_rpc.h +++ b/src/network_inspectors/appid/service_plugins/service_rpc.h @@ -41,7 +41,7 @@ private: int rpc_tcp_validate(AppIdDiscoveryArgs&); int validate_packet(const uint8_t* data, uint16_t size, int dir, AppIdSession&, snort::Packet*, ServiceRPCData*, const char** pname, uint32_t* program); - int16_t app_id = 0; + SnortProtocolId sunrpc_snort_protocol_id = UNKNOWN_PROTOCOL_ID; }; #endif diff --git a/src/network_inspectors/appid/service_plugins/service_rshell.cc b/src/network_inspectors/appid/service_plugins/service_rshell.cc index 427fc3f40..35de59274 100644 --- a/src/network_inspectors/appid/service_plugins/service_rshell.cc +++ b/src/network_inspectors/appid/service_plugins/service_rshell.cc @@ -58,7 +58,6 @@ RshellServiceDetector::RshellServiceDetector(ServiceDiscovery* sd) name = "rshell"; proto = IpProtocol::TCP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppInfoManager::get_instance().add_appid_protocol_reference("rsh-error"); appid_registry = { @@ -135,6 +134,9 @@ int RshellServiceDetector::validate(AppIdDiscoveryArgs& args) goto bail; if (port) { + if(rsh_error_snort_protocol_id == UNKNOWN_PROTOCOL_ID) + rsh_error_snort_protocol_id = snort::SnortConfig::get_conf()->proto_ref->find("rsh-error"); + ServiceRSHELLData* tmp_rd = (ServiceRSHELLData*)snort_calloc( sizeof(ServiceRSHELLData)); tmp_rd->state = RSHELL_STATE_STDERR_CONNECT_SYN; @@ -142,7 +144,7 @@ int RshellServiceDetector::validate(AppIdDiscoveryArgs& args) const snort::SfIp* dip = args.pkt->ptrs.ip_api.get_dst(); const snort::SfIp* sip = args.pkt->ptrs.ip_api.get_src(); AppIdSession* pf = AppIdSession::create_future_session(args.pkt, dip, 0, sip, - (uint16_t)port, IpProtocol::TCP, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, + (uint16_t)port, IpProtocol::TCP, rsh_error_snort_protocol_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if (pf) { diff --git a/src/network_inspectors/appid/service_plugins/service_rshell.h b/src/network_inspectors/appid/service_plugins/service_rshell.h index 1e22e2188..b872b7f31 100644 --- a/src/network_inspectors/appid/service_plugins/service_rshell.h +++ b/src/network_inspectors/appid/service_plugins/service_rshell.h @@ -34,7 +34,7 @@ public: int validate(AppIdDiscoveryArgs&) override; private: - int16_t app_id = 0; + SnortProtocolId rsh_error_snort_protocol_id = UNKNOWN_PROTOCOL_ID; }; #endif diff --git a/src/network_inspectors/appid/service_plugins/service_snmp.cc b/src/network_inspectors/appid/service_plugins/service_snmp.cc index b5dc88e7f..ee4d499b1 100644 --- a/src/network_inspectors/appid/service_plugins/service_snmp.cc +++ b/src/network_inspectors/appid/service_plugins/service_snmp.cc @@ -95,8 +95,6 @@ SnmpServiceDetector::SnmpServiceDetector(ServiceDiscovery* sd) proto = IpProtocol::UDP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppInfoManager::get_instance().add_appid_protocol_reference("snmp"); - udp_patterns = { { SNMP_PATTERN_2, sizeof(SNMP_PATTERN_2), 2, 0, 0 }, @@ -474,10 +472,13 @@ int SnmpServiceDetector::validate(AppIdDiscoveryArgs& args) sd->state = SNMP_STATE_RESPONSE; /*adding expected connection in case the server doesn't send from 161*/ + if(snmp_snort_protocol_id == UNKNOWN_PROTOCOL_ID) + snmp_snort_protocol_id = snort::SnortConfig::get_conf()->proto_ref->find("snmp"); + const snort::SfIp* dip = args.pkt->ptrs.ip_api.get_dst(); const snort::SfIp* sip = args.pkt->ptrs.ip_api.get_src(); AppIdSession* pf = AppIdSession::create_future_session(args.pkt, dip, 0, sip, - args.pkt->ptrs.sp, args.asd.protocol, app_id, 0, handler->get_inspector()); + args.pkt->ptrs.sp, args.asd.protocol, snmp_snort_protocol_id, 0, handler->get_inspector()); if (pf) { tmp_sd = (ServiceSNMPData*)snort_calloc(sizeof(ServiceSNMPData)); diff --git a/src/network_inspectors/appid/service_plugins/service_snmp.h b/src/network_inspectors/appid/service_plugins/service_snmp.h index 54774b0bf..73499754f 100644 --- a/src/network_inspectors/appid/service_plugins/service_snmp.h +++ b/src/network_inspectors/appid/service_plugins/service_snmp.h @@ -34,7 +34,7 @@ public: int validate(AppIdDiscoveryArgs&) override; private: - int16_t app_id = 0; + SnortProtocolId snmp_snort_protocol_id = UNKNOWN_PROTOCOL_ID; }; #endif diff --git a/src/network_inspectors/appid/service_plugins/service_ssl.cc b/src/network_inspectors/appid/service_plugins/service_ssl.cc index 62603364a..2cc4ce155 100644 --- a/src/network_inspectors/appid/service_plugins/service_ssl.cc +++ b/src/network_inspectors/appid/service_plugins/service_ssl.cc @@ -1127,13 +1127,15 @@ void ssl_detector_free_patterns() ssl_patterns_free(&service_ssl_config.DetectorSSLCnamePatternList); } -bool setSSLSquelch(snort::Packet* p, int type, AppId appId, AppIdInspector& inspector) +bool setSSLSquelch(Packet* p, int type, AppId appId, AppIdInspector& inspector) { if (!AppInfoManager::get_instance().get_app_info_flags(appId, APPINFO_FLAG_SSL_SQUELCH)) return false; - const snort::SfIp* dip = p->ptrs.ip_api.get_dst(); - const snort::SfIp* sip = p->ptrs.ip_api.get_src(); + const SfIp* dip = p->ptrs.ip_api.get_dst(); + const SfIp* sip = p->ptrs.ip_api.get_src(); + + // FIXIT-H: Passing appId to create_future_session() is incorrect. We need to pass the snort_protocol_id associated with appId. AppIdSession* asd = AppIdSession::create_future_session(p, sip, 0, dip, p->ptrs.dp, IpProtocol::TCP, appId, 0, inspector); if ( asd ) diff --git a/src/network_inspectors/appid/service_plugins/service_tftp.cc b/src/network_inspectors/appid/service_plugins/service_tftp.cc index 214c540f9..f6de546e5 100644 --- a/src/network_inspectors/appid/service_plugins/service_tftp.cc +++ b/src/network_inspectors/appid/service_plugins/service_tftp.cc @@ -71,8 +71,6 @@ TftpServiceDetector::TftpServiceDetector(ServiceDiscovery* sd) proto = IpProtocol::UDP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppInfoManager::get_instance().add_appid_protocol_reference("tftp"); - appid_registry = { { APP_ID_TFTP, APPINFO_FLAG_SERVICE_ADDITIONAL } @@ -179,12 +177,15 @@ int TftpServiceDetector::validate(AppIdDiscoveryArgs& args) if (strcasecmp((const char*)data, "netascii") && strcasecmp((const char*)data, "octet")) goto bail; + if(tftp_snort_protocol_id == UNKNOWN_PROTOCOL_ID) + tftp_snort_protocol_id = snort::SnortConfig::get_conf()->proto_ref->find("tftp"); + tmp_td = (ServiceTFTPData*)snort_calloc(sizeof(ServiceTFTPData)); tmp_td->state = TFTP_STATE_TRANSFER; dip = args.pkt->ptrs.ip_api.get_dst(); sip = args.pkt->ptrs.ip_api.get_src(); pf = AppIdSession::create_future_session(args.pkt, dip, 0, sip, - args.pkt->ptrs.sp, args.asd.protocol, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, + args.pkt->ptrs.sp, args.asd.protocol, tftp_snort_protocol_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if (pf) { diff --git a/src/network_inspectors/appid/service_plugins/service_tftp.h b/src/network_inspectors/appid/service_plugins/service_tftp.h index da712e68f..8fc8778fa 100644 --- a/src/network_inspectors/appid/service_plugins/service_tftp.h +++ b/src/network_inspectors/appid/service_plugins/service_tftp.h @@ -34,7 +34,7 @@ public: int validate(AppIdDiscoveryArgs&) override; private: - int16_t app_id = 0; + SnortProtocolId tftp_snort_protocol_id = UNKNOWN_PROTOCOL_ID; }; #endif diff --git a/src/network_inspectors/appid/thirdparty_appid_utils.cc b/src/network_inspectors/appid/thirdparty_appid_utils.cc index 1d97ce512..d3e97d772 100644 --- a/src/network_inspectors/appid/thirdparty_appid_utils.cc +++ b/src/network_inspectors/appid/thirdparty_appid_utils.cc @@ -902,7 +902,7 @@ bool do_third_party_discovery(AppIdSession& asd, IpProtocol protocol, const SfIp } if (asd.tp_app_id == APP_ID_SSL && - (Stream::get_application_protocol_id(p->flow) == snortId_for_ftp_data)) + (Stream::get_snort_protocol_id(p->flow) == snortId_for_ftp_data)) { // If we see SSL on an FTP data channel set tpAppId back // to APP_ID_NONE so the FTP preprocessor picks up the flow. @@ -995,7 +995,7 @@ bool do_third_party_discovery(AppIdSession& asd, IpProtocol protocol, const SfIp snort_app_id = asd.tp_app_id; } - asd.sync_with_snort_id(snort_app_id, p); + asd.sync_with_snort_protocol_id(snort_app_id, p); } else { diff --git a/src/network_inspectors/binder/binder.cc b/src/network_inspectors/binder/binder.cc index f9135b895..4126965ed 100644 --- a/src/network_inspectors/binder/binder.cc +++ b/src/network_inspectors/binder/binder.cc @@ -389,15 +389,15 @@ static void set_session(Flow* flow) static void set_service(Flow* flow, const HostAttributeEntry* host) { - Stream::set_application_protocol_id(flow, host, FROM_SERVER); + Stream::set_snort_protocol_id(flow, host, FROM_SERVER); } static Inspector* get_gadget(Flow* flow) { - if ( !flow->ssn_state.application_protocol ) + if ( !flow->ssn_state.snort_protocol_id ) return nullptr; - const char* s = SnortConfig::get_conf()->proto_ref->get_name(flow->ssn_state.application_protocol); + const char* s = SnortConfig::get_conf()->proto_ref->get_name(flow->ssn_state.snort_protocol_id); return InspectorManager::get_inspector(s); } @@ -551,8 +551,8 @@ void Stuff::apply_service(Flow* flow, const HostAttributeEntry* host) { flow->set_gadget(gadget); - if ( !flow->ssn_state.application_protocol ) - flow->ssn_state.application_protocol = gadget->get_service(); + if ( !flow->ssn_state.snort_protocol_id ) + flow->ssn_state.snort_protocol_id = gadget->get_service(); } else if ( wizard ) @@ -674,10 +674,10 @@ int Binder::exec_handle_gadget( void* pv ) if (flow->gadget != nullptr ) flow->clear_gadget(); flow->set_gadget(ins); - flow->ssn_state.application_protocol = ins->get_service(); + flow->ssn_state.snort_protocol_id = ins->get_service(); } else if ( flow->service ) - flow->ssn_state.application_protocol = SnortConfig::get_conf()->proto_ref->find(flow->service); + flow->ssn_state.snort_protocol_id = SnortConfig::get_conf()->proto_ref->find(flow->service); if ( !flow->is_stream() ) return 0; diff --git a/src/parser/parse_conf.cc b/src/parser/parse_conf.cc index 6be94f946..ace4b71f3 100644 --- a/src/parser/parse_conf.cc +++ b/src/parser/parse_conf.cc @@ -122,7 +122,7 @@ void parse_include(SnortConfig* sc, const char* arg) void ParseIpVar(SnortConfig* sc, const char* var, const char* val) { int ret; - IpsPolicy* p = snort::get_ips_policy(); // FIXIT-M double check, see below + IpsPolicy* p = get_ips_policy(); // FIXIT-M double check, see below DisallowCrossTableDuplicateVars(sc, var, VAR_TYPE__IPVAR); if ((ret = sfvt_define(p->ip_vartable, var, val)) != SFIP_SUCCESS) @@ -161,10 +161,10 @@ void add_service_to_otn(SnortConfig* sc, OptTreeNode* otn, const char* svc_name) ParseError("too many service's specified for rule, can't add %s", svc_name); return; } - int16_t svc_id = sc->proto_ref->add(svc_name); + SnortProtocolId svc_id = sc->proto_ref->add(svc_name); for ( unsigned i = 0; i < otn->sigInfo.num_services; ++i ) - if ( otn->sigInfo.services[i].service_ordinal == svc_id ) + if ( otn->sigInfo.services[i].snort_protocol_id == svc_id ) return; // already added if ( !otn->sigInfo.services ) @@ -174,7 +174,7 @@ void add_service_to_otn(SnortConfig* sc, OptTreeNode* otn, const char* svc_name) int idx = otn->sigInfo.num_services++; otn->sigInfo.services[idx].service = snort_strdup(svc_name); - otn->sigInfo.services[idx].service_ordinal = svc_id; + otn->sigInfo.services[idx].snort_protocol_id = svc_id; } // only keep drop rules ... diff --git a/src/parser/parse_rule.cc b/src/parser/parse_rule.cc index a96ed0335..54a6b8cbe 100644 --- a/src/parser/parse_rule.cc +++ b/src/parser/parse_rule.cc @@ -99,7 +99,7 @@ static bool s_ignore = false; // for skipping drop rules when not inline, etc. */ static int FinishPortListRule( RulePortTables* port_tables, RuleTreeNode* rtn, OptTreeNode* otn, - int proto, FastPatternConfig* fp) + SnortProtocolId snort_protocol_id, FastPatternConfig* fp) { int large_port_group = 0; PortTable* dstTable; @@ -108,11 +108,11 @@ static int FinishPortListRule( rule_count_t* prc; uint32_t orig_flags = rtn->flags; - assert(otn->proto == proto); + assert(otn->snort_protocol_id == snort_protocol_id); /* Select the Target PortTable for this rule, based on protocol, src/dst * dir, and if there is rule content */ - switch ( proto ) + switch ( snort_protocol_id ) { case SNORT_PROTO_IP: dstTable = port_tables->ip.dst; @@ -207,7 +207,7 @@ static int FinishPortListRule( if (((rtn->flags & (ANY_DST_PORT|ANY_SRC_PORT)) == (ANY_DST_PORT|ANY_SRC_PORT)) || large_port_group || fp->get_single_rule_group()) { - if (proto == SNORT_PROTO_IP) + if (snort_protocol_id == SNORT_PROTO_IP) { /* Add the IP rules to the higher level app protocol groups, if they apply * to those protocols. All IP rules should have any-any port descriptors @@ -217,7 +217,7 @@ static int FinishPortListRule( "Finishing IP any-any rule %u:%u\n", otn->sigInfo.gid, otn->sigInfo.sid); - switch ( otn->proto ) + switch ( otn->snort_protocol_id ) { case SNORT_PROTO_IP: /* Add to all ip proto any port tables */ PortObjectAddRule(port_tables->icmp.any, otn->ruleIndex); @@ -353,7 +353,7 @@ static int ValidateIPList(sfip_var_t* addrset, const char* token) static int ProcessIP(SnortConfig*, const char* addr, RuleTreeNode* rtn, int mode, int) { - vartable_t* ip_vartable = snort::get_ips_policy()->ip_vartable; + vartable_t* ip_vartable = get_ips_policy()->ip_vartable; assert(rtn); /* If a rule has a variable in it, we want to copy that variable's @@ -582,7 +582,6 @@ static PortObject* ParsePortListTcpUdpPort( * * rtn - proto_node * port_str - port list string or port var name - * proto - protocol * dst_flag - dst or src port flag, true = dst, false = src * */ @@ -656,7 +655,7 @@ bool same_headers(RuleTreeNode* rule, RuleTreeNode* rtn) if (rule->type != rtn->type) return false; - if (rule->proto != rtn->proto) + if (rule->snort_protocol_id != rtn->snort_protocol_id) return false; /* For custom rule type declarations */ @@ -704,7 +703,7 @@ static void XferHeader(RuleTreeNode* from, RuleTreeNode* to) to->sip = from->sip; to->dip = from->dip; - to->proto = from->proto; + to->snort_protocol_id = from->snort_protocol_id; to->src_portobject = from->src_portobject; to->dst_portobject = from->dst_portobject; @@ -870,7 +869,7 @@ static RuleTreeNode* ProcessHeadNode( SnortConfig* sc, RuleTreeNode* test_node, ListHead* list) { RuleTreeNode* rtn = findHeadNode( - sc, test_node, snort::get_ips_policy()->policy_id); + sc, test_node, get_ips_policy()->policy_id); /* if it doesn't match any of the existing nodes, make a new node and * stick it at the end of the list */ @@ -909,7 +908,7 @@ static int mergeDuplicateOtn( SnortConfig* sc, OptTreeNode* otn_cur, OptTreeNode* otn_new, RuleTreeNode* rtn_new) { - if (otn_cur->proto != otn_new->proto) + if (otn_cur->snort_protocol_id != otn_new->snort_protocol_id) { ParseError("GID %u SID %u in rule duplicates previous rule, with different protocol.", otn_new->sigInfo.gid, otn_new->sigInfo.sid); @@ -958,7 +957,7 @@ static int mergeDuplicateOtn( { RuleTreeNode* rtnTmp2 = deleteRtnFromOtn(otn_cur, i, sc, (rtn_cur != rtn_new)); - if ( rtnTmp2 and (i != snort::get_ips_policy()->policy_id) ) + if ( rtnTmp2 and (i != get_ips_policy()->policy_id) ) { addRtnToOtn(sc, otn_new, rtnTmp2, i); } @@ -1090,9 +1089,9 @@ void parse_rule_proto(SnortConfig* sc, const char* s, RuleTreeNode& rtn) // this will allow other protocols like http to have ports rule_proto = PROTO_BIT__TCP; - rtn.proto = sc->proto_ref->add(s); + rtn.snort_protocol_id = sc->proto_ref->add(s); - if ( rtn.proto <= 0 ) + if ( rtn.snort_protocol_id == UNKNOWN_PROTOCOL_ID ) { ParseError("bad protocol: %s", s); rule_proto = 0; @@ -1114,7 +1113,7 @@ void parse_rule_ports( if ( s_ignore ) return; - IpsPolicy* p = snort::get_ips_policy(); + IpsPolicy* p = get_ips_policy(); if ( ParsePortList(&rtn, p->portVarTable, p->nonamePortVarTable, s, src ? SRC : DST) ) { @@ -1157,7 +1156,7 @@ void parse_rule_opt_end(SnortConfig* sc, const char* key, OptTreeNode* otn) return; RuleOptType type = OPT_TYPE_MAX; - IpsManager::option_end(sc, otn, otn->proto, key, type); + IpsManager::option_end(sc, otn, otn->snort_protocol_id, key, type); if ( type != OPT_TYPE_META ) otn->num_detection_opts++; @@ -1184,7 +1183,7 @@ OptTreeNode* parse_rule_open(SnortConfig* sc, RuleTreeNode& rtn, bool stub) otn->sigInfo.gid = GENERATOR_SNORT_ENGINE; otn->chain_node_number = otn_count; - otn->proto = rtn.proto; + otn->snort_protocol_id = rtn.snort_protocol_id; otn->enabled = SnortConfig::get_default_rule_state(); IpsManager::reset_options(); @@ -1284,8 +1283,8 @@ const char* parse_rule_close(SnortConfig* sc, RuleTreeNode& rtn, OptTreeNode* ot validate_fast_pattern(otn); OtnLookupAdd(sc->otn_map, otn); - if ( is_service_protocol(otn->proto) ) - add_service_to_otn(sc, otn, sc->proto_ref->get_name(otn->proto)); + if ( is_service_protocol(otn->snort_protocol_id) ) + add_service_to_otn(sc, otn, sc->proto_ref->get_name(otn->snort_protocol_id)); /* * The src/dst port parsing must be done before the Head Nodes are processed, since they must @@ -1294,7 +1293,7 @@ const char* parse_rule_close(SnortConfig* sc, RuleTreeNode& rtn, OptTreeNode* ot * After otn processing we can finalize port object processing for this rule */ if ( FinishPortListRule( - sc->port_tables, new_rtn, otn, rtn.proto, sc->fast_pattern_config) ) + sc->port_tables, new_rtn, otn, rtn.snort_protocol_id, sc->fast_pattern_config) ) ParseError("Failed to finish a port list rule."); // Clear ips_option vars diff --git a/src/parser/parser.cc b/src/parser/parser.cc index 3c59dcb22..e91befb4c 100644 --- a/src/parser/parser.cc +++ b/src/parser/parser.cc @@ -412,7 +412,7 @@ static void parse_file(SnortConfig* sc, Shell* sh) ***************************************************************************/ SnortConfig* ParseSnortConf(const SnortConfig* boot_conf, const char* fname) { - SnortConfig* sc = new SnortConfig; + SnortConfig* sc = new SnortConfig(SnortConfig::get_conf()->proto_ref); sc->logging_flags = boot_conf->logging_flags; VarNode* tmp = boot_conf->var_list; @@ -776,7 +776,7 @@ RuleTreeNode* deleteRtnFromOtn(OptTreeNode* otn, PolicyId policyId, SnortConfig* RuleTreeNode* deleteRtnFromOtn(OptTreeNode* otn, SnortConfig* sc) { - return deleteRtnFromOtn(otn, snort::get_ips_policy()->policy_id, sc); + return deleteRtnFromOtn(otn, get_ips_policy()->policy_id, sc); } static uint32_t rtn_hash_func(HashFnc*, const unsigned char *k, int) @@ -882,7 +882,7 @@ int addRtnToOtn(SnortConfig* sc, OptTreeNode* otn, RuleTreeNode* rtn, PolicyId p int addRtnToOtn(SnortConfig*sc, OptTreeNode* otn, RuleTreeNode* rtn) { - return addRtnToOtn(sc, otn, rtn, snort::get_ips_policy()->policy_id); + return addRtnToOtn(sc, otn, rtn, get_ips_policy()->policy_id); } void rule_index_map_print_index(int index, char* buf, int bufsize) diff --git a/src/profiler/rule_profiler.cc b/src/profiler/rule_profiler.cc index bb2706800..3415dda0e 100644 --- a/src/profiler/rule_profiler.cc +++ b/src/profiler/rule_profiler.cc @@ -321,7 +321,7 @@ void reset_rule_profiler_stats() auto* rtn = getRtnFromOtn(otn); - if ( !rtn || !is_network_protocol(rtn->proto) ) + if ( !rtn || !is_network_protocol(rtn->snort_protocol_id) ) continue; for ( unsigned i = 0; i < ThreadConfig::get_instance_max(); ++i ) diff --git a/src/protocols/packet.h b/src/protocols/packet.h index 130ae61e4..510dcd235 100644 --- a/src/protocols/packet.h +++ b/src/protocols/packet.h @@ -25,6 +25,7 @@ #include "flow/flow.h" #include "framework/decode_data.h" #include "main/snort_types.h" +#include "target_based/snort_protocols.h" class Endianness; class Obfuscator; @@ -96,7 +97,6 @@ enum PseudoPacketType constexpr int32_t MAX_PORTS = 65536; constexpr uint16_t NUM_IP_PROTOS = 256; -constexpr int16_t SFTARGET_UNKNOWN_PROTOCOL = -1; constexpr uint8_t TCP_OPTLENMAX = 40; /* (((2^4) - 1) * 4 - TCP_HEADER_LEN) */ constexpr uint8_t DEFAULT_LAYERMAX = 40; @@ -260,11 +260,11 @@ struct SO_PUBLIC Packet bool is_rebuilt() const { return (packet_flags & (PKT_REBUILT_STREAM|PKT_REBUILT_FRAG)) != 0; } - int16_t get_application_protocol() - { return flow ? flow->ssn_state.application_protocol : 0; } + SnortProtocolId get_snort_protocol_id() + { return flow ? flow->ssn_state.snort_protocol_id : UNKNOWN_PROTOCOL_ID; } - void set_application_protocol(int16_t ap) - { if ( flow ) flow->ssn_state.application_protocol = ap; } + void set_snort_protocol_id(SnortProtocolId proto_id) + { if ( flow ) flow->ssn_state.snort_protocol_id = proto_id; } private: bool allocated; diff --git a/src/search_engines/test/hyperscan_test.cc b/src/search_engines/test/hyperscan_test.cc index 4a552f0f0..e1507d9c8 100644 --- a/src/search_engines/test/hyperscan_test.cc +++ b/src/search_engines/test/hyperscan_test.cc @@ -62,7 +62,7 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf; static SnortState s_state; -SnortConfig::SnortConfig(SnortConfig*) +SnortConfig::SnortConfig(const SnortConfig* const) { state = &s_state; memset(state, 0, sizeof(*state)); diff --git a/src/search_engines/test/search_tool_test.cc b/src/search_engines/test/search_tool_test.cc index 6ea2e9a67..4e8ca96c0 100644 --- a/src/search_engines/test/search_tool_test.cc +++ b/src/search_engines/test/search_tool_test.cc @@ -52,7 +52,7 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf; static SnortState s_state; -SnortConfig::SnortConfig(SnortConfig*) +SnortConfig::SnortConfig(const SnortConfig* const) { state = &s_state; memset(state, 0, sizeof(*state)); diff --git a/src/service_inspectors/dce_rpc/ips_dce_iface.cc b/src/service_inspectors/dce_rpc/ips_dce_iface.cc index dbe54cbe5..7e4964786 100644 --- a/src/service_inspectors/dce_rpc/ips_dce_iface.cc +++ b/src/service_inspectors/dce_rpc/ips_dce_iface.cc @@ -215,7 +215,7 @@ public: uint32_t hash() const override; bool operator==(const IpsOption&) const override; EvalStatus eval(Cursor&, Packet*) override; - PatternMatchData* get_pattern(int proto, RuleDirection direction) override; + PatternMatchData* get_pattern(SnortProtocolId snort_protocol_id, RuleDirection direction) override; PatternMatchData* get_alternate_pattern() override; ~Dce2IfaceOption() override; @@ -267,14 +267,14 @@ static char* make_pattern_buffer( const Uuid &uuid, DceRpcBoFlag type ) return pattern_buf; } -PatternMatchData* Dce2IfaceOption::get_pattern(int proto, RuleDirection direction) +PatternMatchData* Dce2IfaceOption::get_pattern(SnortProtocolId snort_protocol_id, RuleDirection direction) { if (pmd.pattern_buf) { return &pmd; } - if (proto == SNORT_PROTO_TCP) + if (snort_protocol_id == SNORT_PROTO_TCP) { const char client_fp[] = "\x05\x00\x00"; const char server_fp[] = "\x05\x00\x02"; @@ -302,7 +302,7 @@ PatternMatchData* Dce2IfaceOption::get_pattern(int proto, RuleDirection directio } return &pmd; } - else if (proto == SNORT_PROTO_UDP) + else if (snort_protocol_id == SNORT_PROTO_UDP) { pmd.pattern_buf = make_pattern_buffer( uuid, DCERPC_BO_FLAG__LITTLE_ENDIAN ); pmd.pattern_size = sizeof(Uuid); diff --git a/src/service_inspectors/ftp_telnet/ft_main.h b/src/service_inspectors/ftp_telnet/ft_main.h index a915afc6e..c3539b93b 100644 --- a/src/service_inspectors/ftp_telnet/ft_main.h +++ b/src/service_inspectors/ftp_telnet/ft_main.h @@ -32,6 +32,8 @@ #ifndef FT_MAIN_H #define FT_MAIN_H +#include "target_based/snort_protocols.h" + #include "ftpp_ui_config.h" #define BUF_SIZE 1024 @@ -39,11 +41,10 @@ namespace snort { struct Packet; -struct ProfileStats; struct SnortConfig; } -extern int16_t ftp_data_app_id; +extern SnortProtocolId ftp_data_snort_protocol_id; void do_detection(snort::Packet*); diff --git a/src/service_inspectors/ftp_telnet/ftp.cc b/src/service_inspectors/ftp_telnet/ftp.cc index 15722695a..04c26746b 100644 --- a/src/service_inspectors/ftp_telnet/ftp.cc +++ b/src/service_inspectors/ftp_telnet/ftp.cc @@ -43,7 +43,7 @@ using namespace snort; -int16_t ftp_data_app_id = SFTARGET_UNKNOWN_PROTOCOL; +SnortProtocolId ftp_data_snort_protocol_id = UNKNOWN_PROTOCOL_ID; #define client_key "ftp_client" #define server_key "ftp_server" @@ -344,6 +344,7 @@ FtpServer::~FtpServer () bool FtpServer::configure(SnortConfig* sc) { + ftp_data_snort_protocol_id = sc->proto_ref->add("ftp-data"); return !FTPCheckConfigs(sc, ftp_server); } @@ -455,7 +456,6 @@ static Module* fs_mod_ctor() static void fs_init() { - ftp_data_app_id = SnortConfig::get_conf()->proto_ref->add("ftp-data"); FtpFlowData::init(); } diff --git a/src/service_inspectors/ftp_telnet/pp_ftp.cc b/src/service_inspectors/ftp_telnet/pp_ftp.cc index 940ae1371..463634c1b 100644 --- a/src/service_inspectors/ftp_telnet/pp_ftp.cc +++ b/src/service_inspectors/ftp_telnet/pp_ftp.cc @@ -1077,11 +1077,11 @@ static int do_stateful_checks(FTP_SESSION* session, Packet* p, session->datassn = ftpdata; /* Call into Streams to mark data channel as ftp-data */ - result = Stream::set_application_protocol_id_expected( + result = Stream::set_snort_protocol_id_expected( p, PktType::TCP, IpProtocol::TCP, &session->clientIP, session->clientPort, &session->serverIP, session->serverPort, - ftp_data_app_id, fd); + ftp_data_snort_protocol_id, fd); if (result < 0) { @@ -1155,11 +1155,11 @@ static int do_stateful_checks(FTP_SESSION* session, Packet* p, session->datassn = ftpdata; /* Call into Streams to mark data channel as ftp-data */ - result = Stream::set_application_protocol_id_expected( + result = Stream::set_snort_protocol_id_expected( p, PktType::TCP, IpProtocol::TCP, &session->clientIP, session->clientPort, &session->serverIP, session->serverPort, - ftp_data_app_id, fd); + ftp_data_snort_protocol_id, fd); if (result < 0) { diff --git a/src/stream/base/stream_ha.cc b/src/stream/base/stream_ha.cc index 52e2d773c..c0a519fdc 100644 --- a/src/stream/base/stream_ha.cc +++ b/src/stream/base/stream_ha.cc @@ -211,7 +211,7 @@ static void update_flags(Flow* flow) } if( ( old_state->ipprotocol != cur_state->ipprotocol ) || - ( old_state->application_protocol != cur_state->application_protocol ) || + ( old_state->snort_protocol_id != cur_state->snort_protocol_id ) || ( old_state->direction != cur_state->direction ) ) { flow->ha_state->add(FlowHAState::MODIFIED); diff --git a/src/stream/file/file_session.cc b/src/stream/file/file_session.cc index e5f2c9d0a..039c3d724 100644 --- a/src/stream/file/file_session.cc +++ b/src/stream/file/file_session.cc @@ -73,7 +73,7 @@ int FileSession::process(Packet* p) { Profile profile(file_ssn_stats); - p->flow->ssn_state.application_protocol = SNORT_PROTO_USER; + p->flow->ssn_state.snort_protocol_id = SNORT_PROTO_USER; StreamFileConfig* c = get_file_cfg(p->flow->ssn_server); FileFlows* file_flows = FileFlows::get_file_flows(p->flow); diff --git a/src/stream/stream.cc b/src/stream/stream.cc index 461f863c6..6e47f2a04 100644 --- a/src/stream/stream.cc +++ b/src/stream/stream.cc @@ -379,28 +379,28 @@ bool Stream::expected_flow(Flow* f, Packet* p) // app proto id foo //------------------------------------------------------------------------- -int Stream::set_application_protocol_id_expected( +int Stream::set_snort_protocol_id_expected( const Packet* ctrlPkt, PktType type, IpProtocol ip_proto, const SfIp* srcIP, uint16_t srcPort, const SfIp* dstIP, uint16_t dstPort, - int16_t appId, FlowData* fd) + SnortProtocolId snort_protocol_id, FlowData* fd) { assert(flow_con); return flow_con->add_expected( - ctrlPkt, type, ip_proto, srcIP, srcPort, dstIP, dstPort, appId, fd); + ctrlPkt, type, ip_proto, srcIP, srcPort, dstIP, dstPort, snort_protocol_id, fd); } -void Stream::set_application_protocol_id( +void Stream::set_snort_protocol_id( Flow* flow, const HostAttributeEntry* host_entry, int /*direction*/) { - int16_t application_protocol; + SnortProtocolId snort_protocol_id; if (!flow || !host_entry) return; /* Cool, its already set! */ - if (flow->ssn_state.application_protocol != 0) + if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) return; if (flow->ssn_state.ipprotocol == 0) @@ -408,7 +408,7 @@ void Stream::set_application_protocol_id( set_ip_protocol(flow); } - application_protocol = getApplicationProtocolId( + snort_protocol_id = get_snort_protocol_id_from_host_table( host_entry, flow->ssn_state.ipprotocol, flow->server_port, SFAT_SERVICE); @@ -416,32 +416,32 @@ void Stream::set_application_protocol_id( // FIXIT-M from client doesn't imply need to swap if (direction == FROM_CLIENT) { - if ( application_protocol && + if ( snort_protocol_id && (flow->ssn_state.session_flags & SSNFLAG_MIDSTREAM) ) flow->ssn_state.session_flags |= SSNFLAG_CLIENT_SWAP; } #endif - if (flow->ssn_state.application_protocol != application_protocol) + if (flow->ssn_state.snort_protocol_id != snort_protocol_id) { - flow->ssn_state.application_protocol = application_protocol; + flow->ssn_state.snort_protocol_id = snort_protocol_id; } } -int16_t Stream::get_application_protocol_id(Flow* flow) +SnortProtocolId Stream::get_snort_protocol_id(Flow* flow) { /* Not caching the source and dest host_entry in the session so we can * swap the table out after processing this packet if we need * to. */ if (!flow) - return 0; + return UNKNOWN_PROTOCOL_ID; - if ( flow->ssn_state.application_protocol == -1 ) - return 0; + if ( flow->ssn_state.snort_protocol_id == INVALID_PROTOCOL_ID ) + return UNKNOWN_PROTOCOL_ID; - if (flow->ssn_state.application_protocol != 0) - return flow->ssn_state.application_protocol; + if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) + return flow->ssn_state.snort_protocol_id; if (flow->ssn_state.ipprotocol == 0) { @@ -450,32 +450,32 @@ int16_t Stream::get_application_protocol_id(Flow* flow) if ( HostAttributeEntry* host_entry = SFAT_LookupHostEntryByIP(&flow->server_ip) ) { - set_application_protocol_id(flow, host_entry, FROM_SERVER); + set_snort_protocol_id(flow, host_entry, FROM_SERVER); - if (flow->ssn_state.application_protocol != 0) - return flow->ssn_state.application_protocol; + if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) + return flow->ssn_state.snort_protocol_id; } if ( HostAttributeEntry* host_entry = SFAT_LookupHostEntryByIP(&flow->client_ip) ) { - set_application_protocol_id(flow, host_entry, FROM_CLIENT); + set_snort_protocol_id(flow, host_entry, FROM_CLIENT); - if (flow->ssn_state.application_protocol != 0) - return flow->ssn_state.application_protocol; + if (flow->ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) + return flow->ssn_state.snort_protocol_id; } - flow->ssn_state.application_protocol = -1; - return 0; + flow->ssn_state.snort_protocol_id = INVALID_PROTOCOL_ID; + return UNKNOWN_PROTOCOL_ID; } -int16_t Stream::set_application_protocol_id(Flow* flow, int16_t id) +SnortProtocolId Stream::set_snort_protocol_id(Flow* flow, SnortProtocolId id) { if (!flow) - return 0; + return UNKNOWN_PROTOCOL_ID; - if (flow->ssn_state.application_protocol != id) + if (flow->ssn_state.snort_protocol_id != id) { - flow->ssn_state.application_protocol = id; + flow->ssn_state.snort_protocol_id = id; } if (!flow->ssn_state.ipprotocol) diff --git a/src/stream/stream.h b/src/stream/stream.h index 7d2a52d18..03ce1dfab 100644 --- a/src/stream/stream.h +++ b/src/stream/stream.h @@ -160,10 +160,10 @@ public: static bool missed_packets(Flow*, uint8_t dir); // Get the protocol identifier from a stream - static int16_t get_application_protocol_id(Flow*); + static SnortProtocolId get_snort_protocol_id(Flow*); // Set the protocol identifier for a stream - static int16_t set_application_protocol_id(Flow*, int16_t appId); + static SnortProtocolId set_snort_protocol_id(Flow*, SnortProtocolId); // initialize response count and expiration time static void init_active_response(const Packet*, Flow*); @@ -173,9 +173,9 @@ public: // Turn off inspection for potential session. Adds session identifiers to a hash table. // TCP only. - static int set_application_protocol_id_expected( + static int set_snort_protocol_id_expected( const Packet* ctrlPkt, PktType, IpProtocol, const snort::SfIp* srcIP, uint16_t srcPort, - const snort::SfIp* dstIP, uint16_t dstPort, int16_t appId, FlowData*); + const snort::SfIp* dstIP, uint16_t dstPort, SnortProtocolId, FlowData*); // Get pointer to application data for a flow based on the lookup tuples for cases where // Snort does not have an active packet that is relevant. @@ -206,7 +206,7 @@ public: static void update_direction(Flow*, char dir, const snort::SfIp* ip, uint16_t port); - static void set_application_protocol_id( + static void set_snort_protocol_id( Flow*, const HostAttributeEntry*, int direction); static bool is_midstream(Flow* flow) diff --git a/src/stream/tcp/tcp_reassembler.cc b/src/stream/tcp/tcp_reassembler.cc index a31de3de1..e04581942 100644 --- a/src/stream/tcp/tcp_reassembler.cc +++ b/src/stream/tcp/tcp_reassembler.cc @@ -631,13 +631,13 @@ int TcpReassembler::_flush_to_seq(uint32_t bytes, Packet* p, uint32_t pkt_flags) else pdu->packet_flags |= ( PKT_REBUILT_STREAM | PKT_STREAM_EST ); - pdu->set_application_protocol(p->get_application_protocol()); + pdu->set_snort_protocol_id(p->get_snort_protocol_id()); show_rebuilt_packet(pdu); tcpStats.rebuilt_packets++; tcpStats.rebuilt_bytes += flushed_bytes; ProfileExclude profile_exclude(s5TcpFlushPerfStats); - snort::Snort::inspect(pdu); + Snort::inspect(pdu); } else { @@ -723,12 +723,12 @@ int TcpReassembler::do_zero_byte_flush(Packet* p, uint32_t pkt_flags) pdu->data = sb.data; pdu->dsize = sb.length; pdu->packet_flags |= ( PKT_REBUILT_STREAM | PKT_STREAM_EST | PKT_PDU_HEAD | PKT_PDU_TAIL ); - pdu->set_application_protocol(p->get_application_protocol()); + pdu->set_snort_protocol_id(p->get_snort_protocol_id()); flush_count++; show_rebuilt_packet(pdu); ProfileExclude profile_exclude(s5TcpFlushPerfStats); - snort::Snort::inspect(pdu); + Snort::inspect(pdu); if ( tracker->splitter ) tracker->splitter->update(); } diff --git a/src/target_based/sftarget_data.h b/src/target_based/sftarget_data.h index d458cf7af..b202d240a 100644 --- a/src/target_based/sftarget_data.h +++ b/src/target_based/sftarget_data.h @@ -23,6 +23,7 @@ #define SFTARGET_DATA_H #include "sfip/sf_cidr.h" +#include "target_based/snort_protocols.h" #define SFAT_OK 0 #define SFAT_ERROR (-1) @@ -46,7 +47,7 @@ struct ApplicationEntry uint16_t port; uint16_t ipproto; - uint16_t protocol; + SnortProtocolId snort_protocol_id; uint8_t fields; }; diff --git a/src/target_based/sftarget_hostentry.cc b/src/target_based/sftarget_hostentry.cc index c007754f5..bf9d3fe98 100644 --- a/src/target_based/sftarget_hostentry.cc +++ b/src/target_based/sftarget_hostentry.cc @@ -103,7 +103,7 @@ bool hasProtocol(const HostAttributeEntry* host_entry, } #endif -int getApplicationProtocolId(const HostAttributeEntry* host_entry, +SnortProtocolId get_snort_protocol_id_from_host_table(const HostAttributeEntry* host_entry, int ipprotocol, uint16_t port, char direction) @@ -121,7 +121,7 @@ int getApplicationProtocolId(const HostAttributeEntry* host_entry, { if ((uint16_t)application->port == port) { - return application->protocol; + return application->snort_protocol_id; } } } diff --git a/src/target_based/sftarget_hostentry.h b/src/target_based/sftarget_hostentry.h index fa3ec9d93..97f628146 100644 --- a/src/target_based/sftarget_hostentry.h +++ b/src/target_based/sftarget_hostentry.h @@ -30,7 +30,7 @@ bool hasProtocol(const HostAttributeEntry*, int ipprotocol, int protocol, int application); #endif -int getApplicationProtocolId( +SnortProtocolId get_snort_protocol_id_from_host_table( const HostAttributeEntry*, int ipprotocol, uint16_t port, char direction); #endif diff --git a/src/target_based/sftarget_reader.cc b/src/target_based/sftarget_reader.cc index 1c480aa9e..38707daf3 100644 --- a/src/target_based/sftarget_reader.cc +++ b/src/target_based/sftarget_reader.cc @@ -204,8 +204,8 @@ static void PrintHostAttributeEntry(HostAttributeEntry* host) for (i=0, app = host->services; app; app = app->next,i++) { DebugFormat(DEBUG_ATTRIBUTE, "\tService #%d:\n", i); - DebugFormat(DEBUG_ATTRIBUTE, "\t\tIPProtocol: %d\tPort: %d\tProtocol %d\n", - app->ipproto, app->port, app->protocol); + DebugFormat(DEBUG_ATTRIBUTE, "\t\tIPProtocol: %d\tPort: %d\tSnortProtocolId %hu\n", + app->ipproto, app->port, app->snort_protocol_id); } if (i==0) DebugMessage(DEBUG_ATTRIBUTE, "\t\tNone\n"); @@ -214,8 +214,8 @@ static void PrintHostAttributeEntry(HostAttributeEntry* host) for (i=0, app = host->clients; app; app = app->next,i++) { DebugFormat(DEBUG_ATTRIBUTE, "\tClient #%d:\n", i); - DebugFormat(DEBUG_ATTRIBUTE, "\t\tIPProtocol: %d\tProtocol %d\n", - app->ipproto, app->protocol); + DebugFormat(DEBUG_ATTRIBUTE, "\t\tIPProtocol: %d\tSnortProtocolId %hu\n", + app->ipproto, app->snort_protocol_id); if (app->fields & APPLICATION_ENTRY_PORT) { @@ -347,7 +347,7 @@ tTargetBasedConfig* SFAT_Swap() return curr_cfg; } -void SFAT_UpdateApplicationProtocol(SfIp* ipAddr, uint16_t port, uint16_t protocol, uint16_t id) +void SFAT_UpdateApplicationProtocol(SfIp* ipAddr, uint16_t port, uint16_t protocol, SnortProtocolId snort_protocol_id) { HostAttributeEntry* host_entry; ApplicationEntry* service; @@ -394,11 +394,11 @@ void SFAT_UpdateApplicationProtocol(SfIp* ipAddr, uint16_t port, uint16_t protoc service->ipproto = protocol; service->next = host_entry->services; host_entry->services = service; - service->protocol = id; + service->snort_protocol_id = snort_protocol_id; } - else if (service->protocol != id) + else if (service->snort_protocol_id != snort_protocol_id) { - service->protocol = id; + service->snort_protocol_id = snort_protocol_id; } } diff --git a/src/target_based/snort_protocols.cc b/src/target_based/snort_protocols.cc index 5c16915ac..41d9a9f48 100644 --- a/src/target_based/snort_protocols.cc +++ b/src/target_based/snort_protocols.cc @@ -38,10 +38,12 @@ using namespace snort; using namespace std; -int16_t ProtocolReference::get_count() -{ return protocol_number; } +SnortProtocolId ProtocolReference::get_count() +{ + return protocol_number; +} -const char* ProtocolReference::get_name(uint16_t id) +const char* ProtocolReference::get_name(SnortProtocolId id) { if ( id >= id_map.size() ) id = 0; @@ -51,18 +53,18 @@ const char* ProtocolReference::get_name(uint16_t id) struct Compare { - bool operator()(uint16_t a, uint16_t b) + bool operator()(SnortProtocolId a, SnortProtocolId b) { return map[a] < map[b]; } vector& map; }; -const char* ProtocolReference::get_name_sorted(uint16_t id) +const char* ProtocolReference::get_name_sorted(SnortProtocolId id) { if ( ind_map.size() < id_map.size() ) { while ( ind_map.size() < id_map.size() ) - ind_map.push_back((uint16_t)ind_map.size()); + ind_map.push_back((SnortProtocolId)ind_map.size()); Compare c { id_map }; sort(ind_map.begin(), ind_map.end(), c); @@ -73,10 +75,10 @@ const char* ProtocolReference::get_name_sorted(uint16_t id) return id_map[ind_map[id]].c_str(); } -int16_t ProtocolReference::add(const char* protocol) +SnortProtocolId ProtocolReference::add(const char* protocol) { if (!protocol) - return SFTARGET_UNKNOWN_PROTOCOL; + return UNKNOWN_PROTOCOL_ID; auto protocol_ref = ref_table.find(protocol); if ( protocol_ref != ref_table.end() ) @@ -87,14 +89,14 @@ int16_t ProtocolReference::add(const char* protocol) return protocol_ref->second; } - int16_t ordinal = protocol_number++; + SnortProtocolId snort_protocol_id = protocol_number++; id_map.push_back(protocol); - ref_table[protocol] = ordinal; + ref_table[protocol] = snort_protocol_id; - return ordinal; + return snort_protocol_id; } -int16_t ProtocolReference::find(const char* protocol) +SnortProtocolId ProtocolReference::find(const char* protocol) { auto protocol_ref = ref_table.find(protocol); if ( protocol_ref != ref_table.end() ) @@ -105,20 +107,41 @@ int16_t ProtocolReference::find(const char* protocol) return protocol_ref->second; } - return SFTARGET_UNKNOWN_PROTOCOL; + return UNKNOWN_PROTOCOL_ID; } -ProtocolReference::ProtocolReference() +void ProtocolReference::init(ProtocolReference* old_proto_ref) { id_map.push_back("unknown"); - bool ok = ( add("ip") == SNORT_PROTO_IP ); - ok = ( add("icmp") == SNORT_PROTO_ICMP ) and ok; - ok = ( add("tcp") == SNORT_PROTO_TCP ) and ok; - ok = ( add("udp") == SNORT_PROTO_UDP ) and ok; - ok = ( add("user") == SNORT_PROTO_USER ) and ok; - ok = ( add("file") == SNORT_PROTO_FILE ) and ok; - assert(ok); + if(!old_proto_ref) + { + bool ok = ( add("ip") == SNORT_PROTO_IP ); + ok = ( add("icmp") == SNORT_PROTO_ICMP ) and ok; + ok = ( add("tcp") == SNORT_PROTO_TCP ) and ok; + ok = ( add("udp") == SNORT_PROTO_UDP ) and ok; + ok = ( add("user") == SNORT_PROTO_USER ) and ok; + ok = ( add("file") == SNORT_PROTO_FILE ) and ok; + assert(ok); + } + else + { + // Copy old ProtocolReference ID/name pairs to new ProtocolReference + for(SnortProtocolId id = 1; id < old_proto_ref->get_count(); id++) + { + add(old_proto_ref->get_name(id)); + } + } +} + +ProtocolReference::ProtocolReference() +{ + init(nullptr); +} + +ProtocolReference::ProtocolReference(ProtocolReference* old_proto_ref) +{ + init(old_proto_ref); } ProtocolReference::~ProtocolReference() diff --git a/src/target_based/snort_protocols.h b/src/target_based/snort_protocols.h index 511cac6d4..ca331aef3 100644 --- a/src/target_based/snort_protocols.h +++ b/src/target_based/snort_protocols.h @@ -28,13 +28,12 @@ #include "main/snort_types.h" -// FIXIT-L use logical type instead of int16_t -// for all reference protocols +using SnortProtocolId = uint16_t; // these protocols are always defined because // they are used as consts in switch statements // other protos are added dynamically as used -enum SnortProtocols +enum SnortProtocols : SnortProtocolId { // The is_*_protocol functions depend on the order of these enums. SNORT_PROTO_IP = 1, @@ -46,36 +45,49 @@ enum SnortProtocols SNORT_PROTO_MAX }; -inline bool is_network_protocol(int16_t proto) +constexpr SnortProtocolId UNKNOWN_PROTOCOL_ID = 0; +constexpr SnortProtocolId INVALID_PROTOCOL_ID = 0xffff; + +inline bool is_network_protocol(SnortProtocolId proto) { return (proto >= SNORT_PROTO_IP and proto <= SNORT_PROTO_UDP); } -inline bool is_builtin_protocol(int16_t proto) +inline bool is_builtin_protocol(SnortProtocolId proto) { return proto < SNORT_PROTO_MAX; } -inline bool is_service_protocol(int16_t proto) +inline bool is_service_protocol(SnortProtocolId proto) { return proto > SNORT_PROTO_UDP; } +// A mapping between names and IDs. class SO_PUBLIC ProtocolReference { public: ProtocolReference(); ~ProtocolReference(); - int16_t get_count(); + ProtocolReference(ProtocolReference* old_proto_ref); + + ProtocolReference(const ProtocolReference&) = delete; + ProtocolReference& operator=(const ProtocolReference&) = delete; + + SnortProtocolId get_count(); - const char* get_name(uint16_t id); - const char* get_name_sorted(uint16_t id); + const char* get_name(SnortProtocolId id); + const char* get_name_sorted(SnortProtocolId id); - int16_t add(const char* protocol); - int16_t find(const char* protocol); + SnortProtocolId add(const char* protocol); + SnortProtocolId find(const char* protocol); - bool operator()(uint16_t a, uint16_t b); + bool operator()(SnortProtocolId a, SnortProtocolId b); private: std::vector id_map; - std::vector ind_map; - std::unordered_map ref_table; - int16_t protocol_number = 1; + std::vector ind_map; + std::unordered_map ref_table; + + // Start at 1 since 0 will be "unknown". + SnortProtocolId protocol_number = 1; + + void init(ProtocolReference* old_proto_ref); }; #endif -- 2.47.3