From: Russ Combs (rucombs) Date: Tue, 24 Oct 2017 22:00:06 +0000 (-0400) Subject: Merge pull request #1048 in SNORT/snort3 from appid_get_inspector_no_mas to master X-Git-Tag: 3.0.0-240~4 X-Git-Url: http://git.ipfire.org/gitweb.cgi?a=commitdiff_plain;h=0c244a80fe6a4034ea5fa0b126ba164bb63378bd;p=thirdparty%2Fsnort3.git Merge pull request #1048 in SNORT/snort3 from appid_get_inspector_no_mas to master Squashed commit of the following: commit 20c0eab95890d1027e4cc1de348616f21ef6547a Author: davis mcpherson Date: Mon Oct 23 21:51:19 2017 -0400 fix appid statistics counts to verify id is valid and also add a count for unknown app id (should not happen) commit b125a3db7994f7ae59790544e2d235f16b862fbb Author: davis mcpherson Date: Sat Oct 21 16:34:27 2017 -0400 refactor Lua app detectors to eliminate need for multipl inheritance commit 7018a0ea007728f8aa0792e39d8f7491090d96b1 Author: davis mcpherson Date: Wed Oct 18 13:19:24 2017 -0400 refactor appid to eliminate need to call get_inspector method refactor appid so that detectors, discovery handlers, etc. have a pointer to the AppId inspector or config instance when they need it refactor unit tests to work with changes to appid inspector handle management use static_cast instead of dynamic_cast to cast lua detector object to its correct type --- diff --git a/src/network_inspectors/appid/app_info_table.cc b/src/network_inspectors/appid/app_info_table.cc index 61d5ee9f2..60aecce95 100644 --- a/src/network_inspectors/appid/app_info_table.cc +++ b/src/network_inspectors/appid/app_info_table.cc @@ -34,7 +34,9 @@ #include "appid_peg_counts.h" #include "log/messages.h" #include "log/unified2.h" +#include "main/snort_config.h" #include "main/snort_debug.h" +#include "target_based/snort_protocols.h" #include "utils/util_cstring.h" static AppInfoTable app_info_table; @@ -499,6 +501,15 @@ void AppInfoManager::load_appid_config(AppIdModuleConfig* config, const char* pa fclose(config_file); } +int16_t AppInfoManager::add_appid_protocol_reference(const char* protocol) +{ + static std::mutex apr_mutex; + + std::lock_guard lock(apr_mutex); + int16_t id = snort_conf->proto_ref->add(protocol); + return id; +} + void AppInfoManager::init_appid_info_table(AppIdModuleConfig* mod_config) { char buf[MAX_TABLE_LINE_LEN]; @@ -565,7 +576,7 @@ void AppInfoManager::init_appid_info_table(AppIdModuleConfig* mod_config) /* snort service key, if it exists */ token = strtok_r(nullptr, CONF_SEPARATORS, &context); if (token) - entry->snortId = AppIdInspector::get_inspector()->add_appid_protocol_reference(token); + entry->snortId = add_appid_protocol_reference(token); if ((app_id = get_static_app_info_entry(entry->appId))) { diff --git a/src/network_inspectors/appid/app_info_table.h b/src/network_inspectors/appid/app_info_table.h index 3d1532ced..604c3f5ff 100644 --- a/src/network_inspectors/appid/app_info_table.h +++ b/src/network_inspectors/appid/app_info_table.h @@ -141,6 +141,7 @@ public: void init_appid_info_table(AppIdModuleConfig*); void cleanup_appid_info_table(); void dump_app_info_table(); + int16_t add_appid_protocol_reference(const char* protocol); private: AppInfoManager() = default; diff --git a/src/network_inspectors/appid/appid_api.cc b/src/network_inspectors/appid/appid_api.cc index 3d7ca1096..10c5d8f58 100644 --- a/src/network_inspectors/appid/appid_api.cc +++ b/src/network_inspectors/appid/appid_api.cc @@ -550,8 +550,10 @@ uint32_t AppIdApi::produce_ha_state(Flow* flow, uint8_t* buf) return sizeof(*appHA); } -uint32_t AppIdApi::consume_ha_state(Flow* flow, const uint8_t* buf, uint8_t, IpProtocol proto, - SfIp* ip, uint16_t port) +// FIXIT-H last param AppIdSession ctor is appid inspector, we need that but no good way to get it +// at the moment...code to allocate session ifdef'ed out until this is resolved... +uint32_t AppIdApi::consume_ha_state(Flow* flow, const uint8_t* buf, uint8_t, IpProtocol /*proto*/, + SfIp* /*ip*/, uint16_t /*port*/) { const AppIdSessionHA* appHA = (const AppIdSessionHA*)buf; if (appHA->flags & APPID_HA_FLAGS_APP) @@ -559,9 +561,10 @@ uint32_t AppIdApi::consume_ha_state(Flow* flow, const uint8_t* buf, uint8_t, IpP AppIdSession* asd = (AppIdSession*)(flow->get_flow_data(AppIdSession::inspector_id)); +#ifdef APPID_HA_SUPPORT_ENABLED if (!asd) { - asd = new AppIdSession(proto, ip, port); + asd = new AppIdSession(proto, ip, port, nullptr); flow->set_flow_data(asd); asd->service.set_id(appHA->appId[1]); if (asd->service.get_id() == APP_ID_FTP_CONTROL) @@ -580,14 +583,23 @@ uint32_t AppIdApi::consume_ha_state(Flow* flow, const uint8_t* buf, uint8_t, IpP if (thirdparty_appid_module) thirdparty_appid_module->session_state_set(asd->tpsession, TP_STATE_HA); } +#else + if ( !asd ) + { + assert(false); + return sizeof(*appHA); + } +#endif if ( ( appHA->flags & APPID_HA_FLAGS_TP_DONE ) && thirdparty_appid_module ) { thirdparty_appid_module->session_state_set(asd->tpsession, TP_STATE_TERMINATED); asd->set_session_flags(APPID_SESSION_NO_TPI); } + if (appHA->flags & APPID_HA_FLAGS_SVC_DONE) asd->set_service_detected(); + if (appHA->flags & APPID_HA_FLAGS_HTTP) asd->set_session_flags(APPID_SESSION_HTTP_SESSION); diff --git a/src/network_inspectors/appid/appid_app_descriptor.h b/src/network_inspectors/appid/appid_app_descriptor.h index 9ad03e744..74bdd5807 100644 --- a/src/network_inspectors/appid/appid_app_descriptor.h +++ b/src/network_inspectors/appid/appid_app_descriptor.h @@ -119,8 +119,11 @@ public: void set_port_service_id(AppId id) { - port_service_id = id; - AppIdPegCounts::inc_service_count(id); + if ( id != port_service_id ) + { + port_service_id = id; + AppIdPegCounts::inc_service_count(id); + } } private: diff --git a/src/network_inspectors/appid/appid_detector.h b/src/network_inspectors/appid/appid_detector.h index 6c4fa3058..3cb05a4d2 100644 --- a/src/network_inspectors/appid/appid_detector.h +++ b/src/network_inspectors/appid/appid_detector.h @@ -30,7 +30,7 @@ #include "flow/flow.h" class AppIdConfig; -class LuaDetector; +class LuaStateDescriptor; struct Packet; #define STATE_ID_MAX_VALID_COUNT 5 @@ -120,48 +120,39 @@ public: virtual void add_app(AppIdSession*, AppId, AppId, const char*); const std::string& get_name() const - { - return name; - } + { return name; } unsigned get_minimum_matches() const - { - return minimum_matches; - } + { return minimum_matches; } void set_minimum_matches(unsigned minimumMatches = 0) - { - minimum_matches = minimumMatches; - } + { minimum_matches = minimumMatches; } unsigned int get_precedence() const - { - return precedence; - } + { return precedence; } unsigned get_flow_data_index() const - { - return flow_data_index; - } + { return flow_data_index; } bool is_custom_detector() const - { - return custom_detector; - } + { return custom_detector; } void set_custom_detector(bool isCustom = false) - { - this->custom_detector = isCustom; - } + { this->custom_detector = isCustom; } AppIdDiscovery& get_handler() const - { - return *handler; - } + { return *handler; } + + bool is_client() const + { return client; } + + virtual LuaStateDescriptor* validate_lua_state(bool /*packet_context*/) + { return nullptr; } protected: AppIdDiscovery* handler = nullptr; std::string name; + bool client = false; bool enabled = true; bool custom_detector = false; IpProtocol proto = IpProtocol::PROTO_NOT_SET; diff --git a/src/network_inspectors/appid/appid_discovery.cc b/src/network_inspectors/appid/appid_discovery.cc index 987af761e..47949f480 100644 --- a/src/network_inspectors/appid/appid_discovery.cc +++ b/src/network_inspectors/appid/appid_discovery.cc @@ -45,7 +45,8 @@ #include "protocols/packet.h" #include "protocols/tcp.h" -AppIdDiscovery::AppIdDiscovery() +AppIdDiscovery::AppIdDiscovery(AppIdInspector& ins) + : inspector(ins) { tcp_patterns = new SearchTool("ac_full", true); udp_patterns = new SearchTool("ac_full", true); @@ -68,10 +69,10 @@ AppIdDiscovery::~AppIdDiscovery() delete kv.second; } -void AppIdDiscovery::initialize_plugins() +void AppIdDiscovery::initialize_plugins(AppIdInspector* ins) { - ServiceDiscovery::get_instance(); - ClientDiscovery::get_instance(); + ServiceDiscovery::get_instance(ins); + ClientDiscovery::get_instance(ins); } void AppIdDiscovery::finalize_plugins() @@ -136,14 +137,14 @@ static inline int match_pe_network(const SfIp* pktAddr, const PortExclusion* pe) && ((pkt[3] & nm[3]) == peIP[3])); } -static inline int check_port_exclusion(const Packet* pkt, bool reversed) +static inline int check_port_exclusion(const Packet* pkt, bool reversed, AppIdInspector& inspector) { AppIdPortExclusions* src_port_exclusions; AppIdPortExclusions* dst_port_exclusions; SF_LIST* pe_list; PortExclusion* pe; const SfIp* s_ip; - AppIdConfig* config = AppIdInspector::get_inspector()->get_appid_config(); + AppIdConfig* config = inspector.get_appid_config(); if ( pkt->is_tcp() ) { @@ -350,7 +351,8 @@ static bool is_packet_ignored(AppIdSession* asd, Packet* p, int& direction) return false; } -static uint64_t is_session_monitored(AppIdSession& asd, const Packet* p, int dir) +static uint64_t is_session_monitored(AppIdSession& asd, const Packet* p, int dir, + AppIdInspector& inspector) { uint64_t flags = 0; uint64_t flow_flags = APPID_SESSION_DISCOVER_APP; @@ -365,7 +367,7 @@ static uint64_t is_session_monitored(AppIdSession& asd, const Packet* p, int dir // accordingly if ( asd.common.policyId != asd.config->appIdPolicyId ) { - if (check_port_exclusion(p, dir == APP_ID_FROM_RESPONDER)) + if ( check_port_exclusion(p, dir == APP_ID_FROM_RESPONDER, inspector) ) { flow_flags |= APPID_SESSION_INITIATOR_SEEN | APPID_SESSION_RESPONDER_SEEN | APPID_SESSION_INITIATOR_CHECKED | APPID_SESSION_RESPONDER_CHECKED; @@ -483,7 +485,7 @@ static uint64_t is_session_monitored(AppIdSession& asd, const Packet* p, int dir return flow_flags; } -static uint64_t is_session_monitored(const Packet* p, int dir) +static uint64_t is_session_monitored(const Packet* p, int dir, AppIdInspector& inspector) { uint64_t flags = 0; uint64_t flow_flags = APPID_SESSION_DISCOVER_APP; @@ -491,7 +493,7 @@ static uint64_t is_session_monitored(const Packet* p, int dir) flow_flags |= (dir == APP_ID_FROM_INITIATOR) ? APPID_SESSION_INITIATOR_SEEN : APPID_SESSION_RESPONDER_SEEN; - if (check_port_exclusion(p, false)) + if ( check_port_exclusion(p, false, inspector) ) { flow_flags |= APPID_SESSION_INITIATOR_SEEN | APPID_SESSION_RESPONDER_SEEN | APPID_SESSION_INITIATOR_CHECKED | APPID_SESSION_RESPONDER_CHECKED; @@ -589,7 +591,7 @@ static void lookup_appid_by_host_port(AppIdSession* asd, Packet* p, IpProtocol p } } -void AppIdDiscovery::do_application_discovery(Packet* p) +void AppIdDiscovery::do_application_discovery(Packet* p, AppIdInspector& inspector) { IpProtocol protocol = IpProtocol::PROTO_NOT_SET; bool isTpAppidDiscoveryDone = false; @@ -607,9 +609,9 @@ void AppIdDiscovery::do_application_discovery(Packet* p) uint64_t flow_flags; if (asd) - flow_flags = is_session_monitored(*asd, p, direction); + flow_flags = is_session_monitored(*asd, p, direction, inspector); else - flow_flags = is_session_monitored(p, direction); + flow_flags = is_session_monitored(p, direction, inspector); if ( !( flow_flags & (APPID_SESSION_DISCOVER_APP | APPID_SESSION_SPECIAL_MONITORED) ) ) { @@ -625,7 +627,7 @@ void AppIdDiscovery::do_application_discovery(Packet* p) port = (direction == APP_ID_FROM_INITIATOR) ? p->ptrs.sp : p->ptrs.dp; } - AppIdSession* tmp_session = new AppIdSession(protocol, ip, port); + AppIdSession* tmp_session = new AppIdSession(protocol, ip, port, inspector); if ((flow_flags & APPID_SESSION_BIDIRECTIONAL_CHECKED) == APPID_SESSION_BIDIRECTIONAL_CHECKED) @@ -633,8 +635,7 @@ void AppIdDiscovery::do_application_discovery(Packet* p) else tmp_session->common.flow_type = APPID_FLOW_TYPE_TMP; tmp_session->common.flags = flow_flags; - tmp_session->common.policyId = - AppIdInspector::get_inspector()->get_appid_config()->appIdPolicyId; + tmp_session->common.policyId = inspector.get_appid_config()->appIdPolicyId; p->flow->set_flow_data(tmp_session); } else @@ -651,7 +652,8 @@ void AppIdDiscovery::do_application_discovery(Packet* p) if ( !asd || asd->common.flow_type == APPID_FLOW_TYPE_TMP ) { - asd = AppIdSession::allocate_session(p, protocol, direction); + asd = AppIdSession::allocate_session(p, protocol, direction, inspector); + if (asd->session_logging_enabled) LogMessage("AppIdDbg %s new session\n", asd->session_logging_id); } diff --git a/src/network_inspectors/appid/appid_discovery.h b/src/network_inspectors/appid/appid_discovery.h index ebf445b35..7f2d2b22c 100644 --- a/src/network_inspectors/appid/appid_discovery.h +++ b/src/network_inspectors/appid/appid_discovery.h @@ -33,6 +33,7 @@ #include "flow/flow.h" #include "utils/util.h" +class AppIdInspector; class AppIdSession; class AppIdDetector; class ServiceDetector; @@ -82,9 +83,9 @@ typedef AppIdDetectors::iterator AppIdDetectorsIterator; class AppIdDiscovery { public: - AppIdDiscovery(); + AppIdDiscovery(AppIdInspector& ins); virtual ~AppIdDiscovery(); - static void initialize_plugins(); + static void initialize_plugins(AppIdInspector* ins); static void finalize_plugins(); static void release_plugins(); @@ -98,7 +99,7 @@ public: int position, unsigned nocase); virtual int add_service_port(AppIdDetector*, const ServiceDetectorPort&); - static void do_application_discovery(Packet* p); + static void do_application_discovery(Packet* p, AppIdInspector&); AppIdDetectors* get_tcp_detectors() { @@ -110,7 +111,11 @@ public: return &udp_detectors; } + AppIdInspector& get_inspector() + { return inspector; } + protected: + AppIdInspector& inspector; AppIdDetectors tcp_detectors; AppIdDetectors udp_detectors; SearchTool* tcp_patterns = nullptr; diff --git a/src/network_inspectors/appid/appid_inspector.cc b/src/network_inspectors/appid/appid_inspector.cc index 362142ef7..67b8e2e6c 100644 --- a/src/network_inspectors/appid/appid_inspector.cc +++ b/src/network_inspectors/appid/appid_inspector.cc @@ -44,14 +44,10 @@ #include "detector_plugins/detector_pattern.h" #include "log/messages.h" #include "log/packet_tracer.h" -#include "main/snort_config.h" #include "managers/inspector_manager.h" #include "managers/module_manager.h" #include "protocols/packet.h" #include "profiler/profiler.h" -#include "target_based/snort_protocols.h" - -static THREAD_LOCAL AppIdStatistics* appid_stats_manager = nullptr; // FIXIT-L - appid cleans up openssl now as it is the primary (only) user... eventually this // should probably be done outside of appid @@ -92,31 +88,11 @@ AppIdInspector::~AppIdInspector() delete config; } -AppIdInspector* AppIdInspector::get_inspector() -{ - return (AppIdInspector*)InspectorManager::get_inspector(MOD_NAME); -} - AppIdConfig* AppIdInspector::get_appid_config() { return active_config; } -AppIdStatistics* AppIdInspector::get_stats_manager() -{ - return appid_stats_manager; -} - -int16_t AppIdInspector::add_appid_protocol_reference(const char* protocol) -{ - static std::mutex apr_mutex; - - apr_mutex.lock(); - int16_t id = snort_conf->proto_ref->add(protocol); - apr_mutex.unlock(); - return id; -} - bool AppIdInspector::configure(SnortConfig*) { assert(!active_config); @@ -158,12 +134,12 @@ void AppIdInspector::show(SnortConfig*) void AppIdInspector::tinit() { - appid_stats_manager = AppIdStatistics::initialize_manager(*config); + AppIdStatistics::initialize_manager(*config); HostPortCache::initialize(); AppIdServiceState::initialize(); init_appid_forecast(); HttpPatternMatchers* http_matchers = HttpPatternMatchers::get_instance(); - AppIdDiscovery::initialize_plugins(); + AppIdDiscovery::initialize_plugins(this); init_length_app_cache(); LuaDetectorManager::initialize(*active_config); PatternServiceDetector::finalize_service_port_patterns(); @@ -177,7 +153,7 @@ void AppIdInspector::tinit() void AppIdInspector::tterm() { - delete appid_stats_manager; + AppIdStatistics::cleanup(); HostPortCache::terminate(); clean_appid_forecast(); service_dns_host_clean(); @@ -197,7 +173,7 @@ void AppIdInspector::eval(Packet* p) AppIdPegCounts::inc_disco_peg(AppIdPegCounts::DiscoveryPegs::PACKETS); if (p->flow) { - AppIdDiscovery::do_application_discovery(p); + AppIdDiscovery::do_application_discovery(p, *this); if (PacketTracer::get_enable()) add_appid_to_packet_trace(p->flow); } @@ -306,13 +282,13 @@ int sslAppGroupIdLookup(void*, const char*, const char*, AppId*, AppId*, AppId*) if (commonName) { ssl_scan_cname((const uint8_t*)commonName, strlen(commonName), client_id, payload_app_id, - &AppIdInspector::get_inspector()->get_appid_config()->serviceSslConfig); + &get_appid_config()->serviceSslConfig); } if (serverName) { ssl_scan_hostname((const uint8_t*)serverName, strlen(serverName), client_id, payload_app_id, - &AppIdInspector::get_inspector()->get_appid_config()->serviceSslConfig); + &get_appid_config()->serviceSslConfig); } if (ssnptr && (asd = appid_api.get_appid_session(ssnptr))) diff --git a/src/network_inspectors/appid/appid_inspector.h b/src/network_inspectors/appid/appid_inspector.h index a098b16c1..97d8d5d7d 100644 --- a/src/network_inspectors/appid/appid_inspector.h +++ b/src/network_inspectors/appid/appid_inspector.h @@ -27,7 +27,6 @@ #include "application_ids.h" #include "flow/flow.h" -class AppIdStatistics; struct Packet; class SipEventHandler; struct SnortConfig; @@ -38,7 +37,6 @@ public: AppIdInspector(AppIdModule&); ~AppIdInspector() override; - static AppIdInspector* get_inspector(); bool configure(SnortConfig*) override; void show(SnortConfig*) override; @@ -46,8 +44,6 @@ public: void tterm() override; void eval(Packet*) override; AppIdConfig* get_appid_config(); - AppIdStatistics* get_stats_manager(); - int16_t add_appid_protocol_reference(const char* protocol); SipEventHandler& get_sip_event_handler() { @@ -57,7 +53,7 @@ public: private: const AppIdModuleConfig* config = nullptr; AppIdConfig* active_config = nullptr; - SipEventHandler* my_seh; + SipEventHandler* my_seh = nullptr; }; diff --git a/src/network_inspectors/appid/appid_peg_counts.cc b/src/network_inspectors/appid/appid_peg_counts.cc index b2da0f659..780098881 100644 --- a/src/network_inspectors/appid/appid_peg_counts.cc +++ b/src/network_inspectors/appid/appid_peg_counts.cc @@ -30,6 +30,7 @@ #include bool AppIdPegCounts::detectors_configured = false; +uint32_t AppIdPegCounts::unknown_app_idx = 0; std::map AppIdPegCounts::appid_detector_pegs_idx; std::vector AppIdPegCounts::appid_detectors_peg_info; std::vector AppIdPegCounts::appid_pegs = @@ -76,7 +77,7 @@ void AppIdPegCounts::add_app_peg_info(AppInfoTableEntry& entry, AppId app_id ) PegCount* AppIdPegCounts::get_peg_counts() { - if ( detectors_configured ) + if ( AppIdPegCounts::detectors_configured ) return appid_peg_counts->data(); else return nullptr; @@ -84,10 +85,17 @@ PegCount* AppIdPegCounts::get_peg_counts() PegInfo* AppIdPegCounts::get_peg_info() { - if ( detectors_configured ) + if ( AppIdPegCounts::detectors_configured ) { - if ( !appid_detectors_peg_info.empty() ) - appid_pegs.insert( appid_pegs.end(), appid_detectors_peg_info.begin(), appid_detectors_peg_info.end()); + std::string app_name = "unknown_app"; + + AppIdPegCounts::unknown_app_idx = appid_detectors_peg_info.size() + NUM_APPID_GLOBAL_PEGS; + init_detector_peg_info(app_name, "_flows", " services detected"); + init_detector_peg_info(app_name, "_clients", " clients detected"); + init_detector_peg_info(app_name, "_users", " users detected"); + init_detector_peg_info(app_name, "_payloads", " payloads detected"); + init_detector_peg_info(app_name, "_misc", " misc detected"); + appid_pegs.insert( appid_pegs.end(), appid_detectors_peg_info.begin(), appid_detectors_peg_info.end()); // add the sentinel entry at the end appid_pegs.push_back({ CountType::END, nullptr, nullptr }); @@ -97,3 +105,62 @@ PegInfo* AppIdPegCounts::get_peg_info() return nullptr; } +void AppIdPegCounts::cleanup_peg_info() +{ + for ( auto& app_info : appid_detectors_peg_info ) + { + snort_free((void*)app_info.name); + snort_free((void*)app_info.help); + } + appid_detectors_peg_info.clear(); +} + +void AppIdPegCounts::inc_disco_peg(enum DiscoveryPegs stat) + { + (*appid_peg_counts)[stat]++; + } + + PegCount AppIdPegCounts::get_disco_peg(enum DiscoveryPegs stat) + { + return (*appid_peg_counts)[stat]; + } + + void AppIdPegCounts::inc_service_count(AppId id) + { + (*appid_peg_counts)[get_stats_index(id) + DetectorPegs::SERVICE_DETECTS]++; + } + + void AppIdPegCounts::inc_client_count(AppId id) + { + (*appid_peg_counts)[get_stats_index(id) + DetectorPegs::CLIENT_DETECTS]++; + } + + void AppIdPegCounts::inc_user_count(AppId id) + { + (*appid_peg_counts)[get_stats_index(id) + DetectorPegs::USER_DETECTS]++; + } + + void AppIdPegCounts::inc_payload_count(AppId id) + { + (*appid_peg_counts)[get_stats_index(id)+ DetectorPegs::PAYLOAD_DETECTS]++; + } + + void AppIdPegCounts::inc_misc_count(AppId id) + { + (*appid_peg_counts)[get_stats_index(id) + DetectorPegs::MISC_DETECTS]++; + } + + void AppIdPegCounts::set_detectors_configured() + { + detectors_configured = true; + } + + uint32_t AppIdPegCounts::get_stats_index(AppId id) + { + std::map::iterator stats_idx_it = appid_detector_pegs_idx.find(id); + if ( stats_idx_it != appid_detector_pegs_idx.end() ) + return stats_idx_it->second; + else + return AppIdPegCounts::unknown_app_idx; + } + diff --git a/src/network_inspectors/appid/appid_peg_counts.h b/src/network_inspectors/appid/appid_peg_counts.h index b76f2eeab..0521a865d 100644 --- a/src/network_inspectors/appid/appid_peg_counts.h +++ b/src/network_inspectors/appid/appid_peg_counts.h @@ -62,70 +62,25 @@ public: NUM_APPID_DETECTOR_PEGS }; - AppIdPegCounts(); - ~AppIdPegCounts(); - static void add_app_peg_info(AppInfoTableEntry&, AppId); static PegCount* get_peg_counts(); static PegInfo* get_peg_info(); static void init_pegs(); static void cleanup_pegs(); - static void cleanup_peg_info() - { - for ( auto& app_info : appid_detectors_peg_info ) - { - snort_free((void*)app_info.name); - snort_free((void*)app_info.help); - } - } - - static void inc_disco_peg(enum DiscoveryPegs stat) - { - (*appid_peg_counts)[stat]++; - } - - static PegCount get_disco_peg(enum DiscoveryPegs stat) - { - return (*appid_peg_counts)[stat]; - } - - static void inc_service_count(AppId id) - { - if ( appid_detector_pegs_idx[id] ) - (*appid_peg_counts)[appid_detector_pegs_idx[id] + DetectorPegs::SERVICE_DETECTS]++; - } - - static void inc_client_count(AppId id) - { - if ( appid_detector_pegs_idx[id] ) - (*appid_peg_counts)[appid_detector_pegs_idx[id] + DetectorPegs::CLIENT_DETECTS]++; - } - - static void inc_user_count(AppId id) - { - if ( appid_detector_pegs_idx[id] ) - (*appid_peg_counts)[appid_detector_pegs_idx[id] + DetectorPegs::USER_DETECTS]++; - } - - static void inc_payload_count(AppId id) - { - if ( appid_detector_pegs_idx[id] ) - (*appid_peg_counts)[appid_detector_pegs_idx[id] + DetectorPegs::PAYLOAD_DETECTS]++; - } - - static void inc_misc_count(AppId id) - { - if ( appid_detector_pegs_idx[id] ) - (*appid_peg_counts)[appid_detector_pegs_idx[id] + DetectorPegs::MISC_DETECTS]++; - } + static void cleanup_peg_info(); - static void set_detectors_configured() - { - detectors_configured = true; - } + static void inc_disco_peg(enum DiscoveryPegs stat); + static PegCount get_disco_peg(enum DiscoveryPegs stat); + static void inc_service_count(AppId id); + static void inc_client_count(AppId id); + static void inc_user_count(AppId id); + static void inc_payload_count(AppId id); + static void inc_misc_count(AppId id); + static void set_detectors_configured(); private: static bool detectors_configured; + static uint32_t unknown_app_idx; static std::map appid_detector_pegs_idx; static std::vector appid_detectors_peg_info; static std::vector appid_pegs; @@ -133,5 +88,6 @@ private: static void init_detector_peg_info(const std::string& app_name, const std::string& name_suffix, const std::string& help_suffix); + static uint32_t get_stats_index(AppId id); }; #endif diff --git a/src/network_inspectors/appid/appid_session.cc b/src/network_inspectors/appid/appid_session.cc index a1727622c..c2a058314 100644 --- a/src/network_inspectors/appid/appid_session.cc +++ b/src/network_inspectors/appid/appid_session.cc @@ -103,7 +103,8 @@ void AppIdSession::set_session_logging_state(const Packet* pkt, int direction) } } -AppIdSession* AppIdSession::allocate_session(const Packet* p, IpProtocol proto, int direction) +AppIdSession* AppIdSession::allocate_session(const Packet* p, IpProtocol proto, int direction, + AppIdInspector& inspector) { uint16_t port = 0; @@ -112,8 +113,7 @@ AppIdSession* AppIdSession::allocate_session(const Packet* p, IpProtocol proto, if ( ( proto == IpProtocol::TCP || proto == IpProtocol::UDP ) && ( p->ptrs.sp != p->ptrs.dp ) ) port = (direction == APP_ID_FROM_INITIATOR) ? p->ptrs.sp : p->ptrs.dp; - AppIdSession* asd = new AppIdSession(proto, ip, port); - + AppIdSession* asd = new AppIdSession(proto, ip, port, inspector); asd->flow = p->flow; asd->stats.first_packet_second = p->pkth->ts.tv_sec; asd->set_session_logging_state(p, direction); @@ -122,15 +122,16 @@ AppIdSession* AppIdSession::allocate_session(const Packet* p, IpProtocol proto, return asd; } -AppIdSession::AppIdSession(IpProtocol proto, const SfIp* ip, uint16_t port) - : FlowData(inspector_id), protocol(proto) +AppIdSession::AppIdSession(IpProtocol proto, const SfIp* ip, uint16_t port, + AppIdInspector& inspector) + : FlowData(inspector_id, &inspector), config(inspector.get_appid_config()), + protocol(proto), inspector(inspector) { service_ip.clear(); session_id = ++appid_flow_data_id; common.flow_type = APPID_FLOW_TYPE_NORMAL; common.initiator_ip = *ip; common.initiator_port = port; - config = AppIdInspector::get_inspector()->get_appid_config(); app_info_mgr = &AppInfoManager::get_instance(); if (thirdparty_appid_module) if (!(tpsession = thirdparty_appid_module->session_create())) @@ -148,7 +149,7 @@ AppIdSession::~AppIdSession() { if ( !in_expected_cache ) { - AppIdStatistics* stats_mgr = AppIdInspector::get_inspector()->get_stats_manager(); + AppIdStatistics* stats_mgr = AppIdStatistics::get_stats_manager(); if ( stats_mgr ) stats_mgr->update(this); @@ -204,7 +205,7 @@ static inline PktType get_pkt_type_from_ip_proto(IpProtocol proto) AppIdSession* AppIdSession::create_future_session(const Packet* ctrlPkt, const SfIp* cliIp, uint16_t cliPort, const SfIp* srvIp, uint16_t srvPort, IpProtocol proto, - int16_t app_id, int /*flags*/) + int16_t app_id, int /*flags*/, AppIdInspector& inspector) { char src_ip[INET6_ADDRSTRLEN]; char dst_ip[INET6_ADDRSTRLEN]; @@ -216,7 +217,7 @@ AppIdSession* AppIdSession::create_future_session(const Packet* ctrlPkt, const S // FIXIT-M - port parameter passed in as 0 since we may not know client port, verify this is // correct - AppIdSession* asd = new AppIdSession(proto, cliIp, 0); + AppIdSession* asd = new AppIdSession(proto, cliIp, 0, inspector); asd->common.policyId = asd->config->appIdPolicyId; // FIXIT-M expect session control packet support not ported to snort3 yet @@ -460,7 +461,7 @@ void AppIdSession::examine_ssl_metadata(Packet* p) { set_client_appid_data(client_id, nullptr); set_payload_app_id_data((AppId)payload_id, nullptr); - setSSLSquelch(p, ret, (ret == 1 ? payload_id : client_id)); + setSSLSquelch(p, ret, (ret == 1 ? payload_id : client_id), inspector); } scan_flags &= ~SCAN_SSL_HOST_FLAG; } @@ -472,7 +473,7 @@ void AppIdSession::examine_ssl_metadata(Packet* p) { set_client_appid_data(client_id, nullptr); set_payload_app_id_data((AppId)payload_id, nullptr); - setSSLSquelch(p, ret, (ret == 1 ? payload_id : client_id)); + setSSLSquelch(p, ret, (ret == 1 ? payload_id : client_id), inspector); } snort_free(tsession->tls_cname); tsession->tls_cname = nullptr; @@ -485,7 +486,7 @@ void AppIdSession::examine_ssl_metadata(Packet* p) { set_client_appid_data(client_id, nullptr); set_payload_app_id_data((AppId)payload_id, nullptr); - setSSLSquelch(p, ret, (ret == 1 ? payload_id : client_id)); + setSSLSquelch(p, ret, (ret == 1 ? payload_id : client_id), inspector); } snort_free(tsession->tls_orgUnit); tsession->tls_orgUnit = nullptr; diff --git a/src/network_inspectors/appid/appid_session.h b/src/network_inspectors/appid/appid_session.h index 69bda6e2a..eae6a1764 100644 --- a/src/network_inspectors/appid/appid_session.h +++ b/src/network_inspectors/appid/appid_session.h @@ -162,16 +162,16 @@ struct TlsSession class AppIdSession : public FlowData { public: - AppIdSession(IpProtocol, const SfIp*, uint16_t port); + AppIdSession(IpProtocol, const SfIp*, uint16_t port, AppIdInspector&); ~AppIdSession() override; - static AppIdSession* allocate_session(const Packet*, IpProtocol, int); + static AppIdSession* allocate_session(const Packet*, IpProtocol, int, AppIdInspector&); static AppIdSession* create_future_session(const Packet*, const SfIp*, uint16_t, const SfIp*, - uint16_t, IpProtocol, int16_t, int); + uint16_t, IpProtocol, int16_t, int, AppIdInspector&); uint32_t session_id = 0; Flow* flow = nullptr; - AppIdConfig* config = nullptr; + AppIdConfig* config; std::map flow_data; AppInfoManager* app_info_mgr = nullptr; CommonAppIdData common; @@ -318,8 +318,7 @@ private: static THREAD_LOCAL uint32_t appid_flow_data_id; AppId application_ids[APP_PROTOID_MAX]; - - + AppIdInspector& inspector; }; #endif diff --git a/src/network_inspectors/appid/appid_stats.cc b/src/network_inspectors/appid/appid_stats.cc index 61f888b9d..17c6ec2bd 100644 --- a/src/network_inspectors/appid/appid_stats.cc +++ b/src/network_inspectors/appid/appid_stats.cc @@ -45,6 +45,8 @@ struct AppIdStatRecord static const char appid_stats_filename[] = "appid_stats.log"; +static THREAD_LOCAL AppIdStatistics* appid_stats_manager = nullptr; + static void delete_record(void* record) { snort_free(record); @@ -220,9 +222,16 @@ AppIdStatistics::~AppIdStatistics() AppIdStatistics* AppIdStatistics::initialize_manager(const AppIdModuleConfig& config) { - return new AppIdStatistics(config); + appid_stats_manager = new AppIdStatistics(config); + return appid_stats_manager; } +AppIdStatistics* AppIdStatistics::get_stats_manager() +{ return appid_stats_manager; } + +void AppIdStatistics::cleanup() +{ delete appid_stats_manager; } + static void update_stats(AppIdSession* asd, AppId app_id, StatsBucket* bucket) { AppIdStatRecord* record = (AppIdStatRecord*)(fwAvlLookup(app_id, bucket->appsTree)); diff --git a/src/network_inspectors/appid/appid_stats.h b/src/network_inspectors/appid/appid_stats.h index 0b895dba0..5f83fa30d 100644 --- a/src/network_inspectors/appid/appid_stats.h +++ b/src/network_inspectors/appid/appid_stats.h @@ -49,6 +49,8 @@ public: ~AppIdStatistics(); static AppIdStatistics* initialize_manager(const AppIdModuleConfig&); + static AppIdStatistics* get_stats_manager(); + static void cleanup(); void update(AppIdSession*); void flush(); diff --git a/src/network_inspectors/appid/client_plugins/client_detector.cc b/src/network_inspectors/appid/client_plugins/client_detector.cc index b47a61c5c..a36adf136 100644 --- a/src/network_inspectors/appid/client_plugins/client_detector.cc +++ b/src/network_inspectors/appid/client_plugins/client_detector.cc @@ -38,6 +38,7 @@ static THREAD_LOCAL unsigned client_module_index = 0; ClientDetector::ClientDetector() { flow_data_index = client_module_index++ | APPID_SESSION_DATA_CLIENT_MODSTATE_BIT; + client = true; } void ClientDetector::register_appid(AppId appId, unsigned extractsInfo) diff --git a/src/network_inspectors/appid/client_plugins/client_discovery.cc b/src/network_inspectors/appid/client_plugins/client_discovery.cc index beae437b0..df3446c80 100644 --- a/src/network_inspectors/appid/client_plugins/client_discovery.cc +++ b/src/network_inspectors/appid/client_plugins/client_discovery.cc @@ -54,7 +54,8 @@ ProfileStats clientMatchPerfStats; THREAD_LOCAL ClientAppMatch* match_free_list = nullptr; -ClientDiscovery::ClientDiscovery() +ClientDiscovery::ClientDiscovery(AppIdInspector& ins) + : AppIdDiscovery(ins) { initialize(); } @@ -69,11 +70,15 @@ ClientDiscovery::~ClientDiscovery() } } -ClientDiscovery& ClientDiscovery::get_instance() +ClientDiscovery& ClientDiscovery::get_instance(AppIdInspector* ins) { static THREAD_LOCAL ClientDiscovery* discovery_manager = nullptr; if (!discovery_manager) - discovery_manager = new ClientDiscovery; + { + assert(ins); + discovery_manager = new ClientDiscovery(*ins); + } + return *discovery_manager; } diff --git a/src/network_inspectors/appid/client_plugins/client_discovery.h b/src/network_inspectors/appid/client_plugins/client_discovery.h index f7053e7c9..23d31cbae 100644 --- a/src/network_inspectors/appid/client_plugins/client_discovery.h +++ b/src/network_inspectors/appid/client_plugins/client_discovery.h @@ -43,13 +43,13 @@ class ClientDiscovery : public AppIdDiscovery { public: ~ClientDiscovery() override; - static ClientDiscovery& get_instance(); + static ClientDiscovery& get_instance(AppIdInspector* ins = nullptr); void finalize_client_plugins(); bool do_client_discovery(AppIdSession&, Packet*, int direction); private: - ClientDiscovery(); + ClientDiscovery(AppIdInspector& ins); void initialize() override; int exec_client_detectors(AppIdSession&, Packet*, int direction); ClientAppMatch* find_detector_candidates(const Packet* pkt, IpProtocol); diff --git a/src/network_inspectors/appid/detector_plugins/detector_sip.cc b/src/network_inspectors/appid/detector_plugins/detector_sip.cc index f24db8fc9..ad17e9130 100644 --- a/src/network_inspectors/appid/detector_plugins/detector_sip.cc +++ b/src/network_inspectors/appid/detector_plugins/detector_sip.cc @@ -154,7 +154,7 @@ SipUdpClientDetector::SipUdpClientDetector(ClientDiscovery* cdm) { APP_ID_SIP, APPINFO_FLAG_CLIENT_ADDITIONAL | APPINFO_FLAG_CLIENT_USER }, }; - AppIdInspector::get_inspector()->get_sip_event_handler().set_client(this); + handler->get_inspector().get_sip_event_handler().set_client(this); handler->register_detector(name, this, proto); } @@ -333,7 +333,7 @@ void SipServiceDetector::createRtpFlow(AppIdSession* asd, const Packet* pkt, con AppIdSession* fp, * fp2; fp = AppIdSession::create_future_session(pkt, cliIp, cliPort, srvIp, srvPort, proto, app_id, - APPID_EARLY_SESSION_FLAG_FW_RULE); + APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if ( fp ) { fp->client.set_id(asd->client.get_id()); @@ -347,7 +347,7 @@ void SipServiceDetector::createRtpFlow(AppIdSession* asd, const Packet* pkt, con // create an RTCP flow as well fp2 = AppIdSession::create_future_session(pkt, cliIp, cliPort + 1, srvIp, srvPort + 1, proto, - app_id, APPID_EARLY_SESSION_FLAG_FW_RULE); + app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if ( fp2 ) { fp2->client.set_id(asd->client.get_id()); @@ -428,11 +428,12 @@ SipServiceDetector::SipServiceDetector(ServiceDiscovery* sd) { SIP_PORT, IpProtocol::TCP, false } }; - AppIdInspector::get_inspector()->get_sip_event_handler().set_service(this); + // FIXIT - detector instance in each packet thread is calling this single sip event handler, + // last guy end wins, works now because it is all the same but this is not right... + handler->get_inspector().get_sip_event_handler().set_service(this); handler->register_detector(name, this, proto); } - int SipServiceDetector::validate(AppIdDiscoveryArgs& args) { ServiceSIPData* ss; @@ -480,7 +481,8 @@ void SipEventHandler::handle(DataEvent& event, Flow* flow) const Packet* p = sip_event.get_packet(); IpProtocol protocol = p->is_tcp() ? IpProtocol::TCP : IpProtocol::UDP; int direction = p->is_from_client() ? APP_ID_FROM_INITIATOR : APP_ID_FROM_RESPONDER; - asd = AppIdSession::allocate_session(p, protocol, direction); + asd = AppIdSession::allocate_session(p, protocol, direction, + client->get_handler().get_inspector()); } client_handler(sip_event, asd); diff --git a/src/network_inspectors/appid/lua_detector_api.cc b/src/network_inspectors/appid/lua_detector_api.cc index 6fb9c2748..31c8c845c 100644 --- a/src/network_inspectors/appid/lua_detector_api.cc +++ b/src/network_inspectors/appid/lua_detector_api.cc @@ -197,19 +197,13 @@ static int common_register_application_id(lua_State* L) { int index = 1; - auto& ud = *UserData::check(L, DETECTOR, index); + auto& ud = *UserData::check(L, DETECTOR, index); AppId appId = lua_tonumber(L, ++index); - if ( ud->package_info.client_detector ) - { - auto& cd = *UserData::check(L, DETECTOR, ++index); - cd->register_appid(appId, APPINFO_FLAG_CLIENT_ADDITIONAL); - } + if ( ud->is_client() ) + ud->register_appid(appId, APPINFO_FLAG_CLIENT_ADDITIONAL); else - { - auto& sd = *UserData::check(L, DETECTOR, ++index); - sd->register_appid(appId, APPINFO_FLAG_SERVICE_ADDITIONAL); - } + ud->register_appid(appId, APPINFO_FLAG_SERVICE_ADDITIONAL); AppInfoManager::get_instance().set_app_info_active(appId); @@ -281,9 +275,9 @@ static int detector_log_message(lua_State* L) // 4 - flags/stack - any flags static int service_analyze_payload(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); - ud->validate_params.asd->payload.set_id(lua_tonumber(L, 2)); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); + lsd->ldp.asd->payload.set_id(lua_tonumber(L, 2)); return 0; } @@ -301,9 +295,10 @@ static int service_analyze_payload(lua_State* L) // @return service_id/stack - service_id if successful, -1 otherwise. static int service_get_service_id(lua_State* L) { - auto ud = *UserData::check(L, DETECTOR, 1); + auto ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(false); - lua_pushnumber(L, ud->service_id); + lua_pushnumber(L, lsd->service_id); return 1; } @@ -418,7 +413,7 @@ static int service_set_validator(lua_State* L) } lua_pop(L, 1); - ud->package_info.validateFunctionName = pValidator; + ud->lsd.package_info.validateFunctionName = pValidator; lua_pushnumber(L, 0); return 1; } @@ -434,9 +429,9 @@ static int service_set_validator(lua_State* L) static int service_add_data_id(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); uint16_t sport = lua_tonumber(L, 2); - ud->validate_params.asd->add_flow_data_id(sport, ud); + lsd->ldp.asd->add_flow_data_id(sport, ud); lua_pushnumber(L, 0); return 1; } @@ -454,15 +449,14 @@ static int service_add_data_id(lua_State* L) static int service_add_service(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt && ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); AppId service_id = lua_tonumber(L, 2); const char* vendor = luaL_optstring(L, 3, nullptr); const char* version = luaL_optstring(L, 4, nullptr); /*Phase2 - discuss AppIdServiceSubtype will be maintained on lua side therefore the last parameter on the following call is nullptr. Subtype is not displayed on DC at present. */ - unsigned int retValue = ud->add_service(ud->validate_params.asd, - ud->validate_params.pkt, ud->validate_params.dir, + unsigned int retValue = ud->add_service(lsd->ldp.asd, lsd->ldp.pkt, lsd->ldp.dir, AppInfoManager::get_instance().get_appid_by_service_id(service_id), vendor, version, nullptr); @@ -480,10 +474,10 @@ static int service_add_service(lua_State* L) static int service_fail_service(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt && ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); ServiceDiscovery& sdm = static_cast(ud->get_handler()); - unsigned int retValue = sdm.fail_service(ud->validate_params.asd, ud->validate_params.pkt, - ud->validate_params.dir, nullptr); + unsigned int retValue = sdm.fail_service(lsd->ldp.asd, lsd->ldp.pkt, + lsd->ldp.dir, nullptr); lua_pushnumber(L, retValue); return 1; @@ -499,10 +493,10 @@ static int service_fail_service(lua_State* L) static int service_in_process_service(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt && ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - unsigned int retValue = ud->service_inprocess(ud->validate_params.asd, - ud->validate_params.pkt, ud->validate_params.dir); + unsigned int retValue = ud->service_inprocess(lsd->ldp.asd, + lsd->ldp.pkt, lsd->ldp.dir); lua_pushnumber(L, retValue); return 1; @@ -518,10 +512,10 @@ static int service_in_process_service(lua_State* L) static int service_set_incompatible_data(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt && ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - unsigned int retValue = ud->incompatible_data(ud->validate_params.asd, - ud->validate_params.pkt, ud->validate_params.dir); + unsigned int retValue = ud->incompatible_data(lsd->ldp.asd, + lsd->ldp.pkt, lsd->ldp.dir); lua_pushnumber(L, retValue); return 1; } @@ -537,9 +531,10 @@ static int service_set_incompatible_data(lua_State* L) */ static int detector_get_packet_size(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - lua_pushnumber(L, ud->validate_params.size); + lua_pushnumber(L, lsd->ldp.size); return 1; } @@ -553,9 +548,10 @@ static int detector_get_packet_size(lua_State* L) */ static int detector_get_packet_direction(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - lua_pushnumber(L, ud->validate_params.dir); + lua_pushnumber(L, lsd->ldp.dir); return 1; } @@ -574,17 +570,18 @@ static int detector_get_pcre_groups(lua_State* L) const char* error; int erroffset; - auto& ud = *UserData::check(L, DETECTOR, 1); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); + const char* pattern = lua_tostring(L, 2); unsigned int offset = lua_tonumber(L, 3); /*offset can be zero, no check necessary. */ /*compile the regular expression pattern, and handle errors */ - pcre* re = pcre_compile(pattern, /*the pattern */ - PCRE_DOTALL, /*default options - dot matches everything including - newline */ - &error, /*for error message */ - &erroffset, /*for error offset */ - nullptr); /*use default character tables */ + pcre* re = pcre_compile(pattern, // the pattern + PCRE_DOTALL, // default options - dot matches all inc \n + &error, // for error message + &erroffset, // for error offset + nullptr); // use default character tables if (re == nullptr) { @@ -593,16 +590,14 @@ static int detector_get_pcre_groups(lua_State* L) } /*pattern match against the subject string. */ - int rc = pcre_exec(re, // compiled pattern - nullptr, // no extra data - (const char*)ud->validate_params.data, // subject string - ud->validate_params.size, // length of the subject - offset, // offset 0 - 0, // default options - ovector, // output vector for substring - // information - OVECCOUNT); // number of elements in the output - // vector + int rc = pcre_exec(re, // compiled pattern + nullptr, // no extra data + (const char*)lsd->ldp.data, // subject string + lsd->ldp.size, // length of the subject + offset, // offset 0 + 0, // default options + ovector, // output vector for substring information + OVECCOUNT); // number of elements in the output vector if ( rc >= 0 ) { @@ -616,7 +611,7 @@ static int detector_get_pcre_groups(lua_State* L) lua_checkstack(L, rc); for (int i = 0; i < rc; i++) { - lua_pushlstring(L, (const char*)ud->validate_params.data + ovector[2*i], ovector[2*i+1] - + lua_pushlstring(L, (const char*)lsd->ldp.data + ovector[2*i], ovector[2*i+1] - ovector[2*i]); } } @@ -645,12 +640,13 @@ static int detector_get_pcre_groups(lua_State* L) */ static int detector_memcmp(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); const char* pattern = lua_tostring(L, 2); unsigned int patternLen = lua_tonumber(L, 3); unsigned int offset = lua_tonumber(L, 4); /*offset can be zero, no check necessary. */ - int rc = memcmp(ud->validate_params.data + offset, pattern, patternLen); + int rc = memcmp(lsd->ldp.data + offset, pattern, patternLen); lua_checkstack (L, 1); lua_pushnumber(L, rc); return 1; @@ -664,10 +660,10 @@ static int detector_memcmp(lua_State* L) */ static int detector_get_protocol_type(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - if ( !ud->validate_params.pkt->has_ip() ) + if ( !lsd->ldp.pkt->has_ip() ) { // FIXIT-M J why the inconsistent use of checkstack? lua_checkstack (L, 1); @@ -677,7 +673,7 @@ static int detector_get_protocol_type(lua_State* L) lua_checkstack (L, 1); // FIXIT-M is this conversion to double valid? - lua_pushnumber(L, (double)ud->validate_params.pkt->get_ip_proto_next() ); + lua_pushnumber(L, (double)lsd->ldp.pkt->get_ip_proto_next() ); return 1; } @@ -690,10 +686,10 @@ static int detector_get_protocol_type(lua_State* L) */ static int detector_get_packet_src_addr(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - const SfIp* ipAddr = ud->validate_params.pkt->ptrs.ip_api.get_src(); + const SfIp* ipAddr = lsd->ldp.pkt->ptrs.ip_api.get_src(); lua_checkstack (L, 1); lua_pushnumber(L, ipAddr->get_ip4_value()); return 1; @@ -708,10 +704,10 @@ static int detector_get_packet_src_addr(lua_State* L) */ static int detector_get_packet_dst_addr(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - const SfIp* ipAddr = ud->validate_params.pkt->ptrs.ip_api.get_dst(); + const SfIp* ipAddr = lsd->ldp.pkt->ptrs.ip_api.get_dst(); lua_checkstack (L, 1); lua_pushnumber(L, ipAddr->get_ip4_value()); return 1; @@ -726,10 +722,10 @@ static int detector_get_packet_dst_addr(lua_State* L) */ static int detector_get_packet_src_port(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - unsigned int port = ud->validate_params.pkt->ptrs.sp; + unsigned int port = lsd->ldp.pkt->ptrs.sp; lua_checkstack (L, 1); lua_pushnumber(L, port); return 1; @@ -744,10 +740,10 @@ static int detector_get_packet_src_port(lua_State* L) */ static int detector_get_packet_dst_port(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); - unsigned int port = ud->validate_params.pkt->ptrs.dp; + unsigned int port = lsd->ldp.pkt->ptrs.dp; lua_checkstack (L, 1); lua_pushnumber(L, port); return 1; @@ -816,7 +812,7 @@ static int client_init(lua_State*) static int service_add_client(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); AppId client_id = lua_tonumber(L, 2); AppId service_id = lua_tonumber(L, 3); @@ -828,7 +824,7 @@ static int service_add_client(lua_State* L) return 1; } - ud->add_app(ud->validate_params.asd, service_id, client_id, version); + ud->add_app(lsd->ldp.asd, service_id, client_id, version); lua_pushnumber(L, 0); return 1; @@ -837,12 +833,12 @@ static int service_add_client(lua_State* L) static int client_add_application(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); unsigned int service_id = lua_tonumber(L, 2); unsigned int productId = lua_tonumber(L, 4); const char* version = lua_tostring(L, 5); - ud->add_app(ud->validate_params.asd, + ud->add_app(lsd->ldp.asd, AppInfoManager::get_instance().get_appid_by_service_id(service_id), AppInfoManager::get_instance().get_appid_by_client_id(productId), version); @@ -853,10 +849,10 @@ static int client_add_application(lua_State* L) static int client_add_info(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); const char* info = lua_tostring(L, 2); - ud->add_info(ud->validate_params.asd, info); + ud->add_info(lsd->ldp.asd, info); lua_pushnumber(L, 0); return 1; @@ -865,12 +861,12 @@ static int client_add_info(lua_State* L) static int client_add_user(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); const char* userName = lua_tostring(L, 2); unsigned int service_id = lua_tonumber(L, 3); - ud->add_user(ud->validate_params.asd, userName, + ud->add_user(lsd->ldp.asd, userName, AppInfoManager::get_instance().get_appid_by_service_id(service_id), true); lua_pushnumber(L, 0); @@ -880,10 +876,10 @@ static int client_add_user(lua_State* L) static int client_add_payload(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); unsigned int payloadId = lua_tonumber(L, 2); - ud->add_payload(ud->validate_params.asd, + ud->add_payload(lsd->ldp.asd, AppInfoManager::get_instance().get_appid_by_payload_id(payloadId)); lua_pushnumber(L, 0); @@ -902,11 +898,11 @@ static int client_add_payload(lua_State* L) */ static int detector_get_flow(lua_State* L) { - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + auto& ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); auto df = new DetectorFlow(); - df->asd = ud->validate_params.asd; + df->asd = lsd->ldp.asd; UserData::push(L, DETECTORFLOW, df); df->myLuaState = L; @@ -920,9 +916,9 @@ static int detector_get_flow(lua_State* L) static int detector_add_http_pattern(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); enum httpPatternType pat_type = (enum httpPatternType)lua_tointeger(L, ++index); if (pat_type < HTTP_PAYLOAD || pat_type > HTTP_URL) @@ -960,9 +956,9 @@ static int detector_add_http_pattern(lua_State* L) static int detector_add_ssl_cert_pattern(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint8_t type = lua_tointeger(L, ++index); AppId app_id = (AppId)lua_tointeger(L, ++index); @@ -990,9 +986,9 @@ static int detector_add_ssl_cert_pattern(lua_State* L) static int detector_add_dns_host_pattern(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint8_t type = lua_tointeger(L, ++index); AppId app_id = (AppId)lua_tointeger(L, ++index); @@ -1018,9 +1014,9 @@ static int detector_add_dns_host_pattern(lua_State* L) static int detector_add_ssl_cname_pattern(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint8_t type = lua_tointeger(L, ++index); AppId app_id = (AppId)lua_tointeger(L, ++index); @@ -1052,11 +1048,11 @@ static int detector_add_ssl_cname_pattern(lua_State* L) static int detector_add_host_port_application(lua_State* L) { - int index = 1; SfIp ip_addr; - + int index = 1; // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint8_t type = lua_tointeger(L, ++index); AppId app_id = (AppId)lua_tointeger(L, ++index); @@ -1084,11 +1080,11 @@ static int detector_add_host_port_application(lua_State* L) static int detector_add_content_type_pattern(lua_State* L) { - int index = 1; size_t stringSize = 0; - + int index = 1; // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); const char* tmp_string = lua_tolstring(L, ++index, &stringSize); if (!tmp_string || !stringSize) @@ -1109,20 +1105,6 @@ static int detector_add_content_type_pattern(lua_State* L) return 0; } -static inline int get_detector_user_data(lua_State* L, int index, - UserData** detector_user_data, const char* errorString) -{ - // Verify detector user data and that we are not in packet context - *detector_user_data = UserData::check(L, DETECTOR, index); - if (!*detector_user_data || (**detector_user_data)->validate_params.pkt) - { - ErrorMessage("%s", errorString); - return -1; - } - - return 0; -} - static int create_chp_application(AppId appIdInstance, unsigned app_type_flags, int num_matches) { CHPApp* new_app = (CHPApp*)snort_calloc(sizeof(CHPApp)); @@ -1142,18 +1124,16 @@ static int create_chp_application(AppId appIdInstance, unsigned app_type_flags, static int detector_chp_create_application(lua_State* L) { - UserData* ud; int index = 1; - - if (get_detector_user_data(L, index, &ud, - "LuaDetectorApi:Invalid HTTP detector user data in CHPCreateApp.")) - return 0; + // Verify detector user data and that we are not in packet context + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); AppId appId = lua_tointeger(L, ++index); AppId appIdInstance = CHP_APPID_SINGLE_INSTANCE(appId); // Last instance for the old API - unsigned app_type_flags = lua_tointeger(L, ++index); - int num_matches = lua_tointeger(L, ++index); + unsigned app_type_flags = lua_tointeger(L, ++index); + int num_matches = lua_tointeger(L, ++index); // We only want one of these for each appId. if (sfxhash_find(CHP_glossary, &appIdInstance)) @@ -1323,17 +1303,16 @@ static int add_chp_pattern_action(AppId appIdInstance, int isKeyPattern, Pattern static int detector_add_chp_action(lua_State* L) { - UserData* ud; PatternType ptype; size_t psize; char* pattern; ActionType action; char* action_data; int index = 1; + // Verify detector user data and that we are not in packet context + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); - if (get_detector_user_data(L, index, &ud, - "LuaDetectorApi:Invalid HTTP detector user data in CHPAddAction.")) - return 0; // Parameter 1 AppId appId = lua_tointeger(L, ++index); @@ -1370,14 +1349,12 @@ static int detector_add_chp_action(lua_State* L) static int detector_create_chp_multi_application(lua_State* L) { - UserData* ud; AppId appIdInstance = APP_ID_UNKNOWN; int instance; int index = 1; - - if (get_detector_user_data(L, index, &ud, - "LuaDetectorApi:Invalid HTTP detector user data in CHPMultiCreateApp.")) - return 0; + // Verify detector user data and that we are not in packet context + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); AppId appId = lua_tointeger(L, ++index); unsigned app_type_flags = lua_tointeger(L, ++index); @@ -1408,17 +1385,15 @@ static int detector_create_chp_multi_application(lua_State* L) static int detector_add_chp_multi_action(lua_State* L) { - UserData* ud; PatternType ptype; size_t psize; char* pattern; ActionType action; char* action_data; int index = 1; - - if (get_detector_user_data(L, index, &ud, - "LuaDetectorApi:Invalid HTTP detector user data in CHPMultiAddAction.")) - return 0; + // Verify detector user data and that we are not in packet context + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); // Parameter 1 AppId appIdInstance = lua_tointeger(L, ++index); @@ -1455,15 +1430,15 @@ static int detector_add_chp_multi_action(lua_State* L) static int detector_port_only_service(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); AppId appId = lua_tointeger(L, ++index); uint16_t port = lua_tointeger(L, ++index); uint8_t protocol = lua_tointeger(L, ++index); - AppIdConfig* config = AppIdInspector::get_inspector()->get_appid_config(); + AppIdConfig* config = ud->get_handler().get_inspector().get_appid_config(); if (port == 0) config->ip_protocol[protocol] = appId; else if (protocol == 6) @@ -1477,8 +1452,7 @@ static int detector_port_only_service(lua_State* L) } /* Add a length-based detector. This is done by adding a new length sequence - * to the cache. Note that this does not require a validate and is only used - * as a fallback identification. + * to the cache. * * @param lua_State* - Lua state variable. * @param appId/stack - App ID to use for this detector. @@ -1501,7 +1475,7 @@ static int detector_add_length_app_cache(lua_State* L) LengthKey length_sequence; int index = 1; - UserData::check(L, DETECTOR, index); + UserData::check(L, DETECTOR, index); AppId appId = lua_tonumber(L, ++index); IpProtocol proto = (IpProtocol)lua_tonumber(L, ++index); @@ -1594,9 +1568,9 @@ static int detector_add_length_app_cache(lua_State* L) static int detector_add_af_application(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); AppId indicator = (AppId)lua_tointeger(L, ++index); AppId forecast = (AppId)lua_tointeger(L, ++index); @@ -1609,9 +1583,9 @@ static int detector_add_af_application(lua_State* L) static int detector_add_url_application(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint32_t service_id = lua_tointeger(L, ++index); uint32_t client_app = lua_tointeger(L, ++index); @@ -1694,9 +1668,9 @@ static int detector_add_url_application(lua_State* L) static int detector_add_rtmp_url(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint32_t service_id = lua_tointeger(L, ++index); uint32_t client_app = lua_tointeger(L, ++index); @@ -1779,9 +1753,9 @@ static int detector_add_rtmp_url(lua_State* L) static int detector_add_sip_user_agent(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint32_t client_app = lua_tointeger(L, ++index); const char* clientVersion = lua_tostring(L, ++index); @@ -1809,9 +1783,9 @@ static int detector_add_sip_user_agent(lua_State* L) static int create_custom_application(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); /* Verify that host pattern is a valid string */ size_t appNameLen = 0; @@ -1834,11 +1808,11 @@ static int create_custom_application(lua_State* L) static int add_client_application(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); unsigned int service_id = lua_tonumber(L, 2); unsigned int client_id = lua_tonumber(L, 3); - ud->add_app(ud->validate_params.asd, service_id, client_id, ""); + ud->add_app(lsd->ldp.asd, service_id, client_id, ""); lua_pushnumber(L, 0); return 1; } @@ -1856,14 +1830,14 @@ static int add_client_application(lua_State* L) static int add_service_application(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt && ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); unsigned service_id = lua_tonumber(L, 2); /*Phase2 - discuss AppIdServiceSubtype will be maintained on lua side therefore the last parameter on the following call is nullptr. Subtype is not displayed on DC at present. */ - unsigned retValue = ud->add_service(ud->validate_params.asd, ud->validate_params.pkt, - ud->validate_params.dir, service_id); + unsigned retValue = ud->add_service(lsd->ldp.asd, lsd->ldp.pkt, + lsd->ldp.dir, service_id); lua_pushnumber(L, retValue); return 1; @@ -1872,10 +1846,10 @@ static int add_service_application(lua_State* L) static int add_payload_application(lua_State* L) { auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.asd); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); unsigned payload_id = lua_tonumber(L, 2); - ud->add_payload(ud->validate_params.asd, payload_id); + ud->add_payload(lsd->ldp.asd, payload_id); lua_pushnumber(L, 0); return 1; @@ -1884,9 +1858,9 @@ static int add_payload_application(lua_State* L) static int add_http_pattern(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); /* Verify valid pattern type */ enum httpPatternType pat_type = (enum httpPatternType)lua_tointeger(L, ++index); @@ -1921,9 +1895,9 @@ static int add_http_pattern(lua_State* L) static int add_url_pattern(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint32_t service_id = lua_tointeger(L, ++index); uint32_t clientAppId = lua_tointeger(L, ++index); @@ -2005,9 +1979,11 @@ static int add_url_pattern(lua_State* L) */ static int add_port_pattern_client(lua_State* L) { - int index = 1; size_t patternSize = 0; + int index = 1; + // Verify detector user data and that we are not in packet context auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); IpProtocol protocol = (IpProtocol)lua_tonumber(L, ++index); uint16_t port = 0; //port = lua_tonumber(L, ++index); FIXIT-L - why commented out? @@ -2052,9 +2028,11 @@ static int add_port_pattern_client(lua_State* L) */ static int add_port_pattern_service(lua_State* L) { - int index = 1; size_t patternSize = 0; + int index = 1; + // Verify detector user data and that we are not in packet context auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); IpProtocol protocol = (IpProtocol)lua_tonumber(L, ++index); uint16_t port = lua_tonumber(L, ++index); @@ -2081,9 +2059,9 @@ static int add_port_pattern_service(lua_State* L) static int detector_add_sip_server(lua_State* L) { int index = 1; - // Verify detector user data and that we are not in packet context - assert(!(*UserData::check(L, DETECTOR, index))->validate_params.pkt); + auto& ud = *UserData::check(L, DETECTOR, index); + ud->validate_lua_state(false); uint32_t client_app = lua_tointeger(L, ++index); const char* clientVersion = lua_tostring(L, ++index); @@ -2137,8 +2115,8 @@ static int create_future_flow(lua_State* L) SfIp client_addr; SfIp server_addr; int16_t snort_app_id = 0; - auto& ud = *UserData::check(L, DETECTOR, 1); - assert(ud->validate_params.pkt); + AppIdDetector* ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); const char* pattern = lua_tostring(L, 2); if (!convert_string_to_address(pattern, &client_addr)) @@ -2165,9 +2143,9 @@ static int create_future_flow(lua_State* L) snort_app_id = entry->snortId; } - AppIdSession* fp = AppIdSession::create_future_session(ud->validate_params.pkt, &client_addr, + AppIdSession* fp = AppIdSession::create_future_session(lsd->ldp.pkt, &client_addr, client_port, &server_addr, server_port, proto, snort_app_id, - APPID_EARLY_SESSION_FLAG_FW_RULE); + APPID_EARLY_SESSION_FLAG_FW_RULE, ud->get_handler().get_inspector()); if (fp) { fp->service.set_id(service_id); @@ -2318,7 +2296,7 @@ static int Detector_gc(lua_State*) /*convert detector to string for printing */ static int Detector_tostring(lua_State* L) { - lua_pushfstring(L, "Detector (%p)", UserData::check(L, DETECTOR, 1)); + lua_pushfstring(L, "Detector (%p)", UserData::check(L, DETECTOR, 1)); return 1; } @@ -2359,29 +2337,29 @@ int register_detector(lua_State* L) return 1; /* return methods on the stack */ } -LuaDetector::~LuaDetector() +LuaStateDescriptor::~LuaStateDescriptor() { // release the reference of the userdata on the lua side if ( detector_user_data_ref != LUA_REFNIL ) luaL_unref(my_lua_state, LUA_REGISTRYINDEX, detector_user_data_ref); } -int LuaDetector::lua_validate(AppIdDiscoveryArgs& args) +int LuaStateDescriptor::lua_validate(AppIdDiscoveryArgs& args) { Profile lua_detector_context(luaCustomPerfStats); - validate_params.data = args.data; - validate_params.size = args.size; - validate_params.dir = args.dir; - validate_params.asd = args.asd; - validate_params.pkt = args.pkt; + ldp.data = args.data; + ldp.size = args.size; + ldp.dir = args.dir; + ldp.asd = args.asd; + ldp.pkt = args.pkt; const char* validateFn = package_info.validateFunctionName.c_str(); if ( (!validateFn) || !lua_checkstack(my_lua_state, 1) ) { ErrorMessage("lua detector %s: invalid LUA %s\n", package_info.name.c_str(), lua_tostring(my_lua_state, -1)); - validate_params.pkt = nullptr; + ldp.pkt = nullptr; return APPID_ENULL; } @@ -2396,7 +2374,7 @@ int LuaDetector::lua_validate(AppIdDiscoveryArgs& args) // that don't impact processing by other detectors or future packets by the same detector. ErrorMessage("lua detector %s: error validating %s\n", package_info.name.c_str(), lua_tostring(my_lua_state, -1)); - validate_params.pkt = nullptr; + ldp.pkt = nullptr; return APPID_ENULL; } @@ -2407,24 +2385,92 @@ int LuaDetector::lua_validate(AppIdDiscoveryArgs& args) if ( !lua_isnumber(my_lua_state, -1) ) { ErrorMessage("lua detector %s: returned non-numeric value\n", package_info.name.c_str()); - validate_params.pkt = nullptr; + ldp.pkt = nullptr; return APPID_ENULL; } int rc = lua_tonumber(my_lua_state, -1); lua_pop(my_lua_state, 1); DebugFormat(DEBUG_APPID, "lua detector %s: status: %d\n", package_info.name.c_str(), rc); - validate_params.pkt = nullptr; + ldp.pkt = nullptr; return rc; } +static inline void init_lsd(LuaStateDescriptor* lsd, const std::string& detector_name, lua_State* L) +{ + lsd->service_id = APP_ID_UNKNOWN; + get_lua_field(L, -1, "init", lsd->package_info.initFunctionName); + get_lua_field(L, -1, "clean", lsd->package_info.cleanFunctionName); + get_lua_field(L, -1, "validate", lsd->package_info.validateFunctionName); + get_lua_field(L, -1, "minimum_matches", lsd->package_info.minimum_matches); + lsd->package_info.name = detector_name; + lua_pop(L, 1); // pop client table + lua_pop(L, 1); // pop DetectorPackageInfo table + lsd->my_lua_state = L; +} + +static inline bool lua_params_validator(LuaDetectorParameters& ldp, bool packet_context) +{ + if ( packet_context ) + { + assert(ldp.asd); + assert(ldp.pkt); + } + else + { + assert(!ldp.pkt); + } + + return true; +} + +LuaServiceDetector::LuaServiceDetector(AppIdDiscovery* sdm, const std::string& detector_name, + IpProtocol protocol, lua_State* L) +{ + handler = sdm; + name = detector_name; + proto = protocol; + handler->register_detector(name, this, proto); + init_lsd(&lsd, detector_name, L); + UserData::push(L, DETECTOR, this); + // add a lua reference so the detector doesn't get garbage-collected + lua_pushvalue(L, -1); + lsd.detector_user_data_ref = luaL_ref(L, LUA_REGISTRYINDEX); +} + +LuaStateDescriptor* LuaServiceDetector::validate_lua_state(bool packet_context) +{ + lua_params_validator(lsd.ldp, packet_context); + return &lsd; +} + int LuaServiceDetector::validate(AppIdDiscoveryArgs& args) { - return lua_validate(args); + return lsd.lua_validate(args); +} + +LuaClientDetector::LuaClientDetector(AppIdDiscovery* cdm, const std::string& detector_name, + IpProtocol protocol, lua_State* L) + { + handler = cdm; + name = detector_name; + proto = protocol; + handler->register_detector(name, this, proto); + init_lsd(&lsd, detector_name, L); + UserData::push(L, DETECTOR, this); + // add a lua reference so the detector doesn't get garbage-collected + lua_pushvalue(L, -1); + lsd.detector_user_data_ref = luaL_ref(L, LUA_REGISTRYINDEX); + } + +LuaStateDescriptor* LuaClientDetector::validate_lua_state(bool packet_context) +{ + lua_params_validator(lsd.ldp, packet_context); + return &lsd; } int LuaClientDetector::validate(AppIdDiscoveryArgs& args) { - return lua_validate(args); + return lsd.lua_validate(args); } diff --git a/src/network_inspectors/appid/lua_detector_api.h b/src/network_inspectors/appid/lua_detector_api.h index 5a802a7ec..32bbcd651 100644 --- a/src/network_inspectors/appid/lua_detector_api.h +++ b/src/network_inspectors/appid/lua_detector_api.h @@ -39,7 +39,6 @@ class AppIdSession; struct DetectorPackageInfo { - bool client_detector = false; std::string initFunctionName; std::string cleanFunctionName; std::string validateFunctionName; @@ -48,7 +47,7 @@ struct DetectorPackageInfo IpProtocol proto; }; -struct ValidateParameters +struct LuaDetectorParameters { const uint8_t* data = nullptr; uint16_t size = 0; @@ -58,50 +57,41 @@ struct ValidateParameters uint8_t macAddress[6] = { 0 }; }; -class LuaDetector +class LuaStateDescriptor { public: - LuaDetector() = default; - virtual ~LuaDetector(); + LuaStateDescriptor() = default; + virtual ~LuaStateDescriptor(); - ValidateParameters validate_params; + LuaDetectorParameters ldp; lua_State* my_lua_state= nullptr; int detector_user_data_ref = 0; // key into LUA_REGISTRYINDEX DetectorPackageInfo package_info; - bool is_client = false; unsigned int service_id = APP_ID_UNKNOWN; int lua_validate(AppIdDiscoveryArgs&); }; -class LuaServiceDetector : public LuaDetector, public ServiceDetector +class LuaServiceDetector : public ServiceDetector { public: - LuaServiceDetector(AppIdDiscovery* sdm, const std::string& detector_name, IpProtocol protocol) - { - handler = sdm; - name = detector_name; - proto = protocol; - handler->register_detector(name, this, proto); - } - - + LuaServiceDetector(AppIdDiscovery* sdm, const std::string& detector_name, IpProtocol protocol, + lua_State* L); int validate(AppIdDiscoveryArgs&) override; + LuaStateDescriptor* validate_lua_state(bool packet_context) override; + + LuaStateDescriptor lsd; }; -class LuaClientDetector : public LuaDetector, public ClientDetector +class LuaClientDetector : public ClientDetector { public: - LuaClientDetector(AppIdDiscovery* cdm, const std::string& detector_name, IpProtocol protocol) - { - handler = cdm; - name = detector_name; - proto = protocol; - handler->register_detector(name, this, proto); - } - - + LuaClientDetector(AppIdDiscovery* cdm, const std::string& detector_name, IpProtocol protocol, + lua_State* L); int validate(AppIdDiscoveryArgs&) override; + LuaStateDescriptor* validate_lua_state(bool packet_context) override; + + LuaStateDescriptor lsd; }; int register_detector(lua_State*); diff --git a/src/network_inspectors/appid/lua_detector_flow_api.cc b/src/network_inspectors/appid/lua_detector_flow_api.cc index bd12b4875..e549370f5 100644 --- a/src/network_inspectors/appid/lua_detector_flow_api.cc +++ b/src/network_inspectors/appid/lua_detector_flow_api.cc @@ -26,6 +26,7 @@ #include "lua_detector_flow_api.h" #include "appid_api.h" +#include "appid_inspector.h" #include "lua_detector_api.h" #include "lua_detector_module.h" #include "lua_detector_util.h" @@ -157,8 +158,8 @@ static int create_detector_flow(lua_State* L) SfIp saddr; SfIp daddr; - auto& detector_data = *UserData::check(L, DETECTOR, 1); - assert(detector_data->validate_params.pkt); + AppIdDetector* ud = *UserData::check(L, DETECTOR, 1); + LuaStateDescriptor* lsd = ud->validate_lua_state(true); const char* pattern = lua_tostring(L, 2); size_t patternLen = lua_strlen (L, 2); @@ -208,8 +209,8 @@ static int create_detector_flow(lua_State* L) LuaDetectorManager::add_detector_flow(detector_flow); - detector_flow->asd = AppIdSession::create_future_session(detector_data->validate_params.pkt, - &saddr, sport, &daddr, dport, proto, 0, 0); + detector_flow->asd = AppIdSession::create_future_session(lsd->ldp.pkt, &saddr, sport, + &daddr, dport, proto, 0, 0, ud->get_handler().get_inspector()); if (!detector_flow->asd) { diff --git a/src/network_inspectors/appid/lua_detector_module.cc b/src/network_inspectors/appid/lua_detector_module.cc index 203abe824..c1ac150fa 100644 --- a/src/network_inspectors/appid/lua_detector_module.cc +++ b/src/network_inspectors/appid/lua_detector_module.cc @@ -27,8 +27,6 @@ #include #include -#include -#include "lua/lua.h" #include "appid_config.h" #include "lua_detector_util.h" @@ -48,7 +46,7 @@ static THREAD_LOCAL LuaDetectorManager* lua_detector_mgr; static THREAD_LOCAL SF_LIST allocated_detector_flow_list; -static inline bool get_lua_field(lua_State* L, int table, const char* field, std::string& out) +bool get_lua_field(lua_State* L, int table, const char* field, std::string& out) { lua_getfield(L, table, field); bool result = lua_isstring(L, -1); @@ -59,7 +57,7 @@ static inline bool get_lua_field(lua_State* L, int table, const char* field, std return result; } -static inline bool get_lua_field(lua_State* L, int table, const char* field, int& out) +bool get_lua_field(lua_State* L, int table, const char* field, int& out) { lua_getfield(L, table, field); bool result = lua_isnumber(L, -1); @@ -72,7 +70,7 @@ static inline bool get_lua_field(lua_State* L, int table, const char* field, int return result; } -static inline bool get_lua_field(lua_State* L, int table, const char* field, IpProtocol& out) +bool get_lua_field(lua_State* L, int table, const char* field, IpProtocol& out) { lua_getfield(L, table, field); bool result = lua_isnumber(L, -1); @@ -147,17 +145,18 @@ LuaDetectorManager::~LuaDetectorManager() { for ( auto& detector : allocated_detectors ) { - auto L = detector->my_lua_state; + LuaStateDescriptor* lsd = detector->validate_lua_state(false); + auto L = lsd->my_lua_state; - lua_getglobal(L, detector->package_info.cleanFunctionName.c_str()); + lua_getglobal(L, lsd->package_info.cleanFunctionName.c_str()); if ( lua_isfunction(L, -1) ) { /*first parameter is DetectorUserData */ - lua_rawgeti(L, LUA_REGISTRYINDEX, detector->detector_user_data_ref); + lua_rawgeti(L, LUA_REGISTRYINDEX, lsd->detector_user_data_ref); if ( lua_pcall(L, 1, 1, 0) ) { ErrorMessage("Could not cleanup the %s client app element: %s\n", - detector->package_info.name.c_str(), lua_tostring(L, -1)); + lsd->package_info.name.c_str(), lua_tostring(L, -1)); } } } @@ -260,9 +259,8 @@ static inline uint32_t compute_lua_tracker_size(uint64_t rnaMemory, uint32_t num // FIXIT-M lifetime of detector is easy to misuse with this idiom // Leaves 1 value (the Detector userdata) at the top of the stack -static LuaDetector* create_lua_detector(lua_State* L, const char* detectorName, bool is_custom) +static AppIdDetector* create_lua_detector(lua_State* L, const char* detectorName, bool is_custom) { - LuaDetector* detector = nullptr; std::string detector_name; IpProtocol proto = IpProtocol::PROTO_NOT_SET; @@ -282,18 +280,9 @@ static LuaDetector* create_lua_detector(lua_State* L, const char* detectorName, if ( lua_istable(L, -1) ) { LuaClientDetector* cd = new LuaClientDetector(&ClientDiscovery::get_instance(), - detectorName, proto); - cd->is_client = true; + detectorName, proto, L); cd->set_custom_detector(is_custom); - cd->set_minimum_matches(cd->package_info.minimum_matches); - cd->package_info.client_detector = true; - get_lua_field(L, -1, "init", cd->package_info.initFunctionName); - get_lua_field(L, -1, "clean", cd->package_info.cleanFunctionName); - get_lua_field(L, -1, "validate", cd->package_info.validateFunctionName); - get_lua_field(L, -1, "minimum_matches", cd->package_info.minimum_matches); - cd->package_info.name = detector_name; - detector = cd; - lua_pop(L, 1); // pop client table + return cd; } else { @@ -303,17 +292,9 @@ static LuaDetector* create_lua_detector(lua_State* L, const char* detectorName, if ( lua_istable(L, -1) ) { LuaServiceDetector* sd = new LuaServiceDetector(&ServiceDiscovery::get_instance(), - detectorName, proto); - sd->is_client = false; + detectorName, proto, L); sd->set_custom_detector(is_custom); - sd->service_id = APP_ID_UNKNOWN; - sd->package_info.client_detector = false; - get_lua_field(L, -1, "init", sd->package_info.initFunctionName); - get_lua_field(L, -1, "clean", sd->package_info.cleanFunctionName); - get_lua_field(L, -1, "validate", sd->package_info.validateFunctionName); - get_lua_field(L, -1, "minimum_matches", sd->package_info.minimum_matches); - sd->package_info.name = detector_name; - detector = sd; + return sd; } lua_pop(L, 1); // pop server table @@ -321,17 +302,7 @@ static LuaDetector* create_lua_detector(lua_State* L, const char* detectorName, lua_pop(L, 1); // pop DetectorPackageInfo table - if ( detector ) - { - detector->my_lua_state = L; - UserData::push(L, DETECTOR, detector); - - // add a lua reference so the detector doesn't get garbage-collected - lua_pushvalue(L, -1); - detector->detector_user_data_ref = luaL_ref(L, LUA_REGISTRYINDEX); - } - - return detector; + return nullptr; } void LuaDetectorManager::load_detector(char* detector_filename, bool isCustom) @@ -355,7 +326,7 @@ void LuaDetectorManager::load_detector(char* detector_filename, bool isCustom) snprintf(detectorName, MAX_LUA_DETECTOR_FILENAME_LEN, "%s_%s", (isCustom ? "custom" : "cisco"), basename(detector_filename)); - LuaDetector* detector = create_lua_detector(L, detectorName, isCustom); + AppIdDetector* detector = create_lua_detector(L, detectorName, isCustom); allocated_detectors.push_front(detector); num_lua_detectors++; @@ -398,34 +369,36 @@ void LuaDetectorManager::activate_lua_detectors() { for ( auto ld : allocated_detectors ) { - auto detector = static_cast(ld); - auto L = detector->my_lua_state; - - lua_getglobal(L, detector->package_info.initFunctionName.c_str()); + LuaStateDescriptor* lsd = ld->validate_lua_state(false); + auto L = lsd->my_lua_state; + lua_getglobal(L, lsd->package_info.initFunctionName.c_str()); if (!lua_isfunction(L, -1)) { ErrorMessage("Detector %s: does not contain DetectorInit() function\n", - detector->get_name().c_str()); + ld->get_name().c_str()); return; } /*first parameter is DetectorUserData */ - lua_rawgeti(L, LUA_REGISTRYINDEX, detector->detector_user_data_ref); + lua_rawgeti(L, LUA_REGISTRYINDEX, lsd->detector_user_data_ref); /*second parameter is a table containing configuration stuff. */ // ... which is empty.??? lua_newtable(L); if ( lua_pcall(L, 2, 1, 0) ) ErrorMessage("Could not initialize the %s client app element: %s\n", - detector->get_name().c_str(), lua_tostring(L, -1)); + ld->get_name().c_str(), lua_tostring(L, -1)); ++num_active_lua_detectors; } lua_tracker_size = compute_lua_tracker_size(MAX_MEMORY_FOR_LUA_DETECTORS, num_active_lua_detectors); - for ( auto& detector : allocated_detectors ) - set_lua_tracker_size(detector->my_lua_state, lua_tracker_size); + for ( auto& ld : allocated_detectors ) + { + LuaStateDescriptor* lsd = ld->validate_lua_state(false); + set_lua_tracker_size(lsd->my_lua_state, lua_tracker_size); + } } void LuaDetectorManager::list_lua_detectors() @@ -441,15 +414,10 @@ void LuaDetectorManager::list_lua_detectors() for ( auto& ld : allocated_detectors ) { - const char* name; - mem = lua_gc(ld->my_lua_state, LUA_GCCOUNT, 0); + LuaStateDescriptor* lsd = ld->validate_lua_state(false); + mem = lua_gc(lsd->my_lua_state, LUA_GCCOUNT, 0); totalMem += mem; - if ( ld->is_client ) - name = static_cast(ld)->get_name().c_str(); - else - name = static_cast(ld)->get_name().c_str(); - - LogMessage("\tDetector %s: Lua Memory usage %zu kb\n", name, mem); + LogMessage("\tDetector %s: Lua Memory usage %zu kb\n", ld->get_name().c_str(), mem); } LogMessage("Lua Stats total detectors: %zu\n", allocated_detectors.size()); diff --git a/src/network_inspectors/appid/lua_detector_module.h b/src/network_inspectors/appid/lua_detector_module.h index f849e649c..481ba984f 100644 --- a/src/network_inspectors/appid/lua_detector_module.h +++ b/src/network_inspectors/appid/lua_detector_module.h @@ -24,11 +24,21 @@ #include #include +#include + +#include +#include + +#include "protocols/protocol_ids.h" class AppIdConfig; -class LuaDetector; +class AppIdDetector; struct DetectorFlow; +bool get_lua_field(lua_State* L, int table, const char* field, std::string& out); +bool get_lua_field(lua_State* L, int table, const char* field, int& out); +bool get_lua_field(lua_State* L, int table, const char* field, IpProtocol& out); + class LuaDetectorManager { public: @@ -47,7 +57,7 @@ private: void load_lua_detectors(const char* path, bool isCustom); AppIdConfig& config; - std::list allocated_detectors; + std::list allocated_detectors; // FIXIT-L make these perf counters uint32_t lua_tracker_size = 0; diff --git a/src/network_inspectors/appid/service_plugins/service_detector.cc b/src/network_inspectors/appid/service_plugins/service_detector.cc index 8c2e2bab3..9282e3f41 100644 --- a/src/network_inspectors/appid/service_plugins/service_detector.cc +++ b/src/network_inspectors/appid/service_plugins/service_detector.cc @@ -40,6 +40,7 @@ static THREAD_LOCAL unsigned service_module_index = 0; ServiceDetector::ServiceDetector() { flow_data_index = service_module_index++ | APPID_SESSION_DATA_SERVICE_MODSTATE_BIT; + client = false; } void ServiceDetector::register_appid(AppId appId, unsigned extractsInfo) diff --git a/src/network_inspectors/appid/service_plugins/service_detector.h b/src/network_inspectors/appid/service_plugins/service_detector.h index a54be545b..3190271de 100644 --- a/src/network_inspectors/appid/service_plugins/service_detector.h +++ b/src/network_inspectors/appid/service_plugins/service_detector.h @@ -57,7 +57,6 @@ public: private: int update_service_data(AppIdSession*, const Packet*, int dir, AppId, const char* vendor, const char* version); - }; #endif diff --git a/src/network_inspectors/appid/service_plugins/service_discovery.cc b/src/network_inspectors/appid/service_plugins/service_discovery.cc index 0d5a88b7c..aa888adfd 100644 --- a/src/network_inspectors/appid/service_plugins/service_discovery.cc +++ b/src/network_inspectors/appid/service_plugins/service_discovery.cc @@ -84,17 +84,22 @@ static THREAD_LOCAL ServiceDetector* ftp_service = nullptr; ProfileStats serviceMatchPerfStats; -ServiceDiscovery::ServiceDiscovery() +ServiceDiscovery::ServiceDiscovery(AppIdInspector& ins) + : AppIdDiscovery(ins) { initialize(); } -ServiceDiscovery& ServiceDiscovery::get_instance() +ServiceDiscovery& ServiceDiscovery::get_instance(AppIdInspector* ins) { static THREAD_LOCAL ServiceDiscovery* discovery_manager = nullptr; if (!discovery_manager) - discovery_manager = new ServiceDiscovery; + { + assert(ins); + discovery_manager = new ServiceDiscovery(*ins); + } + return *discovery_manager; } diff --git a/src/network_inspectors/appid/service_plugins/service_discovery.h b/src/network_inspectors/appid/service_plugins/service_discovery.h index 576586e34..dad60c9a2 100644 --- a/src/network_inspectors/appid/service_plugins/service_discovery.h +++ b/src/network_inspectors/appid/service_plugins/service_discovery.h @@ -61,7 +61,7 @@ enum SESSION_SERVICE_SEARCH_STATE class ServiceDiscovery : public AppIdDiscovery { public: - static ServiceDiscovery& get_instance(); + static ServiceDiscovery& get_instance(AppIdInspector* ins = nullptr); void finalize_service_patterns(); int add_service_port(AppIdDetector*, const ServiceDetectorPort&) override; @@ -77,7 +77,7 @@ public: static int add_ftp_service_state(AppIdSession&); private: - ServiceDiscovery(); + ServiceDiscovery(AppIdInspector& ins); void initialize() override; void get_next_service(const Packet*, const int dir, AppIdSession*, ServiceDiscoveryState*); void get_port_based_services(IpProtocol, uint16_t port, AppIdSession*); diff --git a/src/network_inspectors/appid/service_plugins/service_ftp.cc b/src/network_inspectors/appid/service_plugins/service_ftp.cc index cf38c6d75..fef5c59f2 100644 --- a/src/network_inspectors/appid/service_plugins/service_ftp.cc +++ b/src/network_inspectors/appid/service_plugins/service_ftp.cc @@ -89,7 +89,7 @@ FtpServiceDetector::FtpServiceDetector(ServiceDiscovery* sd) name = "ftp"; proto = IpProtocol::TCP; detectorType = DETECTOR_TYPE_DECODER; - ftp_data_app_id = AppIdInspector::get_inspector()->add_appid_protocol_reference("ftp-data"); + ftp_data_app_id = AppInfoManager::get_instance().add_appid_protocol_reference("ftp-data"); tcp_patterns = { @@ -793,13 +793,12 @@ static inline void WatchForCommandResult(ServiceFTPData* fd, AppIdSession* asd, fd->cmd = command; } -void FtpServiceDetector::create_expected_session(AppIdSession* asd,const Packet* pkt, const - SfIp* cliIp, - uint16_t cliPort, const SfIp* srvIp, uint16_t srvPort, IpProtocol proto, +void FtpServiceDetector::create_expected_session(AppIdSession* asd, const Packet* pkt, + const SfIp* cliIp, uint16_t cliPort, const SfIp* srvIp, uint16_t srvPort, IpProtocol proto, int flags, APPID_SESSION_DIRECTION dir) { AppIdSession* fp = AppIdSession::create_future_session(pkt, cliIp, cliPort, srvIp, srvPort, - proto, ftp_data_app_id, flags); + proto, ftp_data_app_id, flags, handler->get_inspector()); if (fp) // initialize data session { diff --git a/src/network_inspectors/appid/service_plugins/service_rexec.cc b/src/network_inspectors/appid/service_plugins/service_rexec.cc index 631264eb8..aa93c443c 100644 --- a/src/network_inspectors/appid/service_plugins/service_rexec.cc +++ b/src/network_inspectors/appid/service_plugins/service_rexec.cc @@ -63,7 +63,7 @@ RexecServiceDetector::RexecServiceDetector(ServiceDiscovery* sd) proto = IpProtocol::TCP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppIdInspector::get_inspector()->add_appid_protocol_reference("rexec"); + app_id = AppInfoManager::get_instance().add_appid_protocol_reference("rexec"); appid_registry = { @@ -143,7 +143,7 @@ int RexecServiceDetector::validate(AppIdDiscoveryArgs& args) dip = args.pkt->ptrs.ip_api.get_dst(); sip = args.pkt->ptrs.ip_api.get_src(); AppIdSession* pf = AppIdSession::create_future_session(args.pkt, dip, 0, sip, (uint16_t)port, - IpProtocol::TCP, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE); + IpProtocol::TCP, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if (pf) { ServiceREXECData* tmp_rd = (ServiceREXECData*)snort_calloc( diff --git a/src/network_inspectors/appid/service_plugins/service_rpc.cc b/src/network_inspectors/appid/service_plugins/service_rpc.cc index 1ba78cdae..b5e4530a5 100644 --- a/src/network_inspectors/appid/service_plugins/service_rpc.cc +++ b/src/network_inspectors/appid/service_plugins/service_rpc.cc @@ -183,7 +183,7 @@ RpcServiceDetector::RpcServiceDetector(ServiceDiscovery* sd) struct rpcent* rpc; RPCProgram* prog; - app_id = AppIdInspector::get_inspector()->add_appid_protocol_reference("sunrpc"); + app_id = AppInfoManager::get_instance().add_appid_protocol_reference("sunrpc"); if (!rpc_programs) { @@ -404,7 +404,8 @@ int RpcServiceDetector::validate_packet(const uint8_t* data, uint16_t size, int const SfIp* sip = pkt->ptrs.ip_api.get_src(); tmp = ntohl(pmr->port); pf = AppIdSession::create_future_session(pkt, dip, 0, sip, (uint16_t)tmp, - (IpProtocol)ntohl((uint32_t)rd->proto), app_id, 0); + (IpProtocol)ntohl((uint32_t)rd->proto), app_id, 0, + handler->get_inspector()); if (pf) { pf->add_flow_data_id((uint16_t)tmp, this); diff --git a/src/network_inspectors/appid/service_plugins/service_rshell.cc b/src/network_inspectors/appid/service_plugins/service_rshell.cc index 3690f190f..213c56d9f 100644 --- a/src/network_inspectors/appid/service_plugins/service_rshell.cc +++ b/src/network_inspectors/appid/service_plugins/service_rshell.cc @@ -58,7 +58,7 @@ RshellServiceDetector::RshellServiceDetector(ServiceDiscovery* sd) name = "rshell"; proto = IpProtocol::TCP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppIdInspector::get_inspector()->add_appid_protocol_reference("rsh-error"); + app_id = AppInfoManager::get_instance().add_appid_protocol_reference("rsh-error"); appid_registry = { @@ -146,7 +146,8 @@ int RshellServiceDetector::validate(AppIdDiscoveryArgs& args) const SfIp* dip = pkt->ptrs.ip_api.get_dst(); const SfIp* sip = pkt->ptrs.ip_api.get_src(); pf = AppIdSession::create_future_session(pkt, dip, 0, sip, (uint16_t)port, - IpProtocol::TCP, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE); + IpProtocol::TCP, app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, + handler->get_inspector()); if (pf) { pf->client_disco_state = APPID_DISCO_STATE_FINISHED; diff --git a/src/network_inspectors/appid/service_plugins/service_snmp.cc b/src/network_inspectors/appid/service_plugins/service_snmp.cc index 0b97d12a3..51b9434e3 100644 --- a/src/network_inspectors/appid/service_plugins/service_snmp.cc +++ b/src/network_inspectors/appid/service_plugins/service_snmp.cc @@ -95,7 +95,7 @@ SnmpServiceDetector::SnmpServiceDetector(ServiceDiscovery* sd) proto = IpProtocol::UDP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppIdInspector::get_inspector()->add_appid_protocol_reference("snmp"); + app_id = AppInfoManager::get_instance().add_appid_protocol_reference("snmp"); udp_patterns = { @@ -481,7 +481,7 @@ int SnmpServiceDetector::validate(AppIdDiscoveryArgs& args) const SfIp* dip = pkt->ptrs.ip_api.get_dst(); const SfIp* sip = pkt->ptrs.ip_api.get_src(); pf = AppIdSession::create_future_session(pkt, dip, 0, sip, pkt->ptrs.sp, asd->protocol, - app_id, 0); + app_id, 0, handler->get_inspector()); if (pf) { tmp_sd = (ServiceSNMPData*)snort_calloc(sizeof(ServiceSNMPData)); diff --git a/src/network_inspectors/appid/service_plugins/service_ssl.cc b/src/network_inspectors/appid/service_plugins/service_ssl.cc index 2498c056c..07179fbd5 100644 --- a/src/network_inspectors/appid/service_plugins/service_ssl.cc +++ b/src/network_inspectors/appid/service_plugins/service_ssl.cc @@ -1128,7 +1128,7 @@ void ssl_detector_free_patterns() ssl_patterns_free(&service_ssl_config.DetectorSSLCnamePatternList); } -bool setSSLSquelch(Packet* p, int type, AppId appId) +bool setSSLSquelch(Packet* p, int type, AppId appId, AppIdInspector& inspector) { AppIdSession* f = nullptr; @@ -1139,7 +1139,7 @@ bool setSSLSquelch(Packet* p, int type, AppId appId) const SfIp* sip = p->ptrs.ip_api.get_src(); if (!(f = AppIdSession::create_future_session(p, sip, 0, dip, p->ptrs.dp, IpProtocol::TCP, - appId, 0))) + appId, 0, inspector))) return false; switch (type) diff --git a/src/network_inspectors/appid/service_plugins/service_ssl.h b/src/network_inspectors/appid/service_plugins/service_ssl.h index 8351412c9..7d7800610 100644 --- a/src/network_inspectors/appid/service_plugins/service_ssl.h +++ b/src/network_inspectors/appid/service_plugins/service_ssl.h @@ -43,7 +43,7 @@ int ssl_scan_cname(const uint8_t*, size_t, AppId*, AppId*); int ssl_add_cert_pattern(uint8_t*, size_t, uint8_t, AppId); int ssl_add_cname_pattern(uint8_t*, size_t, uint8_t, AppId); void ssl_detector_free_patterns(); -bool setSSLSquelch(Packet*, int type, AppId); +bool setSSLSquelch(Packet*, int type, AppId, AppIdInspector& inspector); #endif diff --git a/src/network_inspectors/appid/service_plugins/service_tftp.cc b/src/network_inspectors/appid/service_plugins/service_tftp.cc index 8dc3b61a0..92143dacf 100644 --- a/src/network_inspectors/appid/service_plugins/service_tftp.cc +++ b/src/network_inspectors/appid/service_plugins/service_tftp.cc @@ -71,7 +71,7 @@ TftpServiceDetector::TftpServiceDetector(ServiceDiscovery* sd) proto = IpProtocol::UDP; detectorType = DETECTOR_TYPE_DECODER; - app_id = AppIdInspector::get_inspector()->add_appid_protocol_reference("tftp"); + app_id = AppInfoManager::get_instance().add_appid_protocol_reference("tftp"); appid_registry = { @@ -187,7 +187,7 @@ int TftpServiceDetector::validate(AppIdDiscoveryArgs& args) dip = pkt->ptrs.ip_api.get_dst(); sip = pkt->ptrs.ip_api.get_src(); pf = AppIdSession::create_future_session(pkt, dip, 0, sip, pkt->ptrs.sp, asd->protocol, - app_id, APPID_EARLY_SESSION_FLAG_FW_RULE); + app_id, APPID_EARLY_SESSION_FLAG_FW_RULE, handler->get_inspector()); if (pf) { data_add(pf, tmp_td, &snort_free); diff --git a/src/network_inspectors/appid/test/app_info_table_test.cc b/src/network_inspectors/appid/test/app_info_table_test.cc index 3e9d30643..3eb77ecac 100644 --- a/src/network_inspectors/appid/test/app_info_table_test.cc +++ b/src/network_inspectors/appid/test/app_info_table_test.cc @@ -75,11 +75,13 @@ TEST_GROUP(app_info_table) void setup() { MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); + mock_init_appid_pegs(); } void teardown() { app_info_mgr.cleanup_appid_info_table(); + mock_cleanup_appid_pegs(); MemoryLeakWarningPlugin::turnOnNewDeleteOverloads(); } }; diff --git a/src/network_inspectors/appid/test/appid_api_test.cc b/src/network_inspectors/appid/test/appid_api_test.cc index 7db0fdc6b..4557299b8 100644 --- a/src/network_inspectors/appid/test/appid_api_test.cc +++ b/src/network_inspectors/appid/test/appid_api_test.cc @@ -33,8 +33,11 @@ #include "appid_mock_definitions.h" #include "appid_mock_http_session.h" +#include "appid_mock_inspector.h" #include "appid_mock_session.h" +#include "network_inspectors/appid/appid_peg_counts.h" + #include #include @@ -73,8 +76,9 @@ TEST_GROUP(appid_api) void setup() override { MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); + mock_init_appid_pegs(); flow = new Flow; - mock_session = new AppIdSession(IpProtocol::TCP, nullptr, 1492); + mock_session = new AppIdSession(IpProtocol::TCP, nullptr, 1492, appid_inspector); mock_session->hsession = init_http_session(mock_session); flow->set_flow_data(mock_session); } @@ -83,6 +87,8 @@ TEST_GROUP(appid_api) { delete mock_session; delete flow; + mock_cleanup_appid_pegs(); + MemoryLeakWarningPlugin::turnOnNewDeleteOverloads(); } }; @@ -669,6 +675,8 @@ TEST(appid_api, is_http_inspection_done) CHECK_TRUE(val); } +// FIXIT - enable this test when consume ha appid api call is fixed +#ifdef APPID_HA_SUPPORT_ENABLED TEST(appid_api, produce_ha_state) { AppIdSessionHA appHA, cmp_buf; @@ -729,6 +737,7 @@ TEST(appid_api, produce_ha_state) CHECK_TRUE(mock_session->service_disco_state == APPID_DISCO_STATE_STATEFUL); CHECK_TRUE(mock_session->client_disco_state == APPID_DISCO_STATE_FINISHED); } +#endif int main(int argc, char** argv) { diff --git a/src/network_inspectors/appid/test/appid_detector_test.cc b/src/network_inspectors/appid/test/appid_detector_test.cc index b8004eed0..a4b6e3b63 100644 --- a/src/network_inspectors/appid/test/appid_detector_test.cc +++ b/src/network_inspectors/appid/test/appid_detector_test.cc @@ -28,26 +28,14 @@ #include "protocols/protocol_ids.h" +#include "appid_mock_definitions.h" #include "appid_mock_http_session.h" +#include "appid_mock_inspector.h" #include "appid_mock_session.h" #include #include -char* snort_strdup(const char* str) -{ - assert(str); - size_t n = strlen(str) + 1; - char* p = (char*)snort_alloc(n); - memcpy(p, str, n); - return p; -} - -void ErrorMessage(const char*,...) { } -void WarningMessage(const char*,...) { } -void LogMessage(const char*,...) { } -void ParseWarning(WarningGroup, const char*, ...) { } - Flow* flow = nullptr; AppIdSession* mock_session = nullptr; @@ -66,8 +54,9 @@ TEST_GROUP(appid_detector_tests) void setup() override { MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); + mock_init_appid_pegs(); flow = new Flow; - mock_session = new AppIdSession(IpProtocol::TCP, nullptr, 1492); + mock_session = new AppIdSession(IpProtocol::TCP, nullptr, 1492, appid_inspector); mock_session->hsession = init_http_session(mock_session); flow->set_flow_data(mock_session); } @@ -76,6 +65,7 @@ TEST_GROUP(appid_detector_tests) { delete mock_session; delete flow; + mock_cleanup_appid_pegs(); MemoryLeakWarningPlugin::turnOnNewDeleteOverloads(); } }; diff --git a/src/network_inspectors/appid/test/appid_expected_flags_test.cc b/src/network_inspectors/appid/test/appid_expected_flags_test.cc index c32dac913..9670a3ee4 100644 --- a/src/network_inspectors/appid/test/appid_expected_flags_test.cc +++ b/src/network_inspectors/appid/test/appid_expected_flags_test.cc @@ -22,6 +22,7 @@ #include "network_inspectors/appid/service_plugins/service_detector.cc" #include "appid_mock_definitions.h" +#include "appid_mock_inspector.h" #include "appid_mock_session.h" #include @@ -53,14 +54,16 @@ TEST_GROUP(appid_expected_flags) void setup() { MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); - parent = new AppIdSession(IpProtocol::TCP, nullptr, 1492); - expected = new AppIdSession(IpProtocol::TCP, nullptr, 1492); + mock_init_appid_pegs(); + parent = new AppIdSession(IpProtocol::TCP, nullptr, 1492, appid_inspector); + expected = new AppIdSession(IpProtocol::TCP, nullptr, 1492, appid_inspector); } void teardown() { delete parent; delete expected; + mock_cleanup_appid_pegs(); MemoryLeakWarningPlugin::turnOnNewDeleteOverloads(); } }; diff --git a/src/network_inspectors/appid/test/appid_http_event_test.cc b/src/network_inspectors/appid/test/appid_http_event_test.cc index 401a85b77..5b5a73c5e 100644 --- a/src/network_inspectors/appid/test/appid_http_event_test.cc +++ b/src/network_inspectors/appid/test/appid_http_event_test.cc @@ -33,6 +33,7 @@ #include "appid_mock_definitions.h" #include "appid_mock_http_session.h" +#include "appid_mock_inspector.h" #include "appid_mock_session.h" #include @@ -172,8 +173,9 @@ TEST_GROUP(appid_http_event) void setup() override { MemoryLeakWarningPlugin::turnOffNewDeleteOverloads(); + mock_init_appid_pegs(); flow = new Flow; - mock_session = new AppIdSession(IpProtocol::TCP, nullptr, 1492); + mock_session = new AppIdSession(IpProtocol::TCP, nullptr, 1492, appid_inspector); flow->set_flow_data(mock_session); } @@ -182,6 +184,7 @@ TEST_GROUP(appid_http_event) fake_msg_header = nullptr; delete mock_session; delete flow; + mock_cleanup_appid_pegs(); mock().clear(); MemoryLeakWarningPlugin::turnOnNewDeleteOverloads(); } diff --git a/src/network_inspectors/appid/test/appid_mock_definitions.h b/src/network_inspectors/appid/test/appid_mock_definitions.h index c6a2dbf18..74df7941c 100644 --- a/src/network_inspectors/appid/test/appid_mock_definitions.h +++ b/src/network_inspectors/appid/test/appid_mock_definitions.h @@ -56,7 +56,10 @@ Field global_field; void Debug::print(const char*, int, uint64_t, const char*, ...) { } #endif -void ParseWarning(WarningGroup, char const*, ...) { } +void ErrorMessage(const char*,...) { } +void WarningMessage(const char*,...) { } +void LogMessage(const char*,...) { } +void ParseWarning(WarningGroup, const char*, ...) { } int ServiceDiscovery::add_ftp_service_state(AppIdSession&) { @@ -93,5 +96,17 @@ int ServiceDiscovery::fail_service(AppIdSession*, Packet const*, int, ServiceDet return 0; } +void mock_init_appid_pegs() +{ + AppIdPegCounts::set_detectors_configured(); + AppIdPegCounts::get_peg_info(); + AppIdPegCounts::init_pegs(); +} + +void mock_cleanup_appid_pegs() +{ + AppIdPegCounts::cleanup_pegs(); + AppIdPegCounts::cleanup_peg_info(); +} #endif diff --git a/src/network_inspectors/appid/test/appid_mock_inspector.h b/src/network_inspectors/appid/test/appid_mock_inspector.h index c64e5165a..e58f41759 100644 --- a/src/network_inspectors/appid/test/appid_mock_inspector.h +++ b/src/network_inspectors/appid/test/appid_mock_inspector.h @@ -18,6 +18,9 @@ // appid_mock_inspector.h author davis mcpherson +typedef uint64_t Trace; +class Value; + Inspector::Inspector() { set_api(nullptr); @@ -28,27 +31,25 @@ bool Inspector::likes(Packet*) { return true; } bool Inspector::get_buf(const char*, Packet*, InspectionBuffer&) { return true; } class StreamSplitter* Inspector::get_splitter(bool) { return nullptr; } -Module::Module(const char*, const char*) { } -void Module::sum_stats(bool ) {} -void Module::show_interval_stats(IndexVec&, FILE*) {} -void Module::show_stats() {} -void Module::reset_stats() {} - -AppIdModule::~AppIdModule() {} -AppIdModule::AppIdModule() : Module(nullptr, nullptr), config(nullptr) {} -bool AppIdModule::begin(char const*, int, SnortConfig*) { return true; } -bool AppIdModule::end(char const*, int, SnortConfig*) { return true; } -bool AppIdModule::set(char const*, Value&, SnortConfig*) { return true; } -const PegInfo* AppIdModule::get_pegs() const { return nullptr; } -PegCount* AppIdModule::get_counts() const { return nullptr; } -ProfileStats* AppIdModule::get_profile() const { return nullptr; } - -AppIdInspector::AppIdInspector(AppIdModule& m) : appid_mod(m), my_seh(nullptr) { } -AppIdInspector::~AppIdInspector() { } -AppIdInspector* AppIdInspector::get_inspector() { AppIdModule aim; return new AppIdInspector(aim); } -void AppIdInspector::eval(Packet*) { } -int16_t AppIdInspector::add_appid_protocol_reference(char const*) { return 1066; } -bool AppIdInspector::configure(SnortConfig*) { return true; } -void AppIdInspector::show(SnortConfig*) { } -void AppIdInspector::tinit() { } -void AppIdInspector::tterm() { } +class AppIdModule +{ +public: + AppIdModule() {} + ~AppIdModule() {} + +}; + +class AppIdInspector : public Inspector +{ +public: + AppIdInspector(AppIdModule& ) { } + ~AppIdInspector() { } + void eval(Packet*) { } + bool configure(SnortConfig*) { return true; } + void show(SnortConfig*) { } + void tinit() { } + void tterm() { } +}; + +AppIdModule appid_mod; +AppIdInspector appid_inspector( appid_mod ); diff --git a/src/network_inspectors/appid/test/appid_mock_session.h b/src/network_inspectors/appid/test/appid_mock_session.h index 5eb70d667..c15ab91a2 100644 --- a/src/network_inspectors/appid/test/appid_mock_session.h +++ b/src/network_inspectors/appid/test/appid_mock_session.h @@ -53,9 +53,9 @@ AppIdServiceSubtype APPID_UT_SERVICE_SUBTYPE = { nullptr, APPID_UT_SERVICE, unsigned AppIdSession::inspector_id = 0; -AppIdSession::AppIdSession(IpProtocol, const SfIp*, uint16_t) : FlowData(inspector_id, nullptr) +AppIdSession::AppIdSession(IpProtocol, const SfIp*, uint16_t, AppIdInspector& inspector) + : FlowData(inspector_id, &inspector), inspector(inspector) { - service.set_port_service_id(APPID_UT_ID); common.flow_type = APPID_FLOW_TYPE_NORMAL; service_port = APPID_UT_SERVICE_PORT;