From: Ron Dempster (rdempste) Date: Mon, 24 Jun 2024 16:26:37 +0000 (+0000) Subject: Pull request #4138: appid: restructure the appid code to make it easier to follow... X-Git-Tag: 3.3.1.0~8 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=6f1c98ac237d58f8d6fbe8c9bb1daa4a27e12ac0;p=thirdparty%2Fsnort3.git Pull request #4138: appid: restructure the appid code to make it easier to follow and maintain Merge in SNORT/snort3 from ~RDEMPSTE/snort3:appid to master Squashed commit of the following: commit 1195b37a59928477641dc2efbf606eb2eaca857b Author: Ron Dempster (rdempste) Date: Tue Sep 19 10:42:40 2023 -0400 appid: restructure the appid code to make it easier to follow and maintain --- diff --git a/src/network_inspectors/appid/appid_api.cc b/src/network_inspectors/appid/appid_api.cc index 623a99c56..8d3420d99 100644 --- a/src/network_inspectors/appid/appid_api.cc +++ b/src/network_inspectors/appid/appid_api.cc @@ -261,7 +261,7 @@ bool AppIdApi::is_inspection_needed(const Inspector& inspector) const return false; SnortProtocolId id = inspector.get_service(); - const AppIdConfig& config = appid_inspector->get_ctxt().config; + const AppIdConfig& config = appid_inspector->get_config(); if (id == config.snort_proto_ids[PROTO_INDEX_HTTP2] or id == config.snort_proto_ids[PROTO_INDEX_SSH] or id == config.snort_proto_ids[PROTO_INDEX_CIP]) return true; diff --git a/src/network_inspectors/appid/appid_config.cc b/src/network_inspectors/appid/appid_config.cc index a6a6c90b8..305d2e6bd 100644 --- a/src/network_inspectors/appid/appid_config.cc +++ b/src/network_inspectors/appid/appid_config.cc @@ -57,25 +57,25 @@ ThirdPartyAppIdContext* AppIdContext::tp_appid_ctxt = nullptr; OdpContext* AppIdContext::odp_ctxt = nullptr; uint32_t OdpContext::next_version = 0; -static void map_app_names_to_snort_ids(SnortConfig* sc, AppIdConfig& config) +AppIdConfig::~AppIdConfig() { - // Have to create SnortProtocolIds during configuration initialization. - config.snort_proto_ids[PROTO_INDEX_UNSYNCHRONIZED] = sc->proto_ref->add("unsynchronized"); - config.snort_proto_ids[PROTO_INDEX_FTP_DATA] = sc->proto_ref->add("ftp-data"); - config.snort_proto_ids[PROTO_INDEX_HTTP2] = sc->proto_ref->add("http2"); - config.snort_proto_ids[PROTO_INDEX_REXEC] = sc->proto_ref->add("rexec"); - config.snort_proto_ids[PROTO_INDEX_RSH_ERROR] = sc->proto_ref->add("rsh-error"); - config.snort_proto_ids[PROTO_INDEX_SNMP] = sc->proto_ref->add("snmp"); - config.snort_proto_ids[PROTO_INDEX_SUNRPC] = sc->proto_ref->add("sunrpc"); - config.snort_proto_ids[PROTO_INDEX_TFTP] = sc->proto_ref->add("tftp"); - config.snort_proto_ids[PROTO_INDEX_SIP] = sc->proto_ref->add("sip"); - config.snort_proto_ids[PROTO_INDEX_SSH] = sc->proto_ref->add("ssh"); - config.snort_proto_ids[PROTO_INDEX_CIP] = sc->proto_ref->add("cip"); + snort_free((void*)app_detector_dir); } -AppIdConfig::~AppIdConfig() +void AppIdConfig::map_app_names_to_snort_ids(SnortConfig& sc) { - snort_free((void*)app_detector_dir); + // Have to create SnortProtocolIds during configuration initialization. + snort_proto_ids[PROTO_INDEX_UNSYNCHRONIZED] = sc.proto_ref->add("unsynchronized"); + snort_proto_ids[PROTO_INDEX_FTP_DATA] = sc.proto_ref->add("ftp-data"); + snort_proto_ids[PROTO_INDEX_HTTP2] = sc.proto_ref->add("http2"); + snort_proto_ids[PROTO_INDEX_REXEC] = sc.proto_ref->add("rexec"); + snort_proto_ids[PROTO_INDEX_RSH_ERROR] = sc.proto_ref->add("rsh-error"); + snort_proto_ids[PROTO_INDEX_SNMP] = sc.proto_ref->add("snmp"); + snort_proto_ids[PROTO_INDEX_SUNRPC] = sc.proto_ref->add("sunrpc"); + snort_proto_ids[PROTO_INDEX_TFTP] = sc.proto_ref->add("tftp"); + snort_proto_ids[PROTO_INDEX_SIP] = sc.proto_ref->add("sip"); + snort_proto_ids[PROTO_INDEX_SSH] = sc.proto_ref->add("ssh"); + snort_proto_ids[PROTO_INDEX_CIP] = sc.proto_ref->add("cip"); } void AppIdConfig::show() const @@ -102,11 +102,8 @@ static bool once = false; void AppIdContext::pterm() { - if (odp_thread_local_ctxt) - { - delete odp_thread_local_ctxt; - odp_thread_local_ctxt = nullptr; - } + delete odp_control_thread_ctxt; + odp_control_thread_ctxt = nullptr; if (odp_ctxt) { @@ -131,14 +128,10 @@ void AppIdContext::pterm() bool AppIdContext::init_appid(SnortConfig* sc, AppIdInspector& inspector) { // do not reload ODP on reload_config() - if (!odp_ctxt) - odp_ctxt = new OdpContext(config, sc); - - if (!odp_thread_local_ctxt) - odp_thread_local_ctxt = new OdpThreadContext; - if (!once) { + assert(!odp_ctxt); + odp_ctxt = new OdpContext(config, sc); odp_ctxt->get_client_disco_mgr().initialize(inspector); odp_ctxt->get_service_disco_mgr().initialize(inspector); odp_ctxt->set_client_and_service_detectors(); @@ -149,7 +142,10 @@ bool AppIdContext::init_appid(SnortConfig* sc, AppIdInspector& inspector) appidDebug->set_enabled(config.log_all_sessions); } - odp_thread_local_ctxt->initialize(sc, *this, true); + assert(!odp_control_thread_ctxt); + odp_control_thread_ctxt = new OdpControlContext; + odp_control_thread_ctxt->initialize(sc, *this); + odp_ctxt->initialize(inspector); // do not reload third party on reload_config() @@ -159,12 +155,12 @@ bool AppIdContext::init_appid(SnortConfig* sc, AppIdInspector& inspector) } else { + assert(odp_ctxt); odp_ctxt->get_client_disco_mgr().reload(); odp_ctxt->get_service_disco_mgr().reload(); odp_ctxt->reload(); } - map_app_names_to_snort_ids(sc, config); if (config.enable_rna_filter) discovery_filter = new DiscoveryFilter(config.rna_conf_path); return true; @@ -224,18 +220,18 @@ void OdpContext::dump_appid_config() appid_log(nullptr, TRACE_INFO_LEVEL, "Appid Config: max_packet_before_service_fail %" PRIu16" \n", max_packet_before_service_fail); appid_log(nullptr, TRACE_INFO_LEVEL, "Appid Config: max_packet_service_fail_ignore_bytes %" PRIu16" \n", max_packet_service_fail_ignore_bytes); appid_log(nullptr, TRACE_INFO_LEVEL, "Appid Config: eve_http_client %s\n", (eve_http_client ? "True" : "False")); - appid_log(nullptr, TRACE_INFO_LEVEL, "Appid Config: appid_cpu_profiler %s\n", (appid_cpu_profiler ? "True" : "False")); + appid_log(nullptr, TRACE_INFO_LEVEL, "Appid Config: appid_cpu_profiler %s\n", (appid_cpu_profiler ? "True" : "False")); } bool OdpContext::is_appid_cpu_profiler_running() { return (TimeProfilerStats::is_enabled() and appid_cpu_profiler); -} +} bool OdpContext::is_appid_cpu_profiler_enabled() { return appid_cpu_profiler; -} +} OdpContext::OdpContext(const AppIdConfig& config, SnortConfig* sc) { @@ -367,17 +363,15 @@ AppId OdpContext::get_protocol_service_id(IpProtocol proto) return ip_protocol[(uint16_t)proto]; } -void OdpThreadContext::initialize(const SnortConfig* sc, AppIdContext& ctxt, bool is_control, - bool reload_odp) +void OdpControlContext::initialize(const SnortConfig* sc, AppIdContext& ctxt) { - if (!is_control and reload_odp) - LuaDetectorManager::init_thread_manager(sc, ctxt); - else - LuaDetectorManager::initialize(sc, ctxt, is_control, reload_odp); + lua_detector_mgr = std::make_shared(ctxt); + lua_detector_mgr->initialize(sc); } -OdpThreadContext::~OdpThreadContext() +void OdpPacketThreadContext::initialize(const SnortConfig* sc) { + lua_detector_mgr = ControlLuaDetectorManager::get_packet_lua_detector_manager(); assert(lua_detector_mgr); - delete lua_detector_mgr; + lua_detector_mgr->initialize(sc); } diff --git a/src/network_inspectors/appid/appid_config.h b/src/network_inspectors/appid/appid_config.h index 1a840083a..fefed2635 100644 --- a/src/network_inspectors/appid/appid_config.h +++ b/src/network_inspectors/appid/appid_config.h @@ -23,6 +23,7 @@ #define APP_ID_CONFIG_H #include +#include #include #include "helpers/discovery_filter.h" @@ -88,6 +89,8 @@ public: AppIdConfig() = default; ~AppIdConfig(); + void map_app_names_to_snort_ids(snort::SnortConfig&); + // FIXIT-L: DECRYPT_DEBUG - Move this to ssl-module #ifdef REG_TEST // To manually restart appid detection for an SSL-decrypted flow (single session only), @@ -299,23 +302,65 @@ private: class OdpThreadContext { public: - ~OdpThreadContext(); - void initialize(const snort::SnortConfig*, AppIdContext& ctxt, bool is_control=false, - bool reload_odp=false); + virtual ~OdpThreadContext() = default; - void set_lua_detector_mgr(LuaDetectorManager& mgr) + lua_State* get_lua_state() const { - lua_detector_mgr = &mgr; + assert(lua_detector_mgr); + return lua_detector_mgr->L; } - LuaDetectorManager& get_lua_detector_mgr() const + bool insert_cb_detector(AppId app_id, LuaObject* ud) { assert(lua_detector_mgr); - return *lua_detector_mgr; + return lua_detector_mgr->insert_cb_detector(app_id, ud); } -private: - LuaDetectorManager* lua_detector_mgr = nullptr; + LuaObject* get_cb_detector(AppId app_id) + { + assert(lua_detector_mgr); + return lua_detector_mgr->get_cb_detector(app_id); + } + +protected: + std::shared_ptr lua_detector_mgr; +}; + +class OdpControlContext : public OdpThreadContext +{ +public: + ~OdpControlContext() override = default; + void initialize(const snort::SnortConfig*, AppIdContext&); + void set_ignore_chp_cleanup() + { + assert(lua_detector_mgr); + static_cast(lua_detector_mgr.get())->set_ignore_chp_cleanup(); + } +}; + +class OdpPacketThreadContext : public OdpThreadContext +{ +public: + ~OdpPacketThreadContext() override = default; + void initialize(const snort::SnortConfig*); + + void set_detector_flow(DetectorFlow* df) + { + assert(lua_detector_mgr); + static_cast(lua_detector_mgr.get())->set_detector_flow(df); + } + + DetectorFlow* get_detector_flow() + { + assert(lua_detector_mgr); + return static_cast(lua_detector_mgr.get())->get_detector_flow(); + } + + void free_detector_flow() + { + assert(lua_detector_mgr); + static_cast(lua_detector_mgr.get())->free_detector_flow(); + } }; class AppIdContext diff --git a/src/network_inspectors/appid/appid_detector.cc b/src/network_inspectors/appid/appid_detector.cc index b3384cdf9..7316d2f27 100644 --- a/src/network_inspectors/appid/appid_detector.cc +++ b/src/network_inspectors/appid/appid_detector.cc @@ -60,12 +60,7 @@ int AppIdDetector::initialize(AppIdInspector& inspector) return APPID_SUCCESS; } -void AppIdDetector::reload() -{ - do_custom_reload(); -} - -void* AppIdDetector::data_get(AppIdSession& asd) +void* AppIdDetector::data_get(const AppIdSession& asd) { return asd.get_flow_data(flow_data_index); } diff --git a/src/network_inspectors/appid/appid_detector.h b/src/network_inspectors/appid/appid_detector.h index 74695522d..1ea28095b 100644 --- a/src/network_inspectors/appid/appid_detector.h +++ b/src/network_inspectors/appid/appid_detector.h @@ -111,18 +111,17 @@ public: AppIdDetector() = default; virtual ~AppIdDetector() = default; - virtual int initialize(AppIdInspector&); - virtual void reload(); + int initialize(AppIdInspector&); virtual void do_custom_init() { } virtual void do_custom_reload() { } virtual int validate(AppIdDiscoveryArgs&) = 0; virtual void register_appid(AppId, unsigned extractsInfo, OdpContext& odp_ctxt) = 0; - virtual void* data_get(AppIdSession&); - virtual int data_add(AppIdSession&, void*, AppIdFreeFCN); - virtual void add_user(AppIdSession&, const char*, AppId, bool, AppidChangeBits&); - virtual void add_payload(AppIdSession&, AppId); - virtual void add_app(AppIdSession& asd, AppId service_id, AppId client_id, const char* version, AppidChangeBits& change_bits) + void* data_get(const AppIdSession&); + int data_add(AppIdSession&, void*, AppIdFreeFCN); + void add_user(AppIdSession&, const char*, AppId, bool, AppidChangeBits&); + void add_payload(AppIdSession&, AppId); + void add_app(AppIdSession& asd, AppId service_id, AppId client_id, const char* version, AppidChangeBits& change_bits) { if ( version ) asd.set_client_version(version, change_bits); @@ -131,7 +130,7 @@ public: asd.client_inferred_service_id = service_id; asd.set_client_id(client_id); } - virtual void add_app(const snort::Packet&, AppIdSession&, AppidSessionDirection, AppId, AppId, const char*, AppidChangeBits&); + void add_app(const snort::Packet&, AppIdSession&, AppidSessionDirection, AppId, AppId, const char*, AppidChangeBits&); const char* get_code_string(APPID_STATUS_CODE) const; const std::string& get_name() const diff --git a/src/network_inspectors/appid/appid_discovery.cc b/src/network_inspectors/appid/appid_discovery.cc index 680b6ac2a..93a3f9b7c 100644 --- a/src/network_inspectors/appid/appid_discovery.cc +++ b/src/network_inspectors/appid/appid_discovery.cc @@ -242,9 +242,9 @@ bool AppIdDiscovery::do_pre_discovery(Packet* p, AppIdSession*& asd, AppIdInspec return false; } + const AppIdConfig& config = inspector.get_config(); if (appidDebug->is_enabled()) - appidDebug->activate(p->flow, asd, - inspector.get_ctxt().config.log_all_sessions); + appidDebug->activate(p->flow, asd, config.log_all_sessions); if (is_packet_ignored(p)) return false; @@ -386,7 +386,7 @@ bool AppIdDiscovery::do_pre_discovery(Packet* p, AppIdSession*& asd, AppIdInspec // FIXIT-L: DECRYPT_DEBUG - Move set_proxied and first_decrypted_packet_debug to ssl-module // after ssl-module's decryption capability is implemented #ifdef REG_TEST - uint32_t fdpd = inspector.get_ctxt().config.first_decrypted_packet_debug; + uint32_t fdpd = config.first_decrypted_packet_debug; if (fdpd and (fdpd == asd->session_packet_count)) { p->flow->set_proxied(); diff --git a/src/network_inspectors/appid/appid_discovery.h b/src/network_inspectors/appid/appid_discovery.h index 060fc278e..9248cdbb5 100644 --- a/src/network_inspectors/appid/appid_discovery.h +++ b/src/network_inspectors/appid/appid_discovery.h @@ -116,12 +116,12 @@ public: virtual void initialize(AppIdInspector&) = 0; virtual void reload() = 0; - virtual void register_detector(const std::string&, AppIdDetector*, IpProtocol); - virtual void add_pattern_data(AppIdDetector*, snort::SearchTool&, int position, + void register_detector(const std::string&, AppIdDetector*, IpProtocol); + void add_pattern_data(AppIdDetector*, snort::SearchTool&, int position, const uint8_t* const pattern, unsigned size, unsigned nocase); - virtual void register_tcp_pattern(AppIdDetector*, const uint8_t* const pattern, unsigned size, + void register_tcp_pattern(AppIdDetector*, const uint8_t* const pattern, unsigned size, int position, unsigned nocase); - virtual void register_udp_pattern(AppIdDetector*, const uint8_t* const pattern, unsigned size, + void register_udp_pattern(AppIdDetector*, const uint8_t* const pattern, unsigned size, int position, unsigned nocase); virtual int add_service_port(AppIdDetector*, const ServiceDetectorPort&); diff --git a/src/network_inspectors/appid/appid_http_event_handler.cc b/src/network_inspectors/appid/appid_http_event_handler.cc index e21d532b7..541f72f17 100644 --- a/src/network_inspectors/appid/appid_http_event_handler.cc +++ b/src/network_inspectors/appid/appid_http_event_handler.cc @@ -54,6 +54,7 @@ void HttpEventHandler::handle(DataEvent& event, Flow* flow) auto direction = event_type == REQUEST_EVENT ? APP_ID_FROM_INITIATOR : APP_ID_FROM_RESPONDER; bool is_debug_active = false; + const AppIdConfig& config = inspector.get_config(); if ( !asd ) { // The event is received before appid has seen any packet, e.g., data on SYN @@ -90,7 +91,7 @@ void HttpEventHandler::handle(DataEvent& event, Flow* flow) per_appid_event_cpu_timer.start(); if (appidDebug->is_enabled() and !is_debug_active) - appidDebug->activate(flow, asd, inspector.get_ctxt().config.log_all_sessions); + appidDebug->activate(flow, asd, config.log_all_sessions); appid_log(p, TRACE_DEBUG_LEVEL, "Processing HTTP metadata from HTTP Inspector for stream %" PRId64 "\n", http_event->get_httpx_stream_id()); diff --git a/src/network_inspectors/appid/appid_inspector.cc b/src/network_inspectors/appid/appid_inspector.cc index ad1a5ff94..0138dbff0 100644 --- a/src/network_inspectors/appid/appid_inspector.cc +++ b/src/network_inspectors/appid/appid_inspector.cc @@ -63,7 +63,8 @@ using namespace snort; THREAD_LOCAL ThirdPartyAppIdContext* pkt_thread_tp_appid_ctxt = nullptr; -THREAD_LOCAL OdpThreadContext* odp_thread_local_ctxt = nullptr; +OdpControlContext* odp_control_thread_ctxt = nullptr; +THREAD_LOCAL OdpPacketThreadContext* odp_thread_local_ctxt = nullptr; THREAD_LOCAL OdpContext* pkt_thread_odp_ctxt = nullptr; unsigned AppIdInspector::cached_global_pub_id = 0; @@ -95,31 +96,22 @@ static void add_appid_to_packet_trace(const Flow& flow, const OdpContext& odp_co (misc_name ? misc_name : ""), misc_id); } -AppIdInspector::AppIdInspector(AppIdModule& mod) +AppIdInspector::AppIdInspector(AppIdModule& mod) : config(mod.get_data()), ctxt(*config) { - config = mod.get_data(); - assert(config); } AppIdInspector::~AppIdInspector() { - delete ctxt; delete config; } -AppIdContext& AppIdInspector::get_ctxt() const -{ - assert(ctxt); - return *ctxt; -} -unsigned AppIdInspector::get_pub_id() +unsigned AppIdInspector::get_pub_id() { return appid_pub_id; } bool AppIdInspector::configure(SnortConfig* sc) { - assert(!ctxt); // cppcheck-suppress unreadVariable Profile profile(appid_perf_stats); struct rusage ru; @@ -134,8 +126,9 @@ bool AppIdInspector::configure(SnortConfig* sc) } #endif - ctxt = new AppIdContext(const_cast(*config)); - ctxt->init_appid(sc, *this); + assert(sc); + config->map_app_names_to_snort_ids(*sc); + ctxt.init_appid(sc, *this); #ifdef REG_TEST if ( config->log_memory_and_pattern_count ) @@ -145,7 +138,7 @@ bool AppIdInspector::configure(SnortConfig* sc) appid_log(nullptr, TRACE_ERROR_LEVEL, "appid: fetching memory usage failed\n"); else appid_log(nullptr, TRACE_INFO_LEVEL, "appid: MaxRss diff: %li\n", ru.ru_maxrss - prev_maxrss); - appid_log(nullptr, TRACE_INFO_LEVEL, "appid: patterns loaded: %u\n", ctxt->get_odp_ctxt().get_pattern_count()); + appid_log(nullptr, TRACE_INFO_LEVEL, "appid: patterns loaded: %u\n", ctxt.get_odp_ctxt().get_pattern_count()); #ifdef REG_TEST } #endif @@ -190,18 +183,18 @@ void AppIdInspector::tinit() AppIdStatistics::initialize_manager(*config); assert(!pkt_thread_odp_ctxt); - pkt_thread_odp_ctxt = &(ctxt->get_odp_ctxt()); + pkt_thread_odp_ctxt = &ctxt.get_odp_ctxt(); assert(!odp_thread_local_ctxt); - odp_thread_local_ctxt = new OdpThreadContext(); - odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), *ctxt); + odp_thread_local_ctxt = new OdpPacketThreadContext; + odp_thread_local_ctxt->initialize(SnortConfig::get_conf()); AppIdServiceState::initialize(config->memcap); assert(!pkt_thread_tp_appid_ctxt); - pkt_thread_tp_appid_ctxt = ctxt->get_tp_appid_ctxt(); + pkt_thread_tp_appid_ctxt = ctxt.get_tp_appid_ctxt(); if (pkt_thread_tp_appid_ctxt) pkt_thread_tp_appid_ctxt->tinit(); - if (ctxt->config.log_all_sessions) + if (config->log_all_sessions) appidDebug->set_enabled(true); if ( snort::HighAvailabilityManager::active() ) AppIdHAManager::tinit(); diff --git a/src/network_inspectors/appid/appid_inspector.h b/src/network_inspectors/appid/appid_inspector.h index d2928b392..09a8bbec1 100644 --- a/src/network_inspectors/appid/appid_inspector.h +++ b/src/network_inspectors/appid/appid_inspector.h @@ -46,20 +46,24 @@ public: void tterm() override; void tear_down(snort::SnortConfig*) override; void eval(snort::Packet*) override; - AppIdContext& get_ctxt() const; - const AppIdConfig& get_config() const { return *config; } + AppIdContext& get_ctxt() + { return ctxt; } + + const AppIdConfig& get_config() const + { return *config; } static unsigned get_pub_id(); private: - const AppIdConfig* config = nullptr; - AppIdContext* ctxt = nullptr; + AppIdConfig* config = nullptr; + AppIdContext ctxt; static unsigned cached_global_pub_id; }; extern const snort::InspectApi appid_inspector_api; -extern THREAD_LOCAL OdpThreadContext* odp_thread_local_ctxt; +extern OdpControlContext* odp_control_thread_ctxt; +extern THREAD_LOCAL OdpPacketThreadContext* odp_thread_local_ctxt; extern THREAD_LOCAL OdpContext* pkt_thread_odp_ctxt; extern THREAD_LOCAL ThirdPartyAppIdContext* pkt_thread_tp_appid_ctxt; diff --git a/src/network_inspectors/appid/appid_module.cc b/src/network_inspectors/appid/appid_module.cc index 3fc1249f0..d981580a7 100644 --- a/src/network_inspectors/appid/appid_module.cc +++ b/src/network_inspectors/appid/appid_module.cc @@ -159,7 +159,7 @@ class ACThirdPartyAppIdContextSwap : public AnalyzerCommand { public: bool execute(Analyzer&, void**) override; - ACThirdPartyAppIdContextSwap(const AppIdInspector& inspector, ControlConn* conn) + ACThirdPartyAppIdContextSwap(AppIdInspector& inspector, ControlConn* conn) : AnalyzerCommand(conn), inspector(inspector) { appid_log(nullptr, TRACE_INFO_LEVEL, "== swapping third-party configuration\n"); @@ -168,7 +168,7 @@ public: ~ACThirdPartyAppIdContextSwap() override; const char* stringify() override { return "THIRD-PARTY_CONTEXT_SWAP"; } private: - const AppIdInspector& inspector; + AppIdInspector& inspector; }; bool ACThirdPartyAppIdContextSwap::execute(Analyzer&, void**) @@ -195,13 +195,13 @@ class ACThirdPartyAppIdContextUnload : public AnalyzerCommand { public: bool execute(Analyzer&, void**) override; - ACThirdPartyAppIdContextUnload(const AppIdInspector& inspector, ThirdPartyAppIdContext* tp_ctxt, + ACThirdPartyAppIdContextUnload(AppIdInspector& inspector, ThirdPartyAppIdContext* tp_ctxt, ControlConn* conn): AnalyzerCommand(conn), inspector(inspector), tp_ctxt(tp_ctxt) { } ~ACThirdPartyAppIdContextUnload() override; const char* stringify() override { return "THIRD-PARTY_CONTEXT_UNLOAD"; } private: - const AppIdInspector& inspector; + AppIdInspector& inspector; ThirdPartyAppIdContext* tp_ctxt = nullptr; }; @@ -235,34 +235,34 @@ class ACOdpContextSwap : public AnalyzerCommand { public: bool execute(Analyzer&, void**) override; - ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt, ControlConn* conn) : + ACOdpContextSwap(AppIdInspector& inspector, OdpContext& odp_ctxt, ControlConn* conn) : AnalyzerCommand(conn), inspector(inspector), odp_ctxt(odp_ctxt) { } ~ACOdpContextSwap() override; const char* stringify() override { return "ODP_CONTEXT_SWAP"; } private: - const AppIdInspector& inspector; + AppIdInspector& inspector; OdpContext& odp_ctxt; }; bool ACOdpContextSwap::execute(Analyzer&, void**) { - AppIdContext& ctxt = inspector.get_ctxt(); - OdpContext& current_odp_ctxt = ctxt.get_odp_ctxt(); - assert(pkt_thread_odp_ctxt != ¤t_odp_ctxt); - HostAttributesManager::clear_appid_services(); AppIdServiceState::clean(); AppIdPegCounts::cleanup_pegs(); - AppIdServiceState::initialize(ctxt.config.memcap); + const AppIdConfig& config = inspector.get_config(); + AppIdServiceState::initialize(config.memcap); AppIdPegCounts::init_pegs(); ServiceDiscovery::set_thread_local_ftp_service(); + AppIdContext& ctxt = inspector.get_ctxt(); + OdpContext& current_odp_ctxt = ctxt.get_odp_ctxt(); + assert(pkt_thread_odp_ctxt != ¤t_odp_ctxt); pkt_thread_odp_ctxt = ¤t_odp_ctxt; assert(odp_thread_local_ctxt); delete odp_thread_local_ctxt; - odp_thread_local_ctxt = new OdpThreadContext; - odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), ctxt, false, true); + odp_thread_local_ctxt = new OdpPacketThreadContext; + odp_thread_local_ctxt->initialize(SnortConfig::get_conf()); return true; } @@ -273,7 +273,7 @@ ACOdpContextSwap::~ACOdpContextSwap() delete &odp_ctxt; AppIdContext& ctxt = inspector.get_ctxt(); - LuaDetectorManager::cleanup_after_swap(); + ControlLuaDetectorManager::cleanup_after_swap(); if (ctxt.config.app_detector_dir) { std::string file_path = std::string(ctxt.config.app_detector_dir) + "/custom/userappid.conf"; @@ -492,19 +492,19 @@ static int reload_detectors(lua_State* L) clear_dynamic_host_cache_services(); AppIdPegCounts::cleanup_peg_info(); AppIdPegCounts::init_peg_info(); - LuaDetectorManager::clear_lua_detector_mgrs(); + ControlLuaDetectorManager::clear_lua_detector_mgrs(); ctxt.create_odp_ctxt(); - assert(odp_thread_local_ctxt); - odp_thread_local_ctxt->get_lua_detector_mgr().set_ignore_chp_cleanup(true); - delete odp_thread_local_ctxt; - odp_thread_local_ctxt = new OdpThreadContext; + assert(odp_control_thread_ctxt); + odp_control_thread_ctxt->set_ignore_chp_cleanup(); + delete odp_control_thread_ctxt; + odp_control_thread_ctxt = new OdpControlContext; OdpContext& odp_ctxt = ctxt.get_odp_ctxt(); odp_ctxt.get_client_disco_mgr().initialize(*inspector); odp_ctxt.get_service_disco_mgr().initialize(*inspector); odp_ctxt.set_client_and_service_detectors(); - odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), ctxt, true, true); + odp_control_thread_ctxt->initialize(SnortConfig::get_conf(), ctxt); odp_ctxt.initialize(*inspector); ctrlcon->respond("== swapping detectors configuration\n"); @@ -575,7 +575,14 @@ static const PegInfo appid_pegs[] = }; AppIdModule::AppIdModule() : Module(MOD_NAME, MOD_HELP, s_params) -{ config = nullptr; } +{ +} + +AppIdModule::~AppIdModule() +{ + AppIdPegCounts::cleanup_peg_info(); + delete config; +} void AppIdModule::set_trace(const Trace* trace) const { appid_trace = trace; } @@ -607,7 +614,7 @@ snort::ProfileStats* AppIdModule::get_profile( return nullptr; } -const AppIdConfig* AppIdModule::get_data() +AppIdConfig* AppIdModule::get_data() { AppIdConfig* temp = config; config = nullptr; diff --git a/src/network_inspectors/appid/appid_module.h b/src/network_inspectors/appid/appid_module.h index 1e08d5563..186985538 100644 --- a/src/network_inspectors/appid/appid_module.h +++ b/src/network_inspectors/appid/appid_module.h @@ -77,7 +77,7 @@ class AppIdModule : public snort::Module { public: AppIdModule(); - ~AppIdModule() override = default; + ~AppIdModule() override; bool begin(const char*, int, snort::SnortConfig*) override; bool set(const char*, snort::Value&, snort::SnortConfig*) override; @@ -89,7 +89,7 @@ public: snort::ProfileStats* get_profile( unsigned i, const char*& name, const char*& parent) const override; - const AppIdConfig* get_data(); + AppIdConfig* get_data(); void reset_stats() override; @@ -102,7 +102,7 @@ public: const snort::TraceOption* get_trace_options() const override; private: - AppIdConfig* config; + AppIdConfig* config = nullptr; }; class ACThirdPartyAppIdCleanup : public snort::AnalyzerCommand diff --git a/src/network_inspectors/appid/client_plugins/client_discovery.cc b/src/network_inspectors/appid/client_plugins/client_discovery.cc index 2d1f64ec9..febbde921 100644 --- a/src/network_inspectors/appid/client_plugins/client_discovery.cc +++ b/src/network_inspectors/appid/client_plugins/client_discovery.cc @@ -75,9 +75,9 @@ void ClientDiscovery::initialize(AppIdInspector& inspector) void ClientDiscovery::reload() { for ( auto& kv : tcp_detectors ) - kv.second->reload(); + kv.second->do_custom_reload(); for ( auto& kv : udp_detectors ) - kv.second->reload(); + kv.second->do_custom_reload(); } void ClientDiscovery::finalize_client_patterns() diff --git a/src/network_inspectors/appid/client_plugins/test/eve_ca_patterns_test.cc b/src/network_inspectors/appid/client_plugins/test/eve_ca_patterns_test.cc index deed4a0b9..f1bb8fe60 100644 --- a/src/network_inspectors/appid/client_plugins/test/eve_ca_patterns_test.cc +++ b/src/network_inspectors/appid/client_plugins/test/eve_ca_patterns_test.cc @@ -48,8 +48,6 @@ Inspector* InspectorManager::get_inspector(char const*, bool, const snort::Snort return nullptr; } -AppIdContext* ctxt; -AppIdContext& AppIdInspector::get_ctxt() const { return *ctxt; } void appid_log(const snort::Packet*, unsigned char, char const*, ...) { } TEST_GROUP(eve_ca_patterns_tests) diff --git a/src/network_inspectors/appid/detector_plugins/test/detector_plugins_mock.h b/src/network_inspectors/appid/detector_plugins/test/detector_plugins_mock.h index da39e9882..b765143eb 100644 --- a/src/network_inspectors/appid/detector_plugins/test/detector_plugins_mock.h +++ b/src/network_inspectors/appid/detector_plugins/test/detector_plugins_mock.h @@ -104,6 +104,7 @@ private: AppIdConfig::~AppIdConfig() = default; AppIdModule::AppIdModule() : Module("a", "b") { } +AppIdModule::~AppIdModule() = default; // LCOV_EXCL_START bool AppIdModule::begin(const char*, int, snort::SnortConfig*) diff --git a/src/network_inspectors/appid/detector_plugins/test/detector_sip_test.cc b/src/network_inspectors/appid/detector_plugins/test/detector_sip_test.cc index 398625aad..f916e9922 100644 --- a/src/network_inspectors/appid/detector_plugins/test/detector_sip_test.cc +++ b/src/network_inspectors/appid/detector_plugins/test/detector_sip_test.cc @@ -43,8 +43,8 @@ #include #include -static AppIdConfig config; -static AppIdContext context(config); +static AppIdConfig s_config; +static AppIdContext context(s_config); OdpContext* AppIdContext::odp_ctxt = nullptr; static AppIdModule appid_mod; static AppIdInspector appid_inspector(appid_mod); @@ -83,11 +83,11 @@ unsigned get_instance_id() unsigned ThreadConfig::get_instance_max() { return 1; } } -AppIdInspector::AppIdInspector(AppIdModule&) { } +AppIdInspector::AppIdInspector(AppIdModule&) : config(&s_config), ctxt(s_config) +{ } bool AppIdInspector::configure(snort::SnortConfig*) { - ctxt = &context; return true; } @@ -97,14 +97,13 @@ void AppIdInspector::show(const SnortConfig*) const { } void AppIdInspector::tinit() { } void AppIdInspector::tterm() { } void AppIdInspector::tear_down(SnortConfig*) { } -AppIdContext& AppIdInspector::get_ctxt() const { return *ctxt; } // LCOV_EXCL_STOP AppIdInspector::~AppIdInspector() = default; void AppIdContext::create_odp_ctxt() { - odp_ctxt = new OdpContext(config, nullptr); + odp_ctxt = new OdpContext(s_config, nullptr); } void AppIdContext::pterm() { delete odp_ctxt; } @@ -170,7 +169,6 @@ ClientDetector::ClientDetector() { } // LCOV_EXCL_START void ClientDetector::register_appid(int, unsigned int, OdpContext&) { } int AppIdDetector::initialize(AppIdInspector&) { return 1; } -void AppIdDetector::reload() { } int AppIdDetector::data_add(AppIdSession&, void*, void (*)(void*)) { return 1; } void AppIdDetector::add_user(AppIdSession&, char const*, int, bool, AppidChangeBits&) { } void AppIdDetector::add_payload(AppIdSession&, int) { } @@ -186,7 +184,7 @@ bool SipEvent::is_dialog_established() const { return false; } int SipPatternMatchers::get_client_from_ua(char const*, unsigned int, int&, char*&) { return 0; } // LCOV_EXCL_LINE void SipEventHandler::service_handler(SipEvent&, AppIdSession&, AppidChangeBits&) { } -void* AppIdDetector::data_get(AppIdSession&) +void* AppIdDetector::data_get(const AppIdSession&) { sip_data = new ClientSIPData(); sip_data->from = ""; diff --git a/src/network_inspectors/appid/lua_detector_api.cc b/src/network_inspectors/appid/lua_detector_api.cc index dadad937c..6be05c9dc 100644 --- a/src/network_inspectors/appid/lua_detector_api.cc +++ b/src/network_inspectors/appid/lua_detector_api.cc @@ -77,27 +77,25 @@ static CHPGlossary* old_CHP_glossary = nullptr; void init_chp_glossary() { - if(CHP_glossary) - old_CHP_glossary = CHP_glossary; + assert(!old_CHP_glossary); + old_CHP_glossary = CHP_glossary; CHP_glossary = new CHPGlossary; } static void free_chp_glossary(CHPGlossary*& glossary) { - if (!glossary) - return; - - for (auto& entry : *glossary) + if (glossary) { - if (entry.second) - snort_free(entry.second); + for (auto& entry : *glossary) + delete entry.second; + delete glossary; + glossary = nullptr; } - delete glossary; - glossary = nullptr; } -void free_current_chp_glossary(){ +void free_current_chp_glossary() +{ free_chp_glossary(CHP_glossary); } @@ -1254,11 +1252,11 @@ static int detector_get_flow(lua_State* L) // Verify detector user data and that we are in packet context LuaStateDescriptor* lsd = ud->validate_lua_state(true); - auto df = odp_thread_local_ctxt->get_lua_detector_mgr().get_detector_flow(); + auto df = odp_thread_local_ctxt->get_detector_flow(); if (!df) { df = new DetectorFlow(L, lsd->ldp.asd); - odp_thread_local_ctxt->get_lua_detector_mgr().set_detector_flow(df); + odp_thread_local_ctxt->set_detector_flow(df); } UserData::push(L, DETECTORFLOW, df); lua_pushvalue(L, -1); @@ -1769,7 +1767,7 @@ static int register_callback(lua_State* L, LuaObject& ud, AppInfoFlags flag) // Note that Lua detector objects are thread local ud.set_cb_fn_name(callback); - if (!odp_thread_local_ctxt->get_lua_detector_mgr().insert_cb_detector(app_id, &ud)) + if (!odp_thread_local_ctxt->insert_cb_detector(app_id, &ud)) { appid_log(nullptr, TRACE_ERROR_LEVEL, "AppId: detector callback already registered for app %d\n", app_id); return 1; @@ -1803,8 +1801,7 @@ static int detector_callback(const uint8_t* data, uint16_t size, AppidSessionDir return -10; } - LuaDetectorManager& lua_detector_mgr = odp_thread_local_ctxt->get_lua_detector_mgr(); - auto my_lua_state = lua_detector_mgr.L; + auto my_lua_state = odp_thread_local_ctxt->get_lua_state(); // when an ODP detector triggers the detector callback to be called, there are some elements // in the stack. Checking here to make sure the number of elements is not too many if (lua_gettop(my_lua_state) > 20) @@ -1834,8 +1831,7 @@ static int detector_callback(const uint8_t* data, uint16_t size, AppidSessionDir } // detector flows must be destroyed after each packet is processed - if (lua_detector_mgr.get_detector_flow()) - lua_detector_mgr.free_detector_flow(); + odp_thread_local_ctxt->free_detector_flow(); // retrieve result if (!lua_isnumber(my_lua_state, -1)) @@ -1865,7 +1861,7 @@ void check_detector_callback(const Packet& p, AppIdSession& asd, AppidSessionDir if (entry->flags & APPINFO_FLAG_CLIENT_DETECTOR_CALLBACK or entry->flags & APPINFO_FLAG_SERVICE_DETECTOR_CALLBACK) { - LuaObject* ud = odp_thread_local_ctxt->get_lua_detector_mgr().get_cb_detector(app_id); + LuaObject* ud = odp_thread_local_ctxt->get_cb_detector(app_id); assert(ud); if (ud->is_running()) @@ -1882,7 +1878,7 @@ void check_detector_callback(const Packet& p, AppIdSession& asd, AppidSessionDir static int create_chp_application(AppId appIdInstance, unsigned app_type_flags, int num_matches) { - CHPApp* new_app = (CHPApp*)snort_calloc(sizeof(CHPApp)); + CHPApp* new_app = new CHPApp(); new_app->appIdInstance = appIdInstance; new_app->app_type_flags = app_type_flags; new_app->num_matches = num_matches; @@ -1891,7 +1887,7 @@ static int create_chp_application(AppId appIdInstance, unsigned app_type_flags, { appid_log(nullptr, TRACE_ERROR_LEVEL, "LuaDetectorApi:Failed to add CHP for appId %d, instance %d", CHP_APPIDINSTANCE_TO_ID(appIdInstance), CHP_APPIDINSTANCE_TO_INSTANCE(appIdInstance)); - snort_free(new_app); + delete new_app; return -1; } return 0; @@ -3514,8 +3510,7 @@ int register_detector(lua_State* L) int LuaStateDescriptor::lua_validate(AppIdDiscoveryArgs& args) { - LuaDetectorManager& lua_detector_mgr = odp_thread_local_ctxt->get_lua_detector_mgr(); - auto my_lua_state = lua_detector_mgr.L; + auto my_lua_state = odp_thread_local_ctxt->get_lua_state(); if (!my_lua_state) { appid_log(args.pkt, TRACE_ERROR_LEVEL, "lua detector %s: no LUA state\n", package_info.name.c_str()); @@ -3550,14 +3545,13 @@ int LuaStateDescriptor::lua_validate(AppIdDiscoveryArgs& args) appid_log(args.pkt, TRACE_ERROR_LEVEL, "lua detector %s: error validating %s\n", package_info.name.c_str(), lua_tostring(my_lua_state, -1)); ldp.pkt = nullptr; - lua_detector_mgr.free_detector_flow(); + odp_thread_local_ctxt->free_detector_flow(); lua_settop(my_lua_state, 0); return APPID_ENULL; } /**detectorFlows must be destroyed after each packet is processed.*/ - if (lua_detector_mgr.get_detector_flow()) - lua_detector_mgr.free_detector_flow(); + odp_thread_local_ctxt->free_detector_flow(); /* retrieve result */ if (!lua_isnumber(my_lua_state, -1)) @@ -3654,7 +3648,7 @@ LuaServiceObject::LuaServiceObject(AppIdDiscovery* sdm, const std::string& detec int LuaServiceDetector::validate(AppIdDiscoveryArgs& args) { - auto my_lua_state = odp_thread_local_ctxt->get_lua_detector_mgr().L; + auto my_lua_state = odp_thread_local_ctxt->get_lua_state(); if (lua_gettop(my_lua_state)) appid_log(args.pkt, TRACE_WARNING_LEVEL, "appid: leak of %d lua stack elements before service validate\n", lua_gettop(my_lua_state)); @@ -3730,7 +3724,7 @@ LuaStateDescriptor* LuaObject::validate_lua_state(bool packet_context) int LuaClientDetector::validate(AppIdDiscoveryArgs& args) { - auto my_lua_state = odp_thread_local_ctxt->get_lua_detector_mgr().L; + auto my_lua_state = odp_thread_local_ctxt->get_lua_state(); if (lua_gettop(my_lua_state)) appid_log(args.pkt, TRACE_WARNING_LEVEL, "appid: leak of %d lua stack elements before client validate\n", lua_gettop(my_lua_state)); diff --git a/src/network_inspectors/appid/lua_detector_module.cc b/src/network_inspectors/appid/lua_detector_module.cc index 439e4d465..bae975858 100644 --- a/src/network_inspectors/appid/lua_detector_module.cc +++ b/src/network_inspectors/appid/lua_detector_module.cc @@ -54,8 +54,7 @@ using namespace std; #define OPEN_DETECTOR_PACKAGE_VERSION_FILE "version.conf" #define OPEN_DETECTOR_PACKAGE_VERSION "VERSION=" -static vector lua_detector_mgr_list; -static unordered_set lua_detectors_w_validate; +vector> ControlLuaDetectorManager::lua_detector_mgr_list; bool get_lua_field(lua_State* L, int table, const char* field, string& out) { @@ -179,26 +178,27 @@ static void scan_and_print_odp_version(const char* app_detector_dir) version_file.close(); } -LuaDetectorManager::LuaDetectorManager(AppIdContext& ctxt, bool is_control) : - ctxt(ctxt) +LuaDetectorManager::LuaDetectorManager(AppIdContext& ctxt, bool is_control) : ctxt(ctxt) { - allocated_objects.clear(); - cb_detectors.clear(); L = create_lua_state(ctxt.config, is_control); - if (is_control) - init_chp_glossary(); + if (!L) + { + if (is_control) + appid_log(nullptr, TRACE_CRITICAL_LEVEL, + "Error - appid: can not create new luaState, control instance\n"); + else + appid_log(nullptr, TRACE_ERROR_LEVEL, + "Error - appid: can not create new luaState, instance=%u\n", get_instance_id()); + } } LuaDetectorManager::~LuaDetectorManager() { - if (lua_gettop(L)) - appid_log(nullptr, TRACE_WARNING_LEVEL, "appid: leak of %d lua stack elements before detector unload\n", - lua_gettop(L)); - if (L) { - if (init(L) and !ignore_chp_cleanup) - free_current_chp_glossary(); + if (lua_gettop(L)) + appid_log(nullptr, TRACE_WARNING_LEVEL, "appid: leak of %d lua stack elements before detector unload\n", + lua_gettop(L)); for ( auto& lua_object : allocated_objects ) { @@ -223,70 +223,18 @@ LuaDetectorManager::~LuaDetectorManager() lua_close(L); } - if (detector_flow) - free_detector_flow(); - allocated_objects.clear(); - cb_detectors.clear(); // do not free Lua objects in cb_detectors } -void LuaDetectorManager::initialize(const SnortConfig* sc, AppIdContext& ctxt, bool is_control, - bool reload) +void LuaDetectorManager::initialize(const SnortConfig* sc) { - LuaDetectorManager* lua_detector_mgr = new LuaDetectorManager(ctxt, is_control); - odp_thread_local_ctxt->set_lua_detector_mgr(*lua_detector_mgr); - - if (!lua_detector_mgr->L) - appid_log(nullptr, is_control? TRACE_CRITICAL_LEVEL : TRACE_ERROR_LEVEL, - "Error - appid: can not create new luaState, instance=%u\n", get_instance_id()); - - if (reload) - { - appid_log(nullptr, TRACE_INFO_LEVEL, "AppId Lua-Detectors : loading lua detectors in control thread\n"); - unsigned max_threads = ThreadConfig::get_instance_max(); - for (unsigned i = 0 ; i < max_threads; i++) - { - lua_detector_mgr_list.emplace_back(new LuaDetectorManager(ctxt, 0)); - - if (!lua_detector_mgr_list[i]->L) - appid_log(nullptr, TRACE_CRITICAL_LEVEL, "Error - appid: can not create new luaState, instance=%u\n", i); - - } - } - - lua_detector_mgr->initialize_lua_detectors(is_control, reload); - lua_detector_mgr->activate_lua_detectors(sc); - + activate_lua_detectors(sc); + if (SnortConfig::log_verbose()) scan_and_print_odp_version(ctxt.config.app_detector_dir); if (ctxt.config.list_odp_detectors or SnortConfig::log_verbose()) - lua_detector_mgr->list_lua_detectors(); -} - -void LuaDetectorManager::init_thread_manager(const SnortConfig* sc, const AppIdContext& ctxt) -{ - LuaDetectorManager* lua_detector_mgr = lua_detector_mgr_list[get_instance_id()]; - odp_thread_local_ctxt->set_lua_detector_mgr(*lua_detector_mgr); - lua_detector_mgr->activate_lua_detectors(sc); - if (ctxt.config.list_odp_detectors) - lua_detector_mgr->list_lua_detectors(); -} - -void LuaDetectorManager::cleanup_after_swap() -{ - free_old_chp_glossary(); -} - -void LuaDetectorManager::clear_lua_detector_mgrs() -{ - lua_detector_mgr_list.clear(); -} - -void LuaDetectorManager::free_detector_flow() -{ - delete detector_flow; - detector_flow = nullptr; + list_lua_detectors(); } bool LuaDetectorManager::insert_cb_detector(AppId app_id, LuaObject* cb_detector) @@ -447,9 +395,9 @@ static int dump(lua_State*, const void* buf,size_t size, void* data) return 0; } -bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom, bool is_control, bool reload, string& buf) +bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom, string& buf) { - if (reload and !buf.empty()) + if (!buf.empty()) { if (luaL_loadbuffer(L, buf.c_str(), buf.length(), detector_filename)) { @@ -461,13 +409,6 @@ bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom, } else { - if (!is_control) - { - auto iter = lua_detectors_w_validate.find(detector_filename); - if (iter == lua_detectors_w_validate.end()) - return false; - } - if (luaL_loadfile(L, detector_filename)) { if (init(L)) @@ -475,7 +416,7 @@ bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom, lua_pop(L, 1); return false; } - if (reload and lua_dump(L, dump, &buf)) + if (lua_dump(L, dump, &buf)) { if (init(L)) appid_log(nullptr, TRACE_ERROR_LEVEL, "Error - appid: can not compile Lua detector, %s\n", lua_tostring(L, -1)); @@ -521,7 +462,62 @@ bool LuaDetectorManager::load_detector(char* detector_filename, bool is_custom, return has_validate; } -void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bool is_control, bool reload) +void LuaDetectorManager::activate_lua_detectors(const SnortConfig* sc) +{ + if (lua_gettop(L)) + appid_log(nullptr, TRACE_WARNING_LEVEL, "appid: leak of %d lua stack elements before detector activate\n", + lua_gettop(L)); + + uint32_t lua_tracker_size = compute_lua_tracker_size(MAX_MEMORY_FOR_LUA_DETECTORS, allocated_objects.size()); + list::iterator lo = allocated_objects.begin(); + while (lo != allocated_objects.end()) + { + LuaStateDescriptor* lsd = (*lo)->validate_lua_state(false); + lua_getfield(L, LUA_REGISTRYINDEX, lsd->package_info.name.c_str()); + lua_getfield(L, -1, lsd->package_info.initFunctionName.c_str()); + if (!lua_isfunction(L, -1)) + { + if (init(L)) + appid_log(nullptr, TRACE_ERROR_LEVEL, "Error - appid: can not load DetectorInit function from %s\n", + (*lo)->get_detector()->get_name().c_str()); + if (!(*lo)->get_detector()->is_custom_detector()) + num_odp_detectors--; + lua_settop(L, 0); + delete *lo; + lo = allocated_objects.erase(lo); + continue; + } + + /*first parameter is DetectorUserData */ + string name = lsd->package_info.name + "_"; + lua_getglobal(L, name.c_str()); + + /*second parameter is a table containing configuration stuff. */ + lua_newtable(L); + const SnortConfig** sc_ud = static_cast(lua_newuserdata(L, sizeof(const SnortConfig*))); + *(sc_ud) = sc; + lua_setglobal(L, LUA_STATE_GLOBAL_SC_ID); + if (lua_pcall(L, 2, 1, 0)) + { + if (init(L)) + appid_log(nullptr, TRACE_ERROR_LEVEL, "Error - appid: can not run DetectorInit, %s\n", lua_tostring(L, -1)); + if (!(*lo)->get_detector()->is_custom_detector()) + num_odp_detectors--; + lua_settop(L, 0); + delete *lo; + lo = allocated_objects.erase(lo); + continue; + } + *(sc_ud) = nullptr; + + lua_getfield(L, LUA_REGISTRYINDEX, lsd->package_info.name.c_str()); + set_lua_tracker_size(L, lua_tracker_size); + lua_settop(L, 0); + ++lo; + } +} + +void ControlLuaDetectorManager::load_lua_detectors(const char* path, bool is_custom) { char pattern[PATH_MAX]; snprintf(pattern, sizeof(pattern), "%s/*", path); @@ -535,7 +531,6 @@ void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bo appid_log(nullptr, TRACE_WARNING_LEVEL, "appid: leak of %d lua stack elements before detector load\n", lua_gettop(L)); - string buf; for (unsigned n = 0; n < globs.gl_pathc; n++) { ifstream file(globs.gl_pathv[n], ios::ate); @@ -572,19 +567,14 @@ void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bo // During reload, load_lua_detectors() gets called only for control thread. This // function loads detectors for all the packet threads too during reload. It skips // loading detectors that don't have validate for packet threads. - bool has_validate = load_detector(globs.gl_pathv[n], is_custom, is_control, reload, buf); + string buf; + bool has_validate = load_detector(globs.gl_pathv[n], is_custom, buf); - if (reload) + for (auto& lua_detector_mgr : lua_detector_mgr_list) { - for (auto& lua_detector_mgr : lua_detector_mgr_list) - { - if (has_validate) - lua_detector_mgr->load_detector(globs.gl_pathv[n], is_custom, is_control, reload, buf); - } - buf.clear(); + if (has_validate) + lua_detector_mgr->load_detector(globs.gl_pathv[n], is_custom, buf); } - else if (is_control and has_validate) - lua_detectors_w_validate.insert(globs.gl_pathv[n]); lua_settop(L, 0); } @@ -598,7 +588,7 @@ void LuaDetectorManager::load_lua_detectors(const char* path, bool is_custom, bo pattern, rval); } -void LuaDetectorManager::initialize_lua_detectors(bool is_control, bool reload) +void ControlLuaDetectorManager::initialize_lua_detectors() { char path[PATH_MAX]; const char* dir = ctxt.config.app_detector_dir; @@ -607,76 +597,67 @@ void LuaDetectorManager::initialize_lua_detectors(bool is_control, bool reload) return; snprintf(path, sizeof(path), "%s/odp/lua", dir); - load_lua_detectors(path, false, is_control, reload); - num_odp_detectors = allocated_objects.size(); + load_lua_detectors(path, false); + set_num_odp_detectors(); + for (auto& mgr : lua_detector_mgr_list) + mgr->set_num_odp_detectors(); - if (reload) - { - for (auto& lua_detector_mgr : lua_detector_mgr_list) - lua_detector_mgr->num_odp_detectors = lua_detector_mgr->allocated_objects.size(); - } snprintf(path, sizeof(path), "%s/custom/lua", dir); - load_lua_detectors(path, true, is_control, reload); + load_lua_detectors(path, true); } -void LuaDetectorManager::activate_lua_detectors(const SnortConfig* sc) +ControlLuaDetectorManager::ControlLuaDetectorManager(AppIdContext& appid_ctxt) : LuaDetectorManager(appid_ctxt, true) +{ init_chp_glossary(); } + +ControlLuaDetectorManager::~ControlLuaDetectorManager() { - uint32_t lua_tracker_size = compute_lua_tracker_size(MAX_MEMORY_FOR_LUA_DETECTORS, - allocated_objects.size()); - list::iterator lo = allocated_objects.begin(); + clear_lua_detector_mgrs(); + if (!ignore_chp_cleanup) + free_current_chp_glossary(); +} - if (lua_gettop(L)) - appid_log(nullptr, TRACE_WARNING_LEVEL, "appid: leak of %d lua stack elements before detector activate\n", - lua_gettop(L)); +void ControlLuaDetectorManager::initialize(const SnortConfig* sc) +{ + unsigned max_threads = ThreadConfig::get_instance_max(); + for (unsigned i = 0 ; i < max_threads; i++) + lua_detector_mgr_list.emplace_back(make_shared(ctxt)); - while (lo != allocated_objects.end()) - { - LuaStateDescriptor* lsd = (*lo)->validate_lua_state(false); - lua_getfield(L, LUA_REGISTRYINDEX, lsd->package_info.name.c_str()); - lua_getfield(L, -1, lsd->package_info.initFunctionName.c_str()); - if (!lua_isfunction(L, -1)) - { - if (init(L)) - appid_log(nullptr, TRACE_ERROR_LEVEL, "Error - appid: can not load DetectorInit function from %s\n", - (*lo)->get_detector()->get_name().c_str()); - if (!(*lo)->get_detector()->is_custom_detector()) - num_odp_detectors--; - lua_settop(L, 0); - delete *lo; - lo = allocated_objects.erase(lo); - continue; - } + initialize_lua_detectors(); + LuaDetectorManager::initialize(sc); +} - /*first parameter is DetectorUserData */ - string name = lsd->package_info.name + "_"; - lua_getglobal(L, name.c_str()); +void ControlLuaDetectorManager::list_lua_detectors() +{ - /*second parameter is a table containing configuration stuff. */ - lua_newtable(L); - const SnortConfig** sc_ud = static_cast(lua_newuserdata(L, sizeof(const SnortConfig*))); - *(sc_ud) = sc; - lua_setglobal(L, LUA_STATE_GLOBAL_SC_ID); - if (lua_pcall(L, 2, 1, 0)) - { - if (init(L)) - appid_log(nullptr, TRACE_ERROR_LEVEL, "Error - appid: can not run DetectorInit, %s\n", lua_tostring(L, -1)); - if (!(*lo)->get_detector()->is_custom_detector()) - num_odp_detectors--; - lua_settop(L, 0); - delete *lo; - lo = allocated_objects.erase(lo); - continue; - } - *(sc_ud) = nullptr; + #ifdef REG_TEST + // Lua memory usage is inconsistent, for ease of testing lets print 0 instead. + int memory_used_by_lua = 0; + #else + int memory_used_by_lua = lua_gc(L, LUA_GCCOUNT, 0); + #endif - lua_getfield(L, LUA_REGISTRYINDEX, lsd->package_info.name.c_str()); - set_lua_tracker_size(L, lua_tracker_size); - lua_settop(L, 0); - ++lo; - } + appid_log(nullptr, TRACE_INFO_LEVEL, "AppId Lua-Detector Stats: control instance, odp detectors %zu, custom detectors %zu," + " total memory %d kb\n", num_odp_detectors, (allocated_objects.size() - num_odp_detectors), memory_used_by_lua); } -void LuaDetectorManager::list_lua_detectors() +void ControlLuaDetectorManager::cleanup_after_swap() +{ + free_old_chp_glossary(); +} + +void ControlLuaDetectorManager::clear_lua_detector_mgrs() +{ + lua_detector_mgr_list.clear(); +} + +std::shared_ptr ControlLuaDetectorManager::get_packet_lua_detector_manager() +{ + unsigned instance_id = get_instance_id(); + std::shared_ptr mgr = lua_detector_mgr_list[instance_id]; + return static_cast>(mgr); +} + +void PacketLuaDetectorManager::list_lua_detectors() { #ifdef REG_TEST @@ -691,3 +672,9 @@ void LuaDetectorManager::list_lua_detectors() (allocated_objects.size() - num_odp_detectors), memory_used_by_lua); } +void PacketLuaDetectorManager::free_detector_flow() +{ + delete detector_flow; + detector_flow = nullptr; +} + diff --git a/src/network_inspectors/appid/lua_detector_module.h b/src/network_inspectors/appid/lua_detector_module.h index 85582240f..2ce5118cd 100644 --- a/src/network_inspectors/appid/lua_detector_module.h +++ b/src/network_inspectors/appid/lua_detector_module.h @@ -25,7 +25,9 @@ #include #include #include +#include #include +#include #include #include @@ -52,48 +54,75 @@ bool get_lua_field(lua_State* L, int table, const char* field, IpProtocol& out); class LuaDetectorManager { public: - LuaDetectorManager(AppIdContext&, bool); - ~LuaDetectorManager(); - static void initialize(const snort::SnortConfig*, AppIdContext&, bool is_control=false, - bool reload=false); - static void init_thread_manager(const snort::SnortConfig*, const AppIdContext&); - static void cleanup_after_swap(); - static void clear_lua_detector_mgrs(); - - void set_detector_flow(DetectorFlow* df) - { - detector_flow = df; - } - - DetectorFlow* get_detector_flow() - { - return detector_flow; - } - - void set_ignore_chp_cleanup(bool value) - { - ignore_chp_cleanup = value; - } + LuaDetectorManager(AppIdContext&, bool is_control); + virtual ~LuaDetectorManager(); + virtual void initialize(const snort::SnortConfig*); - void free_detector_flow(); - lua_State* L; + bool load_detector(char* detector_name, bool is_custom, std::string& buf); + void set_num_odp_detectors() + { num_odp_detectors = allocated_objects.size(); } bool insert_cb_detector(AppId app_id, LuaObject* ud); LuaObject* get_cb_detector(AppId app_id); -private: - void initialize_lua_detectors(bool is_control, bool reload = false); + lua_State* L; + +protected: void activate_lua_detectors(const snort::SnortConfig*); - void list_lua_detectors(); - bool load_detector(char* detector_name, bool is_custom, bool is_control, bool reload, std::string& buf); - void load_lua_detectors(const char* path, bool is_custom, bool is_control, bool reload = false); LuaObject* create_lua_detector(const char* detector_name, bool is_custom, const char* detector_filename, bool& has_validate); + virtual void list_lua_detectors() = 0; AppIdContext& ctxt; std::list allocated_objects; size_t num_odp_detectors = 0; std::map cb_detectors; +}; + +class PacketLuaDetectorManager : public LuaDetectorManager +{ +public: + explicit PacketLuaDetectorManager(AppIdContext& appid_ctxt) : LuaDetectorManager(appid_ctxt, false) + { } + ~PacketLuaDetectorManager() override + { free_detector_flow(); } + + void set_detector_flow(DetectorFlow* df) + { detector_flow = df; } + + DetectorFlow* get_detector_flow() const + { return detector_flow; } + + void free_detector_flow(); + +private: + void list_lua_detectors() override; + DetectorFlow* detector_flow = nullptr; +}; + +class ControlLuaDetectorManager : public LuaDetectorManager +{ +public: + explicit ControlLuaDetectorManager(AppIdContext&); + ~ControlLuaDetectorManager() override; + void initialize(const snort::SnortConfig*) override; + + static std::shared_ptr get_packet_lua_detector_manager(); + static void clear_lua_detector_mgrs(); + static void cleanup_after_swap(); + + void set_ignore_chp_cleanup() + { + ignore_chp_cleanup = true; + } + +private: + static std::vector> lua_detector_mgr_list; + + void initialize_lua_detectors(); + void load_lua_detectors(const char* path, bool is_custom); + void list_lua_detectors() override; + bool ignore_chp_cleanup = false; }; diff --git a/src/network_inspectors/appid/service_plugins/service_discovery.cc b/src/network_inspectors/appid/service_plugins/service_discovery.cc index db9ec46b7..3f4390e79 100644 --- a/src/network_inspectors/appid/service_plugins/service_discovery.cc +++ b/src/network_inspectors/appid/service_plugins/service_discovery.cc @@ -141,9 +141,9 @@ void ServiceDiscovery::initialize(AppIdInspector& inspector) void ServiceDiscovery::reload() { for ( auto& kv : tcp_detectors ) - kv.second->reload(); + kv.second->do_custom_reload(); for ( auto& kv : udp_detectors ) - kv.second->reload(); + kv.second->do_custom_reload(); } void ServiceDiscovery::finalize_service_patterns() diff --git a/src/network_inspectors/appid/service_plugins/test/alpn_patterns_tests.cc b/src/network_inspectors/appid/service_plugins/test/alpn_patterns_tests.cc index ebde9bed1..b1a19e687 100644 --- a/src/network_inspectors/appid/service_plugins/test/alpn_patterns_tests.cc +++ b/src/network_inspectors/appid/service_plugins/test/alpn_patterns_tests.cc @@ -48,8 +48,6 @@ Inspector* InspectorManager::get_inspector(char const*, bool, const snort::Snort return nullptr; } -AppIdContext* ctxt; -AppIdContext& AppIdInspector::get_ctxt() const { return *ctxt; } void appid_log(const snort::Packet*, unsigned char, char const*, ...) { } TEST_GROUP(alpn_patterns_tests) diff --git a/src/network_inspectors/appid/service_plugins/test/service_plugin_mock.h b/src/network_inspectors/appid/service_plugins/test/service_plugin_mock.h index 2e1dc7197..a76a6a976 100644 --- a/src/network_inspectors/appid/service_plugins/test/service_plugin_mock.h +++ b/src/network_inspectors/appid/service_plugins/test/service_plugin_mock.h @@ -90,9 +90,8 @@ void ClientDiscovery::reload() {} FpSMBData* smb_data = nullptr; int AppIdDetector::initialize(AppIdInspector&){return 0;} -void AppIdDetector::reload() { } int AppIdDetector::data_add(AppIdSession&, void*, AppIdFreeFCN){return 0;} -void* AppIdDetector::data_get(AppIdSession&) {return nullptr;} +void* AppIdDetector::data_get(const AppIdSession&) {return nullptr;} void AppIdDetector::add_user(AppIdSession&, const char*, AppId, bool, AppidChangeBits&){} void AppIdDetector::add_payload(AppIdSession&, AppId){} void AppIdDetector::add_app(const snort::Packet&, AppIdSession&, AppidSessionDirection, AppId, AppId, const char*, AppidChangeBits&){} diff --git a/src/network_inspectors/appid/test/appid_api_test.cc b/src/network_inspectors/appid/test/appid_api_test.cc index a547d6239..31e8a4b1f 100644 --- a/src/network_inspectors/appid/test/appid_api_test.cc +++ b/src/network_inspectors/appid/test/appid_api_test.cc @@ -425,11 +425,6 @@ TEST(appid_api, is_service_http_type) CHECK_FALSE(appid_api.is_service_http_type(APP_ID_SMTP)); } -TEST(appid_api, get_appid_detector_directory) -{ - STRCMP_EQUAL(appid_api.get_appid_detector_directory(), "/path/to/appid/detectors/"); -} - int main(int argc, char** argv) { int rc = CommandLineTestRunner::RunAllTests(argc, argv); diff --git a/src/network_inspectors/appid/test/appid_discovery_test.cc b/src/network_inspectors/appid/test/appid_discovery_test.cc index 92ae8dd1d..f6f7fbeb1 100644 --- a/src/network_inspectors/appid/test/appid_discovery_test.cc +++ b/src/network_inspectors/appid/test/appid_discovery_test.cc @@ -165,6 +165,7 @@ void ClientAppDescriptor::update_user(AppId, const char*, AppidChangeBits&){} // Stubs for AppIdModule AppIdModule::AppIdModule(): Module("appid_mock", "appid_mock_help") {} +AppIdModule::~AppIdModule() = default; void AppIdModule::sum_stats(bool) {} void AppIdModule::show_dynamic_stats() {} bool AppIdModule::begin(char const*, int, SnortConfig*) { return true; } @@ -180,7 +181,6 @@ THREAD_LOCAL bool ThirdPartyAppIdContext::tp_reload_in_progress = false; // Stubs for config static AppIdConfig app_config; -static AppIdContext app_ctxt(app_config); AppId OdpContext::get_port_service_id(IpProtocol, uint16_t) { return APP_ID_NONE; @@ -195,7 +195,8 @@ AppId OdpContext::get_protocol_service_id(IpProtocol) } // Stubs for AppIdInspector -AppIdInspector::AppIdInspector(AppIdModule&) { ctxt = &stub_ctxt; } +AppIdInspector::AppIdInspector(AppIdModule&) : config(&app_config), ctxt(app_config) +{ } AppIdInspector::~AppIdInspector() = default; void AppIdInspector::eval(Packet*) { } bool AppIdInspector::configure(SnortConfig*) { return true; } @@ -203,11 +204,6 @@ void AppIdInspector::show(const SnortConfig*) const { } void AppIdInspector::tinit() { } void AppIdInspector::tterm() { } void AppIdInspector::tear_down(SnortConfig*) { } -AppIdContext& AppIdInspector::get_ctxt() const -{ - assert(ctxt); - return *ctxt; -} bool DiscoveryFilter::is_app_monitored(const snort::Packet*, uint8_t*){return true;} // Stubs for AppInfoManager @@ -399,6 +395,7 @@ TEST(appid_discovery_tests, event_published_when_ignoring_flow) p.ptrs.ip_api.set(ip, ip); AppIdModule app_module; AppIdInspector ins(app_module); + AppIdContext& app_ctxt = ins.get_ctxt(); AppIdSession* asd = new AppIdSession(IpProtocol::TCP, &ip, 21, ins, app_ctxt.get_odp_ctxt(), 0, 0); asd->flags |= APPID_SESSION_SPECIAL_MONITORED | APPID_SESSION_DISCOVER_USER | APPID_SESSION_DISCOVER_APP; @@ -434,6 +431,7 @@ TEST(appid_discovery_tests, event_published_when_processing_flow) p.ptrs.tcph = nullptr; AppIdModule app_module; AppIdInspector ins(app_module); + AppIdContext& app_ctxt = ins.get_ctxt(); AppIdSession* asd = new AppIdSession(IpProtocol::TCP, &ip, 21, ins, app_ctxt.get_odp_ctxt(), 0, 0); asd->flags |= APPID_SESSION_SPECIAL_MONITORED | APPID_SESSION_DISCOVER_USER | APPID_SESSION_DISCOVER_APP; @@ -459,6 +457,7 @@ TEST(appid_discovery_tests, change_bits_for_client_version) AppIdModule app_module; AppIdInspector ins(app_module); SfIp ip; + AppIdContext app_ctxt(app_config); AppIdSession* asd = new AppIdSession(IpProtocol::TCP, &ip, 21, ins, app_ctxt.get_odp_ctxt(), 0, 0); const char* version = "3.0"; asd->set_client_version(version, change_bits); @@ -494,6 +493,7 @@ TEST(appid_discovery_tests, change_bits_for_non_http_appid) p.ptrs.ip_api.set(ip, ip); AppIdModule app_module; AppIdInspector ins(app_module); + AppIdContext& app_ctxt = ins.get_ctxt(); AppIdSession* asd = new AppIdSession(IpProtocol::TCP, &ip, 21, ins, app_ctxt.get_odp_ctxt(), 0, 0); asd->flags |= APPID_SESSION_SPECIAL_MONITORED | APPID_SESSION_DISCOVER_USER | APPID_SESSION_DISCOVER_APP; diff --git a/src/network_inspectors/appid/test/appid_mock_inspector.h b/src/network_inspectors/appid/test/appid_mock_inspector.h index be225ef58..5c6dc3a52 100644 --- a/src/network_inspectors/appid/test/appid_mock_inspector.h +++ b/src/network_inspectors/appid/test/appid_mock_inspector.h @@ -48,7 +48,8 @@ PegCount Module::get_global_count(char const*) const { return 0; } } -AppIdModule::AppIdModule(): snort::Module("appid_mock", "appid_mock_help") {} +AppIdModule::AppIdModule(): snort::Module("appid_mock", "appid_mock_help") { } +AppIdModule::~AppIdModule() = default; void AppIdModule::sum_stats(bool) {} void AppIdModule::show_dynamic_stats() {} bool AppIdModule::begin(char const*, int, snort::SnortConfig*) { return true; } @@ -62,6 +63,9 @@ snort::ProfileStats* AppIdModule::get_profile( void AppIdModule::set_trace(const Trace*) const { } const TraceOption* AppIdModule::get_trace_options() const { return nullptr; } +AppIdConfig appid_config; +AppIdInspector::AppIdInspector(AppIdModule&) : config(&appid_config), ctxt(appid_config) +{ } AppIdInspector::~AppIdInspector() = default; void AppIdInspector::eval(snort::Packet*) { } bool AppIdInspector::configure(snort::SnortConfig*) { return true; } @@ -69,19 +73,9 @@ void AppIdInspector::show(const SnortConfig*) const { } void AppIdInspector::tinit() { } void AppIdInspector::tterm() { } void AppIdInspector::tear_down(snort::SnortConfig*) { } -AppIdContext& AppIdInspector::get_ctxt() const { return *ctxt; } AppIdModule appid_mod; -AppIdConfig appid_config; -AppIdContext appid_ctxt(appid_config); -THREAD_LOCAL OdpContext* pkt_thread_odp_ctxt = nullptr; AppIdInspector dummy_appid_inspector( appid_mod ); - -AppIdInspector::AppIdInspector(AppIdModule& ) -{ - ctxt = &appid_ctxt; - appid_config.app_detector_dir = "/path/to/appid/detectors/"; - config = &appid_config; -} +THREAD_LOCAL OdpContext* pkt_thread_odp_ctxt = nullptr; #endif diff --git a/src/network_inspectors/appid/test/appid_mock_session.h b/src/network_inspectors/appid/test/appid_mock_session.h index 61d6ec10b..f65bd5bc6 100644 --- a/src/network_inspectors/appid/test/appid_mock_session.h +++ b/src/network_inspectors/appid/test/appid_mock_session.h @@ -77,7 +77,6 @@ OdpContext::OdpContext(const AppIdConfig&, snort::SnortConfig*) { } void FlowHAState::add(uint8_t) { } static AppIdConfig stub_config; -static AppIdContext stub_ctxt(stub_config); static OdpContext stub_odp_ctxt(stub_config, nullptr); OdpContext* AppIdContext::odp_ctxt = &stub_odp_ctxt; AppIdSession::AppIdSession(IpProtocol proto, const SfIp* ip, uint16_t, AppIdInspector& inspector,