From: Ron Dempster (rdempste) Date: Fri, 15 Apr 2022 15:26:44 +0000 (+0000) Subject: Pull request #3371: Fix most of the perf drop from multi-tenant code X-Git-Tag: 3.1.28.0~15 X-Git-Url: http://git.ipfire.org/cgi-bin/gitweb.cgi?a=commitdiff_plain;h=36112e18b21b56299081cd3e3aaa211e7dfb76d1;p=thirdparty%2Fsnort3.git Pull request #3371: Fix most of the perf drop from multi-tenant code Merge in SNORT/snort3 from ~RDEMPSTE/snort3:perf to master Squashed commit of the following: commit c14d36a3e41f083d4a80199b22b40b601166419f Author: Ron Dempster (rdempste) Date: Mon Apr 11 09:58:36 2022 -0400 flow: only select policies when deleting flow data if there is a policy selector commit c38b0b61f1a9b8a7e359ff81a5468a59567a5260 Author: Ron Dempster (rdempste) Date: Sun Apr 10 16:26:12 2022 -0400 flow, snort_config: change service back to a pointer and add a method to return a non-volatile pointer for service commit a9b120ee80a12c64e59f475f56db4477ffc88c08 Author: Ron Dempster (rdempste) Date: Thu Apr 7 11:14:26 2022 -0400 flow: use a flag instead off shared pointer use count for has service check commit 429fa43a6346f6e67e2ddb98238e2fc1f340aaa3 Author: Ron Dempster (rdempste) Date: Fri Apr 1 12:32:23 2022 -0400 flow, managers, binder: only publish flow state reloaded event from internal execute commit 4f2429b5140895ea377a49029e387f5b509de5ca Author: Ron Dempster (rdempste) Date: Thu Mar 31 14:09:29 2022 -0400 main: check policy exists instead of index when setting network policy by id --- diff --git a/src/flow/flow.cc b/src/flow/flow.cc index 916a5b0d9..0b945ab83 100644 --- a/src/flow/flow.cc +++ b/src/flow/flow.cc @@ -120,7 +120,7 @@ void Flow::term() stash = nullptr; } - service.reset(); + service = nullptr; } inline void Flow::clean() @@ -214,7 +214,6 @@ void Flow::reset(bool do_cleanup) stash->reset(); deferred_trust.clear(); - service.reset(); constexpr size_t offset = offsetof(Flow, context_chain); // FIXIT-L need a struct to zero here to make future proof @@ -330,23 +329,32 @@ void Flow::free_flow_data(uint32_t proto) void Flow::free_flow_data() { - NetworkPolicy* np = get_network_policy(); - InspectionPolicy* ip = get_inspection_policy(); - IpsPolicy* ipsp = get_ips_policy(); - - unsigned t_reload_id = SnortConfig::get_thread_reload_id(); - if (reload_id == t_reload_id) + const SnortConfig* sc = SnortConfig::get_conf(); + PolicySelector* ps = sc->policy_map->get_policy_selector(); + NetworkPolicy* np = nullptr; + InspectionPolicy* ip = nullptr; + IpsPolicy* ipsp = nullptr; + if (ps) { - ::set_network_policy(network_policy_id); - ::set_inspection_policy(inspection_policy_id); - ::set_ips_policy(SnortConfig::get_conf(), ips_policy_id); - } - else - { - _daq_pkt_hdr pkthdr = {}; - pkthdr.address_space_id = key->addressSpaceId; - select_default_policy(pkthdr, SnortConfig::get_conf()); + np = get_network_policy(); + ip = get_inspection_policy(); + ipsp = get_ips_policy(); + + unsigned t_reload_id = SnortConfig::get_thread_reload_id(); + if (reload_id == t_reload_id) + { + ::set_network_policy(network_policy_id); + ::set_inspection_policy(inspection_policy_id); + ::set_ips_policy(sc, ips_policy_id); + } + else + { + _daq_pkt_hdr pkthdr = {}; + pkthdr.address_space_id = key->addressSpaceId; + select_default_policy(pkthdr, sc); + } } + while (flow_data) { FlowData* tmp = flow_data; @@ -354,9 +362,12 @@ void Flow::free_flow_data() delete tmp; } - set_network_policy(np); - set_inspection_policy(ip); - set_ips_policy(ipsp); + if (ps) + { + set_network_policy(np); + set_inspection_policy(ip); + set_ips_policy(ipsp); + } } void Flow::call_handlers(Packet* p, bool eof) @@ -571,21 +582,12 @@ bool Flow::is_direction_aborted(bool from_client) const return (session_flags & SSNFLAG_ABORT_CLIENT); } -void Flow::set_service(Packet* pkt, std::shared_ptr new_service) +void Flow::set_service(Packet* pkt, const char* new_service) { - if (!new_service.use_count()) - return clear_service(pkt); - service = new_service; DataBus::publish(FLOW_SERVICE_CHANGE_EVENT, pkt); } -void Flow::clear_service(Packet* pkt) -{ - service.reset(); - DataBus::publish(FLOW_SERVICE_CHANGE_EVENT, pkt); -} - void Flow::swap_roles() { std::swap(flowstats.client_pkts, flowstats.server_pkts); diff --git a/src/flow/flow.h b/src/flow/flow.h index 9b13f8d75..710ffe28e 100644 --- a/src/flow/flow.h +++ b/src/flow/flow.h @@ -27,7 +27,6 @@ // state. Inspector state is stored in FlowData, and Flow manages a list // of FlowData items. -#include #include #include @@ -200,10 +199,7 @@ public: void set_mpls_layer_per_dir(Packet*); Layer get_mpls_layer_per_dir(bool); void swap_roles(); - void set_service(Packet*, std::shared_ptr new_service); - void clear_service(Packet*); - bool has_service() const - { return 0 != service.use_count(); } + void set_service(Packet*, const char* new_service); bool get_attr(const std::string& key, int32_t& val); bool get_attr(const std::string& key, std::string& val); void set_attr(const std::string& key, const int32_t& val); @@ -406,7 +402,6 @@ public: // FIXIT-M privatize if possible // void space and allow for memset of tail end of struct DeferredTrust deferred_trust; - std::shared_ptr service; // Anything before this comment is not zeroed during construction const FlowKey* key; @@ -441,6 +436,7 @@ public: // FIXIT-M privatize if possible Inspector* gadget; // service handler Inspector* assistant_gadget; Inspector* data; + const char* service; uint64_t expire_time; diff --git a/src/flow/flow_control.cc b/src/flow/flow_control.cc index 6324ca993..ddbc45bbd 100644 --- a/src/flow/flow_control.cc +++ b/src/flow/flow_control.cc @@ -437,11 +437,7 @@ unsigned FlowControl::process(Flow* flow, Packet* p) { unsigned reload_id = SnortConfig::get_thread_reload_id(); if (flow->reload_id != reload_id) - { flow->network_policy_id = get_network_policy()->policy_id; - if (flow->flow_state == Flow::FlowState::INSPECT) - DataBus::publish(FLOW_STATE_RELOADED_EVENT, p, flow); - } else { set_inspection_policy(flow->inspection_policy_id); diff --git a/src/loggers/alert_csv.cc b/src/loggers/alert_csv.cc index e786c5b5e..118ea3454 100644 --- a/src/loggers/alert_csv.cc +++ b/src/loggers/alert_csv.cc @@ -343,8 +343,8 @@ static void ff_server_pkts(const Args& a) static void ff_service(const Args& a) { const char* svc = "unknown"; - if ( a.pkt->flow and a.pkt->flow->has_service() ) - svc = a.pkt->flow->service->c_str(); + if ( a.pkt->flow and a.pkt->flow->service ) + svc = a.pkt->flow->service; TextLog_Puts(csv_log, svc); } diff --git a/src/loggers/alert_json.cc b/src/loggers/alert_json.cc index bb1d13599..5bc6422f3 100644 --- a/src/loggers/alert_json.cc +++ b/src/loggers/alert_json.cc @@ -473,8 +473,8 @@ static bool ff_service(const Args& a) { const char* svc = "unknown"; - if ( a.pkt->flow and a.pkt->flow->has_service() ) - svc = a.pkt->flow->service->c_str(); + if ( a.pkt->flow and a.pkt->flow->service ) + svc = a.pkt->flow->service; print_label(a, "service"); TextLog_Quote(json_log, svc); diff --git a/src/main/policy.cc b/src/main/policy.cc index ee19cd08d..2e62e3d8b 100644 --- a/src/main/policy.cc +++ b/src/main/policy.cc @@ -459,9 +459,9 @@ IpsPolicy* get_empty_ips_policy(const SnortConfig* sc) void set_network_policy(unsigned i) { PolicyMap* pm = SnortConfig::get_conf()->policy_map; - - if ( i < pm->network_policy_count() ) - set_network_policy(pm->get_network_policy(i)); + NetworkPolicy* np = pm->get_network_policy(i); + if ( np ) + set_network_policy(np); } void set_inspection_policy(unsigned i) diff --git a/src/main/snort_config.cc b/src/main/snort_config.cc index d2c7c7a8c..ce4d2c7b0 100644 --- a/src/main/snort_config.cc +++ b/src/main/snort_config.cc @@ -24,8 +24,10 @@ #include "snort_config.h" #include +#include #include #include +#include #include "actions/ips_actions.h" #include "detection/detect.h" @@ -1056,3 +1058,15 @@ void SnortConfig::cleanup_fatal_error() #endif } +std::mutex SnortConfig::static_names_mutex; +std::unordered_map SnortConfig::static_names; + +const char* SnortConfig::get_static_name(const char* name) +{ + std::lock_guard static_name_lock(static_names_mutex); + auto entry = static_names.find(name); + if ( entry != static_names.end() ) + return entry->second.c_str(); + static_names.emplace(name, name); + return static_names[name].c_str(); +} diff --git a/src/main/snort_config.h b/src/main/snort_config.h index 441165707..9e89e10a9 100644 --- a/src/main/snort_config.h +++ b/src/main/snort_config.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -409,6 +410,8 @@ public: private: std::list reload_tuners; unsigned reload_id = 0; + static std::mutex static_names_mutex; + static std::unordered_map static_names; public: //------------------------------------------------------ @@ -712,6 +715,8 @@ public: static bool log_show_plugins() { return logging_flags & LOGGING_FLAG__SHOW_PLUGINS; } + + SO_PUBLIC static const char* get_static_name(const char* name); }; } diff --git a/src/managers/inspector_manager.cc b/src/managers/inspector_manager.cc index 4c2ae796f..7f440f446 100644 --- a/src/managers/inspector_manager.cc +++ b/src/managers/inspector_manager.cc @@ -2050,7 +2050,7 @@ void InspectorManager::full_inspection(Packet* p) { Flow* flow = p->flow; - if ( flow->has_service() and flow->searching_for_service() + if ( flow->service and flow->searching_for_service() and (!(p->is_cooked()) or p->is_defrag()) ) bumble(p); @@ -2118,7 +2118,13 @@ void InspectorManager::internal_execute(Packet* p) if ( p->disable_inspect ) return; - if (!p->flow) + unsigned reload_id = SnortConfig::get_thread_reload_id(); + if ( p->flow ) + { + if ( p->flow->reload_id != reload_id ) + DataBus::publish(FLOW_STATE_RELOADED_EVENT, p, p->flow); + } + else DataBus::publish(PKT_WITHOUT_FLOW_EVENT, p); FrameworkPolicy* fp = get_inspection_policy()->framework_policy; @@ -2162,7 +2168,6 @@ void InspectorManager::internal_execute(Packet* p) if ( !p->has_paf_payload() and p->flow->flow_state == Flow::FlowState::INSPECT ) p->flow->session->process(p); - unsigned reload_id = SnortConfig::get_thread_reload_id(); if ( p->flow->reload_id != reload_id ) { ::execute(p, tp->first.vec, tp->first.num); @@ -2172,7 +2177,7 @@ void InspectorManager::internal_execute(Packet* p) return; } - if ( !p->flow->has_service() ) + if ( !p->flow->service ) ::execute(p, fp->network.vec, fp->network.num); if ( p->disable_inspect ) diff --git a/src/network_inspectors/binder/binder.cc b/src/network_inspectors/binder/binder.cc index 518464ec6..15f1b6a94 100644 --- a/src/network_inspectors/binder/binder.cc +++ b/src/network_inspectors/binder/binder.cc @@ -578,9 +578,13 @@ public: void handle(DataEvent&, Flow* flow) override { - Binder* binder = InspectorManager::get_binder(); - if (binder && flow) - binder->handle_flow_after_reload(*flow); + // If reload_id is zero, this is a new flow and is bound by FLOW_STATE_SETUP_EVENT + if (flow && flow->reload_id && Flow::FlowState::INSPECT == flow->flow_state) + { + Binder* binder = InspectorManager::get_binder(); + if (binder) + binder->handle_flow_after_reload(*flow); + } } }; @@ -729,7 +733,7 @@ void Binder::handle_flow_setup(Flow& flow, bool standby) if (flow.ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID) { const SnortConfig* sc = SnortConfig::get_conf(); - flow.service = sc->proto_ref->get_shared_name(flow.ssn_state.snort_protocol_id); + flow.set_service(nullptr, sc->proto_ref->get_name(flow.ssn_state.snort_protocol_id)); } } @@ -761,7 +765,7 @@ void Binder::handle_flow_service_change(Flow& flow) Inspector* ins = nullptr; Inspector* data = nullptr; - if (flow.has_service()) + if (flow.service) { ins = find_gadget(flow, data); if (flow.gadget != ins) @@ -808,8 +812,8 @@ void Binder::handle_flow_service_change(Flow& flow) // If there is no inspector bound to this flow after the service change, see if there's at least // an associated protocol ID. - if (!ins && flow.has_service()) - flow.ssn_state.snort_protocol_id = SnortConfig::get_conf()->proto_ref->find(flow.service->c_str()); + if (!ins && flow.service) + flow.ssn_state.snort_protocol_id = SnortConfig::get_conf()->proto_ref->find(flow.service); if (flow.is_stream()) { @@ -992,7 +996,7 @@ void Binder::get_bindings(Packet* p, Stuff& stuff) Inspector* Binder::find_gadget(Flow& flow, Inspector*& data) { Stuff stuff; - get_bindings(flow, stuff, flow.has_service() ? flow.service->c_str() : nullptr); + get_bindings(flow, stuff, flow.service); data = stuff.data; return stuff.gadget; } diff --git a/src/network_inspectors/binder/binding.cc b/src/network_inspectors/binder/binding.cc index 4b66064f1..282bde171 100644 --- a/src/network_inspectors/binder/binding.cc +++ b/src/network_inspectors/binder/binding.cc @@ -579,10 +579,10 @@ inline bool Binding::check_service(const Flow& flow) const if (!when.has_criteria(BindWhen::Criteria::BWC_SVC)) return true; - if (!flow.has_service()) + if (!flow.service) return false; - return when.svc == flow.service->c_str(); + return when.svc == flow.service; } inline bool Binding::check_service(const char* service) const diff --git a/src/pub_sub/opportunistic_tls_event.h b/src/pub_sub/opportunistic_tls_event.h index 9aafeebd3..4d7b046ad 100644 --- a/src/pub_sub/opportunistic_tls_event.h +++ b/src/pub_sub/opportunistic_tls_event.h @@ -34,18 +34,18 @@ namespace snort class SO_PUBLIC OpportunisticTlsEvent : public snort::DataEvent { public: - OpportunisticTlsEvent(const snort::Packet* p, std::shared_ptr service) : + OpportunisticTlsEvent(const snort::Packet* p, const char* service) : pkt(p), next_service(service) { } const snort::Packet* get_packet() override { return pkt; } - std::shared_ptr get_next_service() + const char* get_next_service() { return next_service; } private: const snort::Packet* pkt; - std::shared_ptr next_service; + const char* next_service; }; } diff --git a/src/service_inspectors/dce_rpc/dce_common.cc b/src/service_inspectors/dce_rpc/dce_common.cc index 1f520fab2..cee6e3be0 100644 --- a/src/service_inspectors/dce_rpc/dce_common.cc +++ b/src/service_inspectors/dce_rpc/dce_common.cc @@ -43,9 +43,6 @@ using namespace snort; THREAD_LOCAL int dce2_detected = 0; static THREAD_LOCAL bool using_rpkt = false; -std::shared_ptr dce_rpc_service_name = - std::make_shared(DCE_RPC_SERVICE_NAME); - static const char* dce2_get_policy_name(DCE2_Policy policy) { const char* policyStr = nullptr; diff --git a/src/service_inspectors/dce_rpc/dce_common.h b/src/service_inspectors/dce_rpc/dce_common.h index 18396850c..481c21db9 100644 --- a/src/service_inspectors/dce_rpc/dce_common.h +++ b/src/service_inspectors/dce_rpc/dce_common.h @@ -43,7 +43,6 @@ extern THREAD_LOCAL int dce2_detected; #define GID_DCE2 133 #define DCE_RPC_SERVICE_NAME "dcerpc" -extern std::shared_ptr dce_rpc_service_name; enum DCE2_Policy { diff --git a/src/service_inspectors/dce_rpc/dce_http_proxy.cc b/src/service_inspectors/dce_rpc/dce_http_proxy.cc index 7da11138e..2cdaad9e4 100644 --- a/src/service_inspectors/dce_rpc/dce_http_proxy.cc +++ b/src/service_inspectors/dce_rpc/dce_http_proxy.cc @@ -68,7 +68,7 @@ void DceHttpProxy::clear(Packet* p) if ( c2s_splitter->cutover_inspector() && s2c_splitter->cutover_inspector() ) { dce_http_proxy_stats.http_proxy_sessions++; - flow->set_service(p, dce_rpc_service_name); + flow->set_service(p, DCE_RPC_SERVICE_NAME); } else dce_http_proxy_stats.http_proxy_session_failures++; diff --git a/src/service_inspectors/dce_rpc/dce_http_server.cc b/src/service_inspectors/dce_rpc/dce_http_server.cc index 557b035da..acf0126e7 100644 --- a/src/service_inspectors/dce_rpc/dce_http_server.cc +++ b/src/service_inspectors/dce_rpc/dce_http_server.cc @@ -64,7 +64,7 @@ void DceHttpServer::clear(Packet* p) if ( splitter->cutover_inspector()) { dce_http_server_stats.http_server_sessions++; - flow->set_service(p, dce_rpc_service_name); + flow->set_service(p, DCE_RPC_SERVICE_NAME); } else dce_http_server_stats.http_server_session_failures++; diff --git a/src/service_inspectors/ftp_telnet/ftp_data.cc b/src/service_inspectors/ftp_telnet/ftp_data.cc index 435d79507..f8a530172 100644 --- a/src/service_inspectors/ftp_telnet/ftp_data.cc +++ b/src/service_inspectors/ftp_telnet/ftp_data.cc @@ -49,8 +49,6 @@ using namespace snort; "FTP data channel handler" static const char* const fd_svc_name = "ftp-data"; -static std::shared_ptr shared_fd_svc_name = - std::make_shared(fd_svc_name); static THREAD_LOCAL ProfileStats ftpdataPerfStats; static THREAD_LOCAL SimpleStats fdstats; @@ -228,15 +226,15 @@ FtpDataFlowData::~FtpDataFlowData() void FtpDataFlowData::handle_expected(Packet* p) { - if (!p->flow->has_service()) + if (!p->flow->service) { - p->flow->set_service(p, shared_fd_svc_name); + p->flow->set_service(p, fd_svc_name); FtpDataFlowData* fd = (FtpDataFlowData*)p->flow->get_flow_data(FtpDataFlowData::inspector_id); if (fd and fd->in_tls) { - OpportunisticTlsEvent evt(p, shared_fd_svc_name); + OpportunisticTlsEvent evt(p, fd_svc_name); DataBus::publish(OPPORTUNISTIC_TLS_EVENT, evt, p->flow); } else diff --git a/src/service_inspectors/http_inspect/http_inspect.cc b/src/service_inspectors/http_inspect/http_inspect.cc index b408e2719..51c78e091 100755 --- a/src/service_inspectors/http_inspect/http_inspect.cc +++ b/src/service_inspectors/http_inspect/http_inspect.cc @@ -688,7 +688,7 @@ void HttpInspect::clear(Packet* p) if (session_data->cutover_on_clear) { Flow* flow = p->flow; - flow->clear_service(p); + flow->set_service(p, nullptr); flow->free_flow_data(HttpFlowData::inspector_id); } } diff --git a/src/service_inspectors/smtp/smtp.cc b/src/service_inspectors/smtp/smtp.cc index 412dc0e74..f54cc03b5 100644 --- a/src/service_inspectors/smtp/smtp.cc +++ b/src/service_inspectors/smtp/smtp.cc @@ -1116,7 +1116,6 @@ static void SMTP_ProcessServerPacket( /* This is either an initial server response or a STARTTLS response */ if (smtp_ssn->state == STATE_CONNECT) smtp_ssn->state = STATE_COMMAND; - break; case RESP_250: diff --git a/src/service_inspectors/ssl/ssl_inspector.cc b/src/service_inspectors/ssl/ssl_inspector.cc index e4701db51..d11b05413 100644 --- a/src/service_inspectors/ssl/ssl_inspector.cc +++ b/src/service_inspectors/ssl/ssl_inspector.cc @@ -403,7 +403,6 @@ static void snort_ssl(SSL_PROTO_CONF* config, Packet* p) // class stuff //------------------------------------------------------------------------- static const char* s_name = "ssl"; -static std::shared_ptr shared_s_name = std::make_shared(s_name); class Ssl : public Inspector { @@ -452,7 +451,7 @@ public: pkt->flow->flags.trigger_finalize_event = fd->finalize_info.orig_flag; fd->finalize_info.switch_in = false; pkt->flow->set_proxied(); - pkt->flow->set_service(const_cast(pkt), shared_s_name); + pkt->flow->set_service(const_cast(pkt), s_name); } } }; diff --git a/src/service_inspectors/wizard/curses.cc b/src/service_inspectors/wizard/curses.cc index 968934759..b14727cab 100644 --- a/src/service_inspectors/wizard/curses.cc +++ b/src/service_inspectors/wizard/curses.cc @@ -393,11 +393,11 @@ static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracke // map between service and curse details static vector curse_map { - // name service alg is_tcp - { "dce_udp", make_shared("dcerpc") , dce_udp_curse, false }, - { "dce_tcp", make_shared("dcerpc") , dce_tcp_curse, true }, - { "dce_smb", make_shared("netbios-ssn"), dce_smb_curse, true }, - { "sslv2" , make_shared("ssl") , ssl_v2_curse , true } + // name service alg is_tcp + { "dce_udp", "dcerpc" , dce_udp_curse, false }, + { "dce_tcp", "dcerpc" , dce_tcp_curse, true }, + { "dce_smb", "netbios-ssn", dce_smb_curse, true }, + { "sslv2" , "ssl" , ssl_v2_curse , true } }; bool CurseBook::add_curse(const char* key) diff --git a/src/service_inspectors/wizard/curses.h b/src/service_inspectors/wizard/curses.h index 8b2d9b682..1ffe226a2 100644 --- a/src/service_inspectors/wizard/curses.h +++ b/src/service_inspectors/wizard/curses.h @@ -21,7 +21,6 @@ #define CURSES_H #include -#include #include #include @@ -87,7 +86,7 @@ typedef bool (* curse_alg)(const uint8_t* data, unsigned len, CurseTracker*); struct CurseDetails { std::string name; - std::shared_ptr service; + const char* service; curse_alg alg; bool is_tcp; }; diff --git a/src/service_inspectors/wizard/hexes.cc b/src/service_inspectors/wizard/hexes.cc index 9b2880fb7..3a1cebf33 100644 --- a/src/service_inspectors/wizard/hexes.cc +++ b/src/service_inspectors/wizard/hexes.cc @@ -23,8 +23,11 @@ #include +#include "main/snort_config.h" + #include "magic.h" +using namespace snort; using namespace std; #define WILD 0x100 @@ -91,7 +94,7 @@ void HexBook::add_spell( ++i; } p->key = key; - p->value = make_shared(val); + p->value = SnortConfig::get_static_name(val); } bool HexBook::add_spell(const char* key, const char*& val) @@ -124,7 +127,7 @@ bool HexBook::add_spell(const char* key, const char*& val) } if ( p->key == key ) { - val = p->value->c_str(); + val = p->value; return false; } @@ -158,7 +161,7 @@ const MagicPage* HexBook::find_spell( if ( const MagicPage* q = find_spell(s, n, p->any, i+1) ) return q; } - return p->value.use_count() ? p : nullptr; + return p->value ? p : nullptr; } return p; } diff --git a/src/service_inspectors/wizard/magic.cc b/src/service_inspectors/wizard/magic.cc index a6f2a3f84..dd56d2546 100644 --- a/src/service_inspectors/wizard/magic.cc +++ b/src/service_inspectors/wizard/magic.cc @@ -42,16 +42,12 @@ MagicPage::~MagicPage() delete any; } -std::shared_ptr MagicBook::find_spell(const uint8_t* data, unsigned len, +const char* MagicBook::find_spell(const uint8_t* data, unsigned len, const MagicPage*& p) const { assert(p); - p = find_spell(data, len, p, 0); - if ( p && p->value.use_count() ) - return p->value; - - return nullptr; + return p ? p->value : nullptr; } MagicBook::MagicBook() diff --git a/src/service_inspectors/wizard/magic.h b/src/service_inspectors/wizard/magic.h index f35380a13..3d3ef7613 100644 --- a/src/service_inspectors/wizard/magic.h +++ b/src/service_inspectors/wizard/magic.h @@ -20,7 +20,6 @@ #ifndef MAGIC_H #define MAGIC_H -#include #include #include @@ -29,7 +28,7 @@ class MagicBook; struct MagicPage { std::string key; - std::shared_ptr value; + const char* value = nullptr; MagicPage* next[256]; MagicPage* any; @@ -53,8 +52,7 @@ public: MagicBook& operator=(const MagicBook&) = delete; virtual bool add_spell(const char* key, const char*& val) = 0; - virtual std::shared_ptr find_spell(const uint8_t*, unsigned len, - const MagicPage*&) const; + virtual const char* find_spell(const uint8_t* data, unsigned len, const MagicPage*&) const; const MagicPage* page1() const { return root; } diff --git a/src/service_inspectors/wizard/spells.cc b/src/service_inspectors/wizard/spells.cc index ce6ba6f6d..997876133 100644 --- a/src/service_inspectors/wizard/spells.cc +++ b/src/service_inspectors/wizard/spells.cc @@ -23,8 +23,11 @@ #include +#include "main/snort_config.h" + #include "magic.h" +using namespace snort; using namespace std; #define WILD 0x100 @@ -84,7 +87,7 @@ void SpellBook::add_spell( ++i; } p->key = key; - p->value = make_shared(val); + p->value = SnortConfig::get_static_name(val); } bool SpellBook::add_spell(const char* key, const char*& val) @@ -118,7 +121,7 @@ bool SpellBook::add_spell(const char* key, const char*& val) } if ( p->key == key ) { - val = p->value->c_str(); + val = p->value; return false; } @@ -162,7 +165,7 @@ const MagicPage* SpellBook::find_spell( } // If no match but has glob, continue lookup from glob - if ( !p->value.use_count() && glob ) + if ( !p->value && glob ) { p = glob; glob = nullptr; @@ -170,7 +173,7 @@ const MagicPage* SpellBook::find_spell( return find_spell(s, n, p, i); } - return p->value.use_count() ? p : nullptr; + return p->value ? p : nullptr; } return p; } diff --git a/src/service_inspectors/wizard/wizard.cc b/src/service_inspectors/wizard/wizard.cc index 7d20b9c9d..46de82103 100644 --- a/src/service_inspectors/wizard/wizard.cc +++ b/src/service_inspectors/wizard/wizard.cc @@ -195,7 +195,7 @@ StreamSplitter::Status MagicSplitter::scan( if ( wizard->cast_spell(wand, pkt->flow, data, len, wizard_processed_bytes) ) { trace_logf(wizard_trace, pkt, "%s streaming search found service %s\n", - to_server() ? "c2s" : "s2c", pkt->flow->service->c_str()); + to_server() ? "c2s" : "s2c", pkt->flow->service); count_hit(pkt->flow); wizard_processed_bytes = 0; return STOP; @@ -225,7 +225,7 @@ StreamSplitter::Status MagicSplitter::scan( // delayed. Because AppId depends on wizard only for SSH detection and SSH inspector can be // attached very early, event is raised here after first scan. In the future, wizard should be // enhanced to abort sooner if it can't detect service. - if (!pkt->flow->has_service() && !pkt->flow->flags.svc_event_generated) + if (!pkt->flow->service && !pkt->flow->flags.svc_event_generated) { DataBus::publish(FLOW_NO_SERVICE_EVENT, pkt); pkt->flow->flags.svc_event_generated = true; @@ -309,7 +309,7 @@ void Wizard::eval(Packet* p) if ( cast_spell(wand, p->flow, p->data, p->dsize, udp_processed_bytes) ) { trace_logf(wizard_trace, p, "%s datagram search found service %s\n", - c2s ? "c2s" : "s2c", p->flow->service->c_str()); + c2s ? "c2s" : "s2c", p->flow->service); ++tstats.udp_hits; } else @@ -328,12 +328,8 @@ StreamSplitter* Wizard::get_splitter(bool c2s) bool Wizard::spellbind( const MagicPage*& m, Flow* f, const uint8_t* data, unsigned len) { - std::shared_ptr p_shared = m->book.find_spell(data, len, m); - if (p_shared.use_count()) - f->service = p_shared; - else - f->service.reset(); - return f->has_service(); + f->service = m->book.find_spell(data, len, m); + return f->service != nullptr; } bool Wizard::cursebind(const vector& curse_tracker, Flow* f, @@ -343,11 +339,8 @@ bool Wizard::cursebind(const vector& curse_tracker, Flow* f { if (cst.curse->alg(data, len, cst.tracker)) { - if (cst.curse->service.use_count()) - f->service = cst.curse->service; - else - f->service.reset(); - if ( f->has_service() ) + f->service = cst.curse->service; + if ( f->service ) return true; } } @@ -375,7 +368,7 @@ bool Wizard::cast_spell( // If we reach max value of wizard_processed_bytes, // but not assign any inspector - raise tcp_miss and stop - if ( !f->has_service() && wizard_processed_bytes >= max_search_depth ) + if ( !f->service && wizard_processed_bytes >= max_search_depth ) { w.spell = nullptr; w.hex = nullptr; diff --git a/src/target_based/snort_protocols.cc b/src/target_based/snort_protocols.cc index b37d0a2fa..fb09564fd 100644 --- a/src/target_based/snort_protocols.cc +++ b/src/target_based/snort_protocols.cc @@ -26,8 +26,10 @@ #include "snort_protocols.h" #include +#include #include "log/messages.h" +#include "main/snort_config.h" #include "protocols/packet.h" #include "utils/util.h" #include "utils/util_cstring.h" @@ -39,12 +41,6 @@ SnortProtocolId ProtocolReference::get_count() const { return protocol_number; } const char* ProtocolReference::get_name(SnortProtocolId id) const -{ - std::shared_ptr shared_name = get_shared_name(id); - return shared_name->c_str(); -} - -std::shared_ptr ProtocolReference::get_shared_name(SnortProtocolId id) const { if ( id >= id_map.size() ) id = 0; @@ -55,9 +51,9 @@ std::shared_ptr ProtocolReference::get_shared_name(SnortProtocolId struct Compare { bool operator()(SnortProtocolId a, SnortProtocolId b) - { return map[a]->c_str() < map[b]->c_str(); } + { return 0 > strcmp(map[a], map[b]); } - vector>& map; + vector& map; }; const char* ProtocolReference::get_name_sorted(SnortProtocolId id) @@ -73,7 +69,7 @@ const char* ProtocolReference::get_name_sorted(SnortProtocolId id) if ( id >= ind_map.size() ) return nullptr; - return id_map[ind_map[id]]->c_str(); + return id_map[ind_map[id]]; } SnortProtocolId ProtocolReference::add(const char* protocol) @@ -86,7 +82,8 @@ SnortProtocolId ProtocolReference::add(const char* protocol) return protocol_ref->second; SnortProtocolId snort_protocol_id = protocol_number++; - id_map.emplace_back(make_shared(protocol)); + protocol = SnortConfig::get_static_name(protocol); + id_map.emplace_back(protocol); ref_table[protocol] = snort_protocol_id; return snort_protocol_id; @@ -127,5 +124,9 @@ ProtocolReference::ProtocolReference(ProtocolReference* old_proto_ref) { init(old_proto_ref); } ProtocolReference::~ProtocolReference() -{ ref_table.clear(); } +{ + ref_table.clear(); + id_map.clear(); + ind_map.clear(); +} diff --git a/src/target_based/snort_protocols.h b/src/target_based/snort_protocols.h index 6281cd1b5..b16712b2b 100644 --- a/src/target_based/snort_protocols.h +++ b/src/target_based/snort_protocols.h @@ -22,7 +22,6 @@ #ifndef SNORT_PROTOCOLS_H #define SNORT_PROTOCOLS_H -#include #include #include #include @@ -74,7 +73,6 @@ public: SnortProtocolId get_count() const; const char* get_name(SnortProtocolId id) const; - std::shared_ptr get_shared_name(SnortProtocolId id) const; const char* get_name_sorted(SnortProtocolId id); SnortProtocolId add(const char* protocol); @@ -83,7 +81,7 @@ public: bool operator()(SnortProtocolId a, SnortProtocolId b); private: - std::vector> id_map; + std::vector id_map; std::vector ind_map; std::unordered_map ref_table; @@ -91,6 +89,10 @@ private: void init(const ProtocolReference* old_proto_ref); }; + +void protocol_reference_global_init(); +void protocol_reference_global_term(); + } #endif diff --git a/src/target_based/test/proto_ref_test.cc b/src/target_based/test/proto_ref_test.cc index 198c67c02..a5869171c 100644 --- a/src/target_based/test/proto_ref_test.cc +++ b/src/target_based/test/proto_ref_test.cc @@ -32,8 +32,10 @@ using namespace snort; +const char* SnortConfig::get_static_name(const char* name) { return name; } + TEST_GROUP(protocol_reference) -{}; +{ }; // Service Protocols //