]> git.ipfire.org Git - thirdparty/snort3.git/commitdiff
Pull request #3279: Multi-tenant with reconcile inspectors and reputation with reload...
authorRon Dempster (rdempste) <rdempste@cisco.com>
Tue, 22 Mar 2022 19:06:38 +0000 (19:06 +0000)
committerRon Dempster (rdempste) <rdempste@cisco.com>
Tue, 22 Mar 2022 19:06:38 +0000 (19:06 +0000)
Merge in SNORT/snort3 from ~RDEMPSTE/snort3:reputation to master

Squashed commit of the following:

commit fb9b349ce3fc2612c4f0bdae6f1e03a511bf9cf7
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Tue Mar 22 11:06:13 2022 -0400

    framework: update base API version to 13

commit 877c1e7dcc63499301a8868880831b27ff9bcabe
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Fri Mar 11 07:32:55 2022 -0500

    appid: sum stats at tterm and null the thread local stats pointer after delete

commit d23843bb934a4072c1c15458f9ddf17a95d1d269
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Tue Mar 8 10:16:45 2022 -0500

    main: add the control connection to the analyzer command and a method to log a message to both console and the remote connection

commit aaf890c670f013e8af21c8db345139314084d13e
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Sat Mar 5 13:18:39 2022 -0500

    main: fix and reenable the distill_verdict unit test

commit edc81969f10a390a4a1e6e355906566405778583
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Tue Mar 8 09:37:46 2022 -0500

    managers: add get_inspector unit tests

commit 393507e0e4182033f7f726e710516ffc68e95d1d
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Fri Feb 25 12:22:24 2022 -0500

    policy_selectors: add a method to select policies based on DAQ_FlowStats_t

commit c85bb3a7b2225efda3e0ade20267746a989f7e01
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Mon Feb 14 12:39:59 2022 -0500

    appid: make appid a global inspector

commit 046846e765831debe98886fdf1ce57382db96c75
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Fri Feb 11 10:12:40 2022 -0500

    managers: add a faster get_inspectors method

commit 3470d1cb7dfdee60af067f15bba29694e4646ed3
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Fri Jan 14 10:22:17 2022 -0500

    inspector, main, inspector_manager: add support for thread local data in inspectors and commands updating reload_id

commit 3d9c2556dbb39220ca26d61e4f2e6e2477b55a22
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Tue Dec 7 15:43:49 2021 -0500

    reputation: add a command to reload repuation data

commit c74d98a34b089d0b86db78cac78c6aaa793c2853
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Tue Dec 21 08:22:14 2021 -0500

    flow: make service a shared pointer to handle reload properly

commit 6750746d83d0c82ff3ebe552be43f8d36797c29b
Author: Ron Dempster (rdempste) <rdempste@cisco.com>
Date:   Thu Dec 16 07:59:30 2021 -0500

    managers: move inspection policies into the corresponding network policy

122 files changed:
src/control/control.h
src/flow/flow.cc
src/flow/flow.h
src/flow/flow_control.cc
src/flow/test/flow_cache_test.cc
src/flow/test/flow_control_test.cc
src/flow/test/flow_stash_test.cc
src/flow/test/flow_test.cc
src/framework/base_api.h
src/framework/data_bus.cc
src/framework/data_bus.h
src/framework/inspector.cc
src/framework/inspector.h
src/framework/policy_selector.h
src/framework/test/data_bus_test.cc
src/hash/test/ghash_test.cc
src/hash/test/xhash_test.cc
src/hash/test/zhash_test.cc
src/helpers/test/hyper_search_test.cc
src/host_tracker/host_cache_module.cc
src/host_tracker/host_cache_module.h
src/host_tracker/test/host_cache_module_test.cc
src/ips_options/test/ips_regex_test.cc
src/log/messages.cc
src/log/messages.h
src/loggers/alert_csv.cc
src/loggers/alert_json.cc
src/main.cc
src/main/CMakeLists.txt
src/main/ac_shell_cmd.cc
src/main/ac_shell_cmd.h
src/main/analyzer.cc
src/main/analyzer.h
src/main/analyzer_command.cc
src/main/analyzer_command.h
src/main/modules.cc
src/main/policy.cc
src/main/policy.h
src/main/reload_tuner.h [new file with mode: 0644]
src/main/shell.cc
src/main/snort.cc
src/main/snort_config.cc
src/main/snort_config.h
src/main/test/CMakeLists.txt
src/main/test/distill_verdict_stubs.h [moved from src/main/test/stubs.h with 96% similarity]
src/main/test/distill_verdict_test.cc
src/managers/CMakeLists.txt
src/managers/inspector_manager.cc
src/managers/inspector_manager.h
src/managers/module_manager.cc
src/managers/test/CMakeLists.txt [new file with mode: 0644]
src/managers/test/get_inspector_stubs.h [new file with mode: 0644]
src/managers/test/get_inspector_test.cc [new file with mode: 0644]
src/network_inspectors/appid/appid_config.cc
src/network_inspectors/appid/appid_config.h
src/network_inspectors/appid/appid_ha.cc
src/network_inspectors/appid/appid_inspector.cc
src/network_inspectors/appid/appid_inspector.h
src/network_inspectors/appid/appid_module.cc
src/network_inspectors/appid/appid_module.h
src/network_inspectors/appid/appid_peg_counts.cc
src/network_inspectors/appid/appid_peg_counts.h
src/network_inspectors/appid/host_port_app_cache.cc
src/network_inspectors/appid/host_port_app_cache.h
src/network_inspectors/appid/lua_detector_api.cc
src/network_inspectors/appid/lua_detector_api.h
src/network_inspectors/appid/lua_detector_module.cc
src/network_inspectors/appid/lua_detector_module.h
src/network_inspectors/binder/bind_module.cc
src/network_inspectors/binder/bind_module.h
src/network_inspectors/binder/binder.cc
src/network_inspectors/binder/binding.cc
src/network_inspectors/perf_monitor/perf_module.cc
src/network_inspectors/perf_monitor/perf_reload_tuner.h
src/network_inspectors/port_scan/ps_module.cc
src/network_inspectors/port_scan/ps_module.h
src/network_inspectors/reputation/CMakeLists.txt
src/network_inspectors/reputation/reputation_commands.cc [new file with mode: 0644]
src/network_inspectors/reputation/reputation_commands.h [new file with mode: 0644]
src/network_inspectors/reputation/reputation_config.h
src/network_inspectors/reputation/reputation_inspect.cc
src/network_inspectors/reputation/reputation_inspect.h
src/network_inspectors/reputation/reputation_module.cc
src/network_inspectors/reputation/reputation_module.h
src/network_inspectors/reputation/reputation_parse.cc
src/network_inspectors/reputation/reputation_parse.h
src/network_inspectors/rna/rna_inspector.cc
src/network_inspectors/rna/rna_inspector.h
src/network_inspectors/rna/rna_module.h
src/network_inspectors/rna/test/rna_module_stubs.h
src/network_inspectors/rna/test/rna_module_test.cc
src/payload_injector/test/payload_injector_test.cc
src/policy_selectors/address_space_selector/address_space_selector.cc
src/policy_selectors/address_space_selector/address_space_selector_module.cc
src/policy_selectors/tenant_selector/tenant_selector.cc
src/policy_selectors/tenant_selector/tenant_selector_module.cc
src/pub_sub/opportunistic_tls_event.h
src/search_engines/test/hyperscan_test.cc
src/search_engines/test/search_tool_test.cc
src/service_inspectors/dce_rpc/dce_common.cc
src/service_inspectors/dce_rpc/dce_common.h
src/service_inspectors/dce_rpc/dce_http_proxy.cc
src/service_inspectors/dce_rpc/dce_http_server.cc
src/service_inspectors/ftp_telnet/ftp_data.cc
src/service_inspectors/http_inspect/http_inspect.cc
src/service_inspectors/pop/pop.cc
src/service_inspectors/ssl/ssl_inspector.cc
src/service_inspectors/wizard/curses.cc
src/service_inspectors/wizard/curses.h
src/service_inspectors/wizard/hexes.cc
src/service_inspectors/wizard/magic.cc
src/service_inspectors/wizard/magic.h
src/service_inspectors/wizard/spells.cc
src/service_inspectors/wizard/wizard.cc
src/stream/base/stream_base.cc
src/stream/base/stream_module.cc
src/stream/base/stream_module.h
src/stream/tcp/stream_tcp.cc
src/stream/tcp/tcp_reassembler.cc
src/target_based/host_attributes.cc
src/target_based/snort_protocols.cc
src/target_based/snort_protocols.h

index 6bc60750d217d4a26c9e2b321d690d3f662a8e48..a211b46a5c463b7ca0a816372307813763f240c3 100644 (file)
@@ -61,11 +61,11 @@ public:
     void shutdown();
 
     SO_PUBLIC bool is_local() const { return local; }
+    SO_PUBLIC bool respond(const char* format, va_list& ap);
     SO_PUBLIC bool respond(const char* format, ...) __attribute__((format (printf, 2, 3)));
     SO_PUBLIC static ControlConn* query_from_lua(const lua_State*);
 
 private:
-    bool respond(const char* format, va_list& ap);
     bool show_prompt();
     void touch();
 
index dd30e50522e55848c6677aac332d5d37267d7718..916a5b0d9b497ea492705f4be0573a2245781399 100644 (file)
@@ -24,6 +24,7 @@
 #include "flow.h"
 
 #include "detection/detection_engine.h"
+#include "flow/flow_key.h"
 #include "flow/ha.h"
 #include "flow/session.h"
 #include "framework/data_bus.h"
@@ -118,6 +119,8 @@ void Flow::term()
         delete stash;
         stash = nullptr;
     }
+
+    service.reset();
 }
 
 inline void Flow::clean()
@@ -211,6 +214,7 @@ void Flow::reset(bool do_cleanup)
         stash->reset();
 
     deferred_trust.clear();
+    service.reset();
 
     constexpr size_t offset = offsetof(Flow, context_chain);
     // FIXIT-L need a struct to zero here to make future proof
@@ -326,12 +330,33 @@ void Flow::free_flow_data(uint32_t proto)
 
 void Flow::free_flow_data()
 {
+    NetworkPolicy* np = get_network_policy();
+    InspectionPolicy* ip = get_inspection_policy();
+    IpsPolicy* ipsp = get_ips_policy();
+
+    unsigned t_reload_id = SnortConfig::get_thread_reload_id();
+    if (reload_id == t_reload_id)
+    {
+        ::set_network_policy(network_policy_id);
+        ::set_inspection_policy(inspection_policy_id);
+        ::set_ips_policy(SnortConfig::get_conf(), ips_policy_id);
+    }
+    else
+    {
+        _daq_pkt_hdr pkthdr = {};
+        pkthdr.address_space_id = key->addressSpaceId;
+        select_default_policy(pkthdr, SnortConfig::get_conf());
+    }
     while (flow_data)
     {
         FlowData* tmp = flow_data;
         flow_data = flow_data->next;
         delete tmp;
     }
+
+    set_network_policy(np);
+    set_inspection_policy(ip);
+    set_ips_policy(ipsp);
 }
 
 void Flow::call_handlers(Packet* p, bool eof)
@@ -546,12 +571,21 @@ bool Flow::is_direction_aborted(bool from_client) const
     return (session_flags & SSNFLAG_ABORT_CLIENT);
 }
 
-void Flow::set_service(Packet* pkt, const char* new_service)
+void Flow::set_service(Packet* pkt, std::shared_ptr<std::string> new_service)
 {
+    if (!new_service.use_count())
+        return clear_service(pkt);
+
     service = new_service;
     DataBus::publish(FLOW_SERVICE_CHANGE_EVENT, pkt);
 }
 
+void Flow::clear_service(Packet* pkt)
+{
+    service.reset();
+    DataBus::publish(FLOW_SERVICE_CHANGE_EVENT, pkt);
+}
+
 void Flow::swap_roles()
 {
     std::swap(flowstats.client_pkts, flowstats.server_pkts);
index bce283bb34ebe8f7834af4b9ebe984920d28cc7c..9b13f8d75734bc2f410965702b5816e5ba8adc86 100644 (file)
 // state.  Inspector state is stored in FlowData, and Flow manages a list
 // of FlowData items.
 
-#include <daq_common.h>
+#include <memory>
+#include <string>
 #include <sys/time.h>
 
+#include <daq_common.h>
+
 #include "detection/ips_context_chain.h"
 #include "flow/deferred_trust.h"
 #include "flow/flow_data.h"
@@ -197,7 +200,10 @@ public:
     void set_mpls_layer_per_dir(Packet*);
     Layer get_mpls_layer_per_dir(bool);
     void swap_roles();
-    void set_service(Packet* pkt, const char* new_service);
+    void set_service(Packet*, std::shared_ptr<std::string> new_service);
+    void clear_service(Packet*);
+    bool has_service() const
+    { return 0 != service.use_count(); }
     bool get_attr(const std::string& key, int32_t& val);
     bool get_attr(const std::string& key, std::string& val);
     void set_attr(const std::string& key, const int32_t& val);
@@ -281,26 +287,20 @@ public:
 
     void set_client(Inspector* ins)
     {
+        if (ssn_client)
+            ssn_client->rem_ref();
         ssn_client = ins;
-        ssn_client->add_ref();
-    }
-
-    void clear_client()
-    {
-        ssn_client->rem_ref();
-        ssn_client = nullptr;
+        if (ssn_client)
+            ssn_client->add_ref();
     }
 
     void set_server(Inspector* ins)
     {
+        if (ssn_server)
+            ssn_server->rem_ref();
         ssn_server = ins;
-        ssn_server->add_ref();
-    }
-
-    void clear_server()
-    {
-        ssn_server->rem_ref();
-        ssn_server = nullptr;
+        if (ssn_server)
+            ssn_server->add_ref();
     }
 
     void set_clouseau(Inspector* ins)
@@ -405,8 +405,8 @@ public:  // FIXIT-M privatize if possible
     // fields are organized by initialization and size to minimize
     // void space and allow for memset of tail end of struct
 
-    // these fields are const after initialization
     DeferredTrust deferred_trust;
+    std::shared_ptr<std::string> service;
 
     // Anything before this comment is not zeroed during construction
     const FlowKey* key;
@@ -441,10 +441,10 @@ public:  // FIXIT-M privatize if possible
     Inspector* gadget;    // service handler
     Inspector* assistant_gadget;
     Inspector* data;
-    const char* service;
 
     uint64_t expire_time;
 
+    unsigned network_policy_id;
     unsigned inspection_policy_id;
     unsigned ips_policy_id;
     unsigned reload_id;
index 611f5c9a91691a4e6f08fc1097d72402e735449d..6324ca993bb8d5be5e0905fb03699cf5231c8b43 100644 (file)
@@ -435,14 +435,24 @@ unsigned FlowControl::process(Flow* flow, Packet* p)
 
     if ( flow->flow_state != Flow::FlowState::SETUP )
     {
-        const SnortConfig* sc = SnortConfig::get_conf();
-        set_inspection_policy(sc, flow->inspection_policy_id);
-        set_ips_policy(sc, flow->ips_policy_id);
+        unsigned reload_id = SnortConfig::get_thread_reload_id();
+        if (flow->reload_id != reload_id)
+        {
+            flow->network_policy_id = get_network_policy()->policy_id;
+            if (flow->flow_state == Flow::FlowState::INSPECT)
+                DataBus::publish(FLOW_STATE_RELOADED_EVENT, p, flow);
+        }
+        else
+        {
+            set_inspection_policy(flow->inspection_policy_id);
+            set_ips_policy(p->context->conf, flow->ips_policy_id);
+        }
         p->filtering_state = flow->filtering_state;
     }
 
     else
     {
+        flow->network_policy_id = get_network_policy()->policy_id;
         if (PacketTracer::is_active())
             PacketTracer::log("Session: new snort session\n");
 
index aa9af717b971648fabb23175a7fcd7793ce4e2e3..2be52113936d6b20de13b007ff75ee1e386a8bff 100644 (file)
@@ -27,6 +27,7 @@
 #include "flow/flow_control.h"
 
 #include "detection/detection_engine.h"
+#include "main/policy.h"
 #include "main/snort_config.h"
 #include "main/snort_debug.h"
 #include "managers/inspector_manager.h"
@@ -67,7 +68,12 @@ void Active::set_drop_reason(char const*) { }
 Packet::Packet(bool) { }
 Packet::~Packet() = default;
 uint32_t Packet::get_flow_geneve_vni() const { return 0; }
-Flow::Flow() { memset(this, 0, sizeof(*this)); }
+Flow::Flow()
+{
+    constexpr size_t offset = offsetof(Flow, key);
+    // FIXIT-L need a struct to zero here to make future proof
+    memset((uint8_t*)this+offset, 0, sizeof(*this)-offset);
+}
 Flow::~Flow() = default;
 DetectionEngine::DetectionEngine() = default;
 ExpectCache::~ExpectCache() = default;
@@ -77,14 +83,14 @@ void Flow::term() { }
 void Flow::flush(bool) { }
 void Flow::reset(bool) { }
 void Flow::free_flow_data() { }
-void set_network_policy(const SnortConfig*, unsigned) { }
 void DataBus::publish(const char*, const uint8_t*, unsigned, Flow*) { }
 void DataBus::publish(const char*, Packet*, Flow*) { }
 const SnortConfig* SnortConfig::get_conf() { return nullptr; }
 void Flow::set_client_initiate(Packet*) { }
 void Flow::set_direction(Packet*) { }
-void set_inspection_policy(const SnortConfig*, unsigned) { }
-void set_ips_policy(const SnortConfig*, unsigned) { }
+void set_network_policy(unsigned) { }
+void set_inspection_policy(unsigned) { }
+void set_ips_policy(const snort::SnortConfig*, unsigned) { }
 void Flow::set_mpls_layer_per_dir(Packet*) { }
 void DetectionEngine::disable_all(Packet*) { }
 void Stream::drop_traffic(const Packet*, char) { }
@@ -101,6 +107,11 @@ void snort::TraceApi::filter(const Packet&) {}
 
 namespace snort
 {
+NetworkPolicy* get_network_policy() { return nullptr; }
+InspectionPolicy* get_inspection_policy() { return nullptr; }
+IpsPolicy* get_ips_policy() { return nullptr; }
+unsigned SnortConfig::get_thread_reload_id() { return 0; }
+
 namespace layer
 {
 const vlan::VlanTagHdr* get_vlan_layer(const Packet* const) { return nullptr; }
index 6174ed748e54f4fa02df5938b07670ad2ef4ca26..6512e53ace91a83216499318329535d309ca3be8 100644 (file)
@@ -27,6 +27,7 @@
 #include "flow/flow_control.h"
 
 #include "detection/detection_engine.h"
+#include "main/policy.h"
 #include "main/snort_config.h"
 #include "managers/inspector_manager.h"
 #include "packet_io/active.h"
@@ -81,15 +82,15 @@ bool FlowCache::prune_one(PruneReason, bool) { return true; }
 unsigned FlowCache::delete_flows(unsigned) { return 0; }
 unsigned FlowCache::timeout(unsigned, time_t) { return 1; }
 void Flow::init(PktType) { }
-void set_network_policy(const SnortConfig*, unsigned) { }
 void DataBus::publish(const char*, const uint8_t*, unsigned, Flow*) { }
 void DataBus::publish(const char*, Packet*, Flow*) { }
 const SnortConfig* SnortConfig::get_conf() { return nullptr; }
 void FlowCache::unlink_uni(Flow*) { }
 void Flow::set_client_initiate(Packet*) { }
 void Flow::set_direction(Packet*) { }
-void set_inspection_policy(const SnortConfig*, unsigned) { }
-void set_ips_policy(const SnortConfig*, unsigned) { }
+void set_network_policy(unsigned) { }
+void set_inspection_policy(unsigned) { }
+void set_ips_policy(const snort::SnortConfig*, unsigned) { }
 void Flow::set_mpls_layer_per_dir(Packet*) { }
 void DetectionEngine::disable_all(Packet*) { }
 void Stream::drop_traffic(const Packet*, char) { }
@@ -101,6 +102,11 @@ Flow* HighAvailabilityManager::import(Packet&, FlowKey&) { return nullptr; }
 
 namespace snort
 {
+NetworkPolicy* get_network_policy() { return nullptr; }
+InspectionPolicy* get_inspection_policy() { return nullptr; }
+IpsPolicy* get_ips_policy() { return nullptr; }
+unsigned SnortConfig::get_thread_reload_id() { return 0; }
+
 namespace layer
 {
 const vlan::VlanTagHdr* get_vlan_layer(const Packet* const) { return nullptr; }
index 0bffa3979bec04742fc1e879f7fdeb2ff1143b63..9b7c64a421915767c4f0471c60f45d9d30fb38aa 100644 (file)
@@ -111,7 +111,7 @@ void DataBus::_subscribe(const char* key, DataHandler* h)
 
 void DataBus::_unsubscribe(const char*, DataHandler*) {}
 
-void DataBus::_publish(const char* key, DataEvent& e, Flow* f)
+void DataBus::_publish(const char* key, DataEvent& e, Flow* f) const
 {
     auto v = map.find(key);
 
index eace4f931da195400402208490d2299ea674581e..77fe3c8e560d8ecc3beb161b3d0388e7f87ca033 100644 (file)
@@ -29,6 +29,7 @@
 #include "flow/ha.h"
 #include "framework/inspector.h"
 #include "framework/data_bus.h"
+#include "main/policy.h"
 #include "main/snort_config.h"
 #include "protocols/ip.h"
 #include "protocols/layer.h"
@@ -58,6 +59,21 @@ void FlowStash::reset() {}
 
 void DetectionEngine::onload(Flow*) {}
 
+void set_network_policy(unsigned) { }
+void set_inspection_policy(unsigned) { }
+void set_ips_policy(const snort::SnortConfig*, unsigned) { }
+void select_default_policy(const _daq_pkt_hdr&, const SnortConfig*) { }
+namespace snort
+{
+NetworkPolicy* get_network_policy() { return nullptr; }
+InspectionPolicy* get_inspection_policy() { return nullptr; }
+IpsPolicy* get_ips_policy() { return nullptr; }
+void set_network_policy(NetworkPolicy*) { }
+void set_inspection_policy(InspectionPolicy*) { }
+void set_ips_policy(IpsPolicy*) { }
+unsigned SnortConfig::get_thread_reload_id() { return 0; }
+}
+
 Packet* DetectionEngine::set_next_packet(Packet*, Flow*) { return nullptr; }
 
 IpsContext* DetectionEngine::get_context() { return nullptr; }
index b3d21230f5681b26020ca719b92f66c5287476ab..1b646581e8fd598107c05e812c34d262acca6b1d 100644 (file)
@@ -29,7 +29,7 @@
 
 // this is the current version of the base api
 // must be prefixed to subtype version
-#define BASE_API_VERSION 12
+#define BASE_API_VERSION 13
 
 // set options to API_OPTIONS to ensure compatibility
 #ifndef API_OPTIONS
index 12bea2593b40db586ce29ff0550bb364a95996c2..debd6a8801aa73b221701bc9d2c51899014bd058 100644 (file)
@@ -107,6 +107,12 @@ void DataBus::subscribe_network(const char* key, DataHandler* h)
     get_network_data_bus()._subscribe(key, h);
 }
 
+// for subscribers that need to receive events regardless of active inspection policy
+void DataBus::subscribe_global(const char* key, DataHandler* h, SnortConfig& sc)
+{
+    sc.global_dbus->_subscribe(key, h);
+}
+
 void DataBus::unsubscribe(const char* key, DataHandler* h)
 {
     get_data_bus()._unsubscribe(key, h);
@@ -117,9 +123,16 @@ void DataBus::unsubscribe_network(const char* key, DataHandler* h)
     get_network_data_bus()._unsubscribe(key, h);
 }
 
+void DataBus::unsubscribe_global(const char* key, DataHandler* h, SnortConfig& sc)
+{
+    sc.global_dbus->_unsubscribe(key, h);
+}
+
 // notify subscribers of event
 void DataBus::publish(const char* key, DataEvent& e, Flow* f)
 {
+    SnortConfig::get_conf()->global_dbus->_publish(key, e, f);
+
     NetworkPolicy* ni = get_network_policy();
     ni->dbus._publish(key, e, f);
 
@@ -176,7 +189,7 @@ void DataBus::_unsubscribe(const char* key, DataHandler* h)
 }
 
 // notify subscribers of event
-void DataBus::_publish(const char* key, DataEvent& e, Flow* f)
+void DataBus::_publish(const char* key, DataEvent& e, Flow* f) const
 {
     auto v = map.find(key);
 
index 59d901e30c108c29e6f73598121e1a3bd2604744..3b7f8c711ff32d230a54cd4c723d9d67d4caf5df 100644 (file)
@@ -38,6 +38,7 @@ namespace snort
 {
 class Flow;
 struct Packet;
+struct SnortConfig;
 
 class DataEvent
 {
@@ -101,10 +102,12 @@ public:
     // FIXIT-L ideally these would not be static or would take an inspection policy*
     static void subscribe(const char* key, DataHandler*);
     static void subscribe_network(const char* key, DataHandler*);
+    static void subscribe_global(const char* key, DataHandler*, SnortConfig&);
 
     // FIXIT-L these should be called during cleanup
     static void unsubscribe(const char* key, DataHandler*);
     static void unsubscribe_network(const char* key, DataHandler*);
+    static void unsubscribe_global(const char* key, DataHandler*, SnortConfig&);
 
     // runtime methods
     static void publish(const char* key, DataEvent&, Flow* = nullptr);
@@ -116,7 +119,7 @@ public:
 private:
     void _subscribe(const char* key, DataHandler*);
     void _unsubscribe(const char* key, DataHandler*);
-    void _publish(const char* key, DataEvent&, Flow*);
+    void _publish(const char* key, DataEvent&, Flow*) const;
 
 private:
     DataMap map;
@@ -147,6 +150,9 @@ private:
 // A flow has entered the setup state
 #define FLOW_STATE_SETUP_EVENT "flow.state_setup"
 
+// The policy has changed for the flow
+#define FLOW_STATE_RELOADED_EVENT "flow.reloaded"
+
 // A new flow is created on this packet
 #define STREAM_ICMP_NEW_FLOW_EVENT "stream.icmp_new_flow"
 #define STREAM_IP_NEW_FLOW_EVENT "stream.ip_new_flow"
index 9d103c67ccab1245ee38c5ae5f80415156c5c797..3f905a711ce77956f2075386505cdece4e1cec1a 100644 (file)
 #include "protocols/packet.h"
 #include "stream/stream_splitter.h"
 
+namespace snort
+{
+class ThreadSpecificData
+{
+public:
+    explicit ThreadSpecificData(unsigned max)
+    { data.resize(max); }
+    ~ThreadSpecificData() = default;
+
+    std::vector<void*> data;
+};
+}
+
 using namespace snort;
 
 //-------------------------------------------------------------------------
@@ -34,14 +47,11 @@ using namespace snort;
 //-------------------------------------------------------------------------
 
 unsigned THREAD_LOCAL Inspector::slot = 0;
-unsigned Inspector::max_slots = 1;
 
 Inspector::Inspector()
 {
     unsigned max = ThreadConfig::get_instance_max();
-    assert(slot < max);
     ref_count = new std::atomic_uint[max];
-
     for ( unsigned i = 0; i < max; ++i )
         ref_count[i] = 0;
 }
@@ -125,6 +135,25 @@ void Inspector::add_global_ref()
 void Inspector::rem_global_ref()
 { --ref_count[0]; }
 
+void Inspector::allocate_thread_storage()
+{
+    if (!thread_specific_data.use_count())
+        thread_specific_data = std::make_shared<ThreadSpecificData>(ThreadConfig::get_instance_max());
+}
+
+void Inspector::copy_thread_storage(Inspector* ins)
+{
+    assert(!thread_specific_data.use_count());
+    if (ins->thread_specific_data.use_count())
+        thread_specific_data = ins->thread_specific_data;
+}
+
+void Inspector::set_thread_specific_data(void* tsd)
+{ thread_specific_data->data[slot] = tsd; }
+
+void* Inspector::get_thread_specific_data() const
+{ return thread_specific_data->data[slot]; }
+
 static const char* InspectorTypeNames[IT_MAX] =
 {
     "passive",
index 96d453c9062eb9edc763031df28dc7e1e2f1a4f3..f598ecaf5b5cbd5449ba5b9f643a9a6dd86482cc 100644 (file)
@@ -25,6 +25,9 @@
 // in different ways.  These correspond to Snort 2X preprocessors.
 
 #include <atomic>
+#include <cstring>
+#include <memory>
+#include <vector>
 
 #include "framework/base_api.h"
 #include "main/thread.h"
@@ -59,6 +62,8 @@ struct InspectApi;
 // api for class
 //-------------------------------------------------------------------------
 
+class ThreadSpecificData;
+
 class SO_PUBLIC Inspector
 {
 public:
@@ -164,8 +169,15 @@ public:
     virtual bool can_start_tls() const
     { return false; }
 
+    void allocate_thread_storage();
+    void set_thread_specific_data(void*);
+    void* get_thread_specific_data() const;
+    void copy_thread_storage(Inspector*);
+
+    virtual void install_reload_handler(SnortConfig*)
+    { }
+
 public:
-    static unsigned max_slots;
     static THREAD_LOCAL unsigned slot;
 
 protected:
@@ -174,6 +186,7 @@ protected:
 
 private:
     const InspectApi* api = nullptr;
+    std::shared_ptr<ThreadSpecificData> thread_specific_data;
     std::atomic_uint* ref_count;
     SnortProtocolId snort_protocol_id = 0;
     // FIXIT-E Use std::string to avoid storing a pointer to external std::string buffers
index ab0e5e46f6736408f529bba028b5edb3938ce8cf..b3c2a1d53c7955ecb959582a0f03ac6d2308fc90 100644 (file)
@@ -29,6 +29,7 @@
 #include "framework/counts.h"
 #include "main/snort_types.h"
 
+struct _daq_flow_stats;
 struct _daq_pkt_hdr;
 
 namespace snort
@@ -81,7 +82,8 @@ public:
     }
     const PolicySelectorApi* get_api() const
     { return api; }
-    virtual bool select_default_policies(const _daq_pkt_hdr*, const SnortConfig*) = 0;
+    virtual bool select_default_policies(const _daq_pkt_hdr&, const SnortConfig*) = 0;
+    virtual bool select_default_policies(const _daq_flow_stats&, const SnortConfig*) = 0;
     virtual void show() const = 0;
 
 protected:
index b9f1a9444a84a3ab2a2dd868d6d260b194e7d5f8..cd409cef355c835c39ddb47189022852e6e5e6fc 100644 (file)
@@ -41,7 +41,7 @@ NetworkPolicy::~NetworkPolicy() = default;
 namespace snort
 {
 SnortConfig::SnortConfig(snort::SnortConfig const*, const char*)
-{ }
+{ global_dbus = new DataBus(); }
 
 const SnortConfig* SnortConfig::get_conf()
 {
@@ -58,7 +58,7 @@ SnortConfig* SnortConfig::get_main_conf()
 }
 
 SnortConfig::~SnortConfig()
-{ }
+{ delete global_dbus; }
 
 NetworkPolicy* get_network_policy()
 {
@@ -135,6 +135,26 @@ TEST_GROUP(data_bus)
     }
 };
 
+TEST(data_bus, subscribe_global)
+{
+    UTestHandler h;
+    DataBus::subscribe_global(DB_UTEST_EVENT, &h, snort_conf);
+
+    UTestEvent event(100);
+    DataBus::publish(DB_UTEST_EVENT, event);
+    CHECK(100 == h.evt_msg);
+
+    UTestEvent event1(200);
+    DataBus::publish(DB_UTEST_EVENT, event1);
+    CHECK(200 == h.evt_msg);
+
+    DataBus::unsubscribe_global(DB_UTEST_EVENT, &h, snort_conf);
+
+    UTestEvent event2(300);
+    DataBus::publish(DB_UTEST_EVENT, event2);
+    CHECK(200 == h.evt_msg); // unsubscribed!
+}
+
 TEST(data_bus, subscribe_network)
 {
     UTestHandler* h = new UTestHandler();
index 808d0b82c0e031c1ca03249b1a265536e3a33c41..b207dd576a05d1fe4797c0db0d1eeb0dc0006ad0 100644 (file)
@@ -38,6 +38,9 @@ using namespace snort;
 static SnortConfig my_config;
 THREAD_LOCAL SnortConfig* snort_conf = &my_config;
 
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
+
 // run_flags is used indirectly from HashFnc class by calling SnortConfig::static_hash()
 SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 { snort_conf->run_flags = 0;}
index b4c4987c2791e03e4fae0c73193bd5549eb1bbb6..350e84afb9aea40b25769f18bf4bcbf3fd40ac84 100644 (file)
@@ -38,6 +38,9 @@ using namespace snort;
 static SnortConfig my_config;
 THREAD_LOCAL SnortConfig* snort_conf = &my_config;
 
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
+
 // run_flags is used indirectly from HashFnc class by calling SnortConfig::static_hash()
 SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 { snort_conf->run_flags = 0;}
index 8af9b44d946037d3faa3fb912eb549ff422bf42c..0629283a85cb8f95f87c921899555a9da4f1c92a 100644 (file)
@@ -63,6 +63,9 @@ bool FlowHashKeyOps::key_compare(const void* k1, const void* k2, size_t len)
 static SnortConfig my_config;
 THREAD_LOCAL SnortConfig *snort_conf = &my_config;
 
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
+
 // run_flags is used indirectly from HashFnc class by calling SnortConfig::static_hash()
 SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 { snort_conf->run_flags = 0;}
index 9634ba188601b781cee9d25e7da66c34ae11bd9b..1fd57b154f9f906d54c186db1e3779590e7f5653 100644 (file)
@@ -50,6 +50,9 @@ static ScratchAllocator* scratcher = nullptr;
 
 static unsigned s_parse_errors = 0;
 
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
+
 SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 {
     state = &s_state;
index 44d4a62c3597d9132b1925e7cbf2025ee0991733..cca05b71d8172f031b8f3460a89323d64fe6a304 100644 (file)
@@ -371,7 +371,7 @@ bool HostCacheModule::end(const char* fqn, int, SnortConfig* sc)
     if ( memcap and !strcmp(fqn, HOST_CACHE_NAME) )
     {
         if ( Snort::is_reloading() )
-            sc->register_reload_resource_tuner(new HostCacheReloadTuner(memcap));
+            sc->register_reload_handler(new HostCacheReloadTuner(memcap));
         else
             host_cache.set_max_size(memcap);
     }
index 5144c9561bad1abd49f5d43b54b6f857e4770d1b..7fdc182aa8e159d32e12d9c179a6f86e482ba4ec 100644 (file)
@@ -25,7 +25,7 @@
 
 #include "framework/module.h"
 #include "main/snort.h"
-#include "main/snort_config.h"
+#include "main/reload_tuner.h"
 
 #include "host_cache.h"
 
index 05ed3740be75ed5a6631dbebdd76623d82069f1c..3e46d2da66841aa5af695a3e3bf69af8bfcef4bf 100644 (file)
@@ -72,7 +72,7 @@ void LogMessage(const char* format,...)
 }
 time_t packet_time() { return 0; }
 bool Snort::is_reloading() { return false; }
-void SnortConfig::register_reload_resource_tuner(ReloadResourceTuner* rrt) { delete rrt; }
+void SnortConfig::register_reload_handler(ReloadResourceTuner* rrt) { delete rrt; }
 } // end of namespace snort
 
 void show_stats(PegCount*, const PegInfo*, unsigned, const char*) { }
index e84e187fc316afb83fde641208f90369c6f1c7cc..6f18b8128b142ebc3138063af25c80efd8a4c12f 100644 (file)
@@ -56,6 +56,9 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf;
 static std::vector<void *> s_state;
 static ScratchAllocator* scratcher = nullptr;
 
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
+
 SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 {
     state = &s_state;
index b8c224d2295abf7bec2205a200a5a50a369b17a3..e621259fdb6e9a93e7749dd164c78a364cc9135e 100644 (file)
@@ -26,7 +26,6 @@
 #include <syslog.h>
 
 #include <cassert>
-#include <cstdarg>
 #include <cstring>
 
 #include "main/snort_config.h"
@@ -189,6 +188,13 @@ static void WriteLogMessage(FILE* fh, bool prefer_fh, const char* format, va_lis
 }
 
 // print an info message to stdout or syslog
+void LogMessage(const char* format, va_list& ap)
+{
+    if ( SnortConfig::log_quiet() )
+        return;
+    WriteLogMessage(stdout, false, format, ap);
+}
+
 void LogMessage(const char* format,...)
 {
     if ( SnortConfig::log_quiet() )
index e2bdaf89bf0d680b9d0e507c71f25e0c8dfb9398..9d8241ef4f22e86d766469bb445b97323be7caf0 100644 (file)
@@ -23,6 +23,7 @@
 
 #include <arpa/inet.h>
 #include <cstdio>
+#include <cstdarg>
 #include <string>
 #include <ctime>
 
@@ -61,6 +62,7 @@ SO_PUBLIC void ParseError(const char*, ...) __attribute__((format (printf, 1, 2)
 SO_PUBLIC void ReloadError(const char*, ...) __attribute__((format (printf, 1, 2)));
 [[noreturn]] SO_PUBLIC void ParseAbort(const char*, ...) __attribute__((format (printf, 1, 2)));
 
+SO_PUBLIC void LogMessage(const char*, va_list& ap);
 SO_PUBLIC void LogMessage(const char*, ...) __attribute__((format (printf, 1, 2)));
 SO_PUBLIC void LogMessage(FILE*, const char*, ...) __attribute__((format (printf, 2, 3)));
 SO_PUBLIC void WarningMessage(const char*, ...) __attribute__((format (printf, 1, 2)));
index 118ea34545f3afa495915c980c4d49040417431f..e786c5b5ec8cc2332a9d1034410fea3fd740620e 100644 (file)
@@ -343,8 +343,8 @@ static void ff_server_pkts(const Args& a)
 static void ff_service(const Args& a)
 {
     const char* svc = "unknown";
-    if ( a.pkt->flow and a.pkt->flow->service )
-        svc = a.pkt->flow->service;
+    if ( a.pkt->flow and a.pkt->flow->has_service() )
+        svc = a.pkt->flow->service->c_str();
     TextLog_Puts(csv_log, svc);
 }
 
index 5bc6422f382ed1ab90bc8bd5346f4b679b424381..bb1d1359915bb81a614b205a26735f736aaacef1 100644 (file)
@@ -473,8 +473,8 @@ static bool ff_service(const Args& a)
 {
     const char* svc = "unknown";
 
-    if ( a.pkt->flow and a.pkt->flow->service )
-        svc = a.pkt->flow->service;
+    if ( a.pkt->flow and a.pkt->flow->has_service() )
+        svc = a.pkt->flow->service->c_str();
 
     print_label(a, "service");
     TextLog_Quote(json_log, svc);
index 2bd7ecdb48ff8af520e610bb19847220e18bd132..561212dd336c0ed25c2e39e036dcf071f77f4cb5 100644 (file)
@@ -310,6 +310,8 @@ void snort::main_broadcast_command(AnalyzerCommand* ac, ControlConn* ctrlcon)
 
     ac = get_command(ac, ctrlcon);
     debug_logf(snort_trace, TRACE_MAIN, nullptr, "Broadcasting %s command\n", ac->stringify());
+    if (ac->need_update_reload_id())
+        SnortConfig::get_main_conf()->update_reload_id();
 
     for (unsigned idx = 0; idx < max_pigs; ++idx)
     {
@@ -422,7 +424,6 @@ int main_reload_config(lua_State* L)
     }
 
     PluginManager::reload_so_plugins_cleanup(sc);
-    sc->update_reload_id();
     SnortConfig::set_conf(sc);
     TraceApi::thread_reinit(sc->trace_config);
     proc_stats.conf_reloads++;
@@ -469,7 +470,6 @@ int main_reload_policy(lua_State* L)
         send_response(ctrlcon, "== reload failed\n");
         return 0;
     }
-    sc->update_reload_id();
     SnortConfig::set_conf(sc);
     proc_stats.policy_reloads++;
 
@@ -515,7 +515,6 @@ int main_reload_module(lua_State* L)
         send_response(ctrlcon, "== reload failed\n");
         return 0;
     }
-    sc->update_reload_id();
     SnortConfig::set_conf(sc);
     proc_stats.policy_reloads++;
 
index 97746692dc1a1873325e80de6719f8e6b287957c..ebcd871db3a486005bf92528f39981c9bdfb0fd4 100644 (file)
@@ -4,6 +4,7 @@ set (INCLUDES
     analyzer_command.h
     policy.h
     reload_tracker.h
+    reload_tuner.h
     snort.h
     snort_config.h
     snort_debug.h
@@ -22,6 +23,8 @@ if ( ENABLE_SHELL )
     set ( SHELL_SOURCES ac_shell_cmd.h ac_shell_cmd.cc)
 endif ( ENABLE_SHELL )
 
+add_subdirectory(test)
+
 add_library (main OBJECT
     analyzer.cc
     analyzer_command.cc
@@ -67,3 +70,4 @@ include_directories (${CMAKE_CURRENT_BINARY_DIR})
 install (FILES ${INCLUDES}
     DESTINATION "${INCLUDE_INSTALL_PATH}/main"
 )
+
index 7b44a58befc12a43dcc665233c457298c9a122b4..89b2ce1115af75014f22690516314220558ed178 100644 (file)
@@ -27,7 +27,7 @@
 
 #include "control/control.h"
 
-ACShellCmd::ACShellCmd(ControlConn* ctrlcon, AnalyzerCommand* ac) : ctrlcon(ctrlcon), ac(ac)
+ACShellCmd::ACShellCmd(ControlConn* conn, AnalyzerCommand* ac) : AnalyzerCommand(conn), ac(ac)
 {
     assert(ac);
 
index 47cf77a9cee1d1ae7a586c795e1f167ec73472a1..4a4a6166cc81b312da2d776c64af7d888d9281b7 100644 (file)
@@ -34,11 +34,12 @@ public:
     ACShellCmd() = delete;
     ACShellCmd(ControlConn*, snort::AnalyzerCommand*);
     bool execute(Analyzer&, void**) override;
+    bool need_update_reload_id() const override
+    { return ac->need_update_reload_id(); }
     const char* stringify() override { return ac->stringify(); }
     ~ACShellCmd() override;
 
 private:
-    ControlConn* ctrlcon;
     snort::AnalyzerCommand* ac;
 };
 
index 3feca6d4a6998db18bf2091cb60588c9690b803a..46cb06fc0817cfd999ff941b84a4fed5d055dff9 100644 (file)
@@ -169,6 +169,7 @@ static void process_daq_sof_eof_msg(DAQ_Msg_h msg, DAQ_Verdict& verdict)
     const DAQ_FlowStats_t *stats = (const DAQ_FlowStats_t*) daq_msg_get_hdr(msg);
     const char* key;
 
+    select_default_policy(*stats, SnortConfig::get_conf());
     if (daq_msg_get_type(msg) == DAQ_MSG_TYPE_EOF)
     {
         packet_time_update(&stats->eof_timestamp);
@@ -390,7 +391,7 @@ void Analyzer::process_daq_pkt_msg(DAQ_Msg_h msg, bool retry)
     Packet* p = switcher->get_context()->packet;
     p->context->wire_packet = p;
     p->context->packet_number = get_packet_number();
-    select_default_policy(pkthdr, p->context->conf);
+    select_default_policy(*pkthdr, p->context->conf);
 
     DetectionEngine::reset();
     sfthreshold_reset();
@@ -652,8 +653,8 @@ void Analyzer::term()
 
     DetectionEngine::idle();
     InspectorManager::thread_stop(sc);
-    ModuleManager::accumulate();
     InspectorManager::thread_term();
+    ModuleManager::accumulate();
     ActionManager::thread_term();
 
     IpsManager::clear_options(sc);
@@ -776,6 +777,9 @@ bool Analyzer::handle_command()
         return false;
 
     void* ac_state = nullptr;
+    if ( ac->need_update_reload_id() )
+        SnortConfig::update_thread_reload_id();
+
     if ( ac->execute(*this, &ac_state) )
         add_command_to_completed_queue(ac);
     else
index bd6c5b0eff61c246bfcd0d6e4e2a470b4b71d14b..2c2e354daed5116970db9e4a3751962dd9d5155d 100644 (file)
@@ -43,7 +43,6 @@ class Swapper;
 namespace snort
 {
 class AnalyzerCommand;
-class ReloadResourceTuner;
 class SFDAQInstance;
 struct Packet;
 struct SnortConfig;
index 726e4f8a3619a018c4b6119bc43d393bffec1de1..2b9c48241e530bd0918a823df5afe07450410237 100644 (file)
 
 #include "analyzer.h"
 #include "reload_tracker.h"
+#include "reload_tuner.h"
 #include "snort.h"
 #include "snort_config.h"
 #include "swapper.h"
 
 using namespace snort;
 
+void AnalyzerCommand::log_message(ControlConn* ctrlcon, const char* format, va_list& ap)
+{
+    LogMessage(format, ap);
+    if (ctrlcon && !ctrlcon->is_local())
+        ctrlcon->respond(format, ap);
+}
+
+void AnalyzerCommand::log_message(ControlConn* ctrlcon, const char* format, ...)
+{
+    va_list args;
+    va_start(args, format);
+    log_message(ctrlcon, format, args);
+    va_end(args);
+}
+
+void AnalyzerCommand::log_message(const char* format, ...)
+{
+    va_list args;
+    va_start(args, format);
+    log_message(ctrlcon, format, args);
+    va_end(args);
+}
+
 bool ACStart::execute(Analyzer& analyzer, void**)
 {
     analyzer.start();
@@ -108,9 +132,6 @@ bool ACResetStats::execute(Analyzer&, void**)
 ACResetStats::ACResetStats(clear_counter_type_t requested_type_l) : requested_type(
         requested_type_l) { }
 
-ACSwap::ACSwap(Swapper* ps, ControlConn *ctrlcon) : ps(ps), ctrlcon(ctrlcon)
-{ }
-
 bool ACSwap::execute(Analyzer& analyzer, void** ac_state)
 {
     if (ps)
@@ -180,15 +201,9 @@ ACSwap::~ACSwap()
     HostAttributesManager::swap_cleanup();
 
     ReloadTracker::end(ctrlcon);
-    LogMessage("== reload complete\n");
-    if (ctrlcon && !ctrlcon->is_local())
-        ctrlcon->respond("== reload complete\n");
+    log_message("== reload complete\n");
 }
 
-ACHostAttributesSwap::ACHostAttributesSwap(ControlConn *ctrlcon)
-    : ctrlcon(ctrlcon)
-{ }
-
 bool ACHostAttributesSwap::execute(Analyzer&, void**)
 {
     HostAttributesManager::initialize();
@@ -199,9 +214,7 @@ ACHostAttributesSwap::~ACHostAttributesSwap()
 {
     HostAttributesManager::swap_cleanup();
     ReloadTracker::end(ctrlcon);
-    LogMessage("== reload host attributes complete\n");
-    if (ctrlcon && !ctrlcon->is_local())
-        ctrlcon->respond("== reload host attributes complete\n");
+    log_message("== reload host attributes complete\n");
 }
 
 bool ACDAQSwap::execute(Analyzer& analyzer, void**)
index f2e4fa08c12482091556d7d13bcc0fc68f4745df..41abd5dcf2e7d5aaff85f9ef6c8f4d61abe9fea2 100644 (file)
@@ -20,6 +20,7 @@
 #ifndef ANALYZER_COMMANDS_H
 #define ANALYZER_COMMANDS_H
 
+#include <cstdarg>
 #include <vector>
 
 #include "main/snort_types.h"
@@ -35,14 +36,25 @@ class SFDAQInstance;
 class AnalyzerCommand
 {
 public:
+    AnalyzerCommand() : AnalyzerCommand(nullptr)
+    { }
+    explicit AnalyzerCommand(ControlConn* conn) : ctrlcon(conn)
+    { }
     virtual ~AnalyzerCommand() = default;
     virtual bool execute(Analyzer&, void**) = 0;
+    virtual bool need_update_reload_id() const
+    { return false; }
     virtual const char* stringify() = 0;
     unsigned get() { return ++ref_count; }
     unsigned put() { return --ref_count; }
+    SO_PUBLIC void log_message(const char* format, ...) __attribute__((format (printf, 2, 3)));
+    SO_PUBLIC static void log_message(ControlConn*, const char* format, ...) __attribute__((format (printf, 2, 3)));
     SO_PUBLIC static snort::SFDAQInstance* get_daq_instance(Analyzer& analyzer);
 
+    ControlConn* ctrlcon;
+
 private:
+    static void log_message(ControlConn*, const char* format, va_list& ap);
     unsigned ref_count = 0;
 };
 }
@@ -50,12 +62,11 @@ private:
 class ACGetStats : public snort::AnalyzerCommand
 {
 public:
-    ACGetStats(ControlConn* conn) : ctrlcon(conn) {}
+    ACGetStats(ControlConn* conn) : AnalyzerCommand(conn)
+    { }
     bool execute(Analyzer&, void**) override;
     const char* stringify() override { return "GET_STATS"; }
     ~ACGetStats() override;
-private:
-    ControlConn* ctrlcon;
 };
 
 typedef enum clear_counter_type
@@ -69,7 +80,7 @@ typedef enum clear_counter_type
     TYPE_HA
 } clear_counter_type_t;
 
-// FIXIT-M Will replace this vector with an unordered map of 
+// FIXIT-M Will replace this vector with an unordered map of
 // <clear_counter_type, clear_counter_type_string_map> when
 // will come up with more granular form of clearing module stats.
 static std::vector<const char*> clear_counter_type_string_map
@@ -145,25 +156,25 @@ class ACSwap : public snort::AnalyzerCommand
 {
 public:
     ACSwap() = delete;
-    ACSwap(Swapper* ps, ControlConn* ctrlcon);
+    ACSwap(Swapper* ps, ControlConn* conn) : AnalyzerCommand(conn), ps(ps)
+    { }
     bool execute(Analyzer&, void**) override;
+    bool need_update_reload_id() const override
+    { return true; }
     const char* stringify() override { return "SWAP"; }
     ~ACSwap() override;
 private:
     Swapper *ps;
-    ControlConn* ctrlcon;
 };
 
 class ACHostAttributesSwap : public snort::AnalyzerCommand
 {
 public:
-    ACHostAttributesSwap(ControlConn* ctrlcon);
+    ACHostAttributesSwap(ControlConn* conn) : AnalyzerCommand(conn)
+    { }
     bool execute(Analyzer&, void**) override;
     const char* stringify() override { return "HOST_ATTRIBUTES_SWAP"; }
     ~ACHostAttributesSwap() override;
-
-private:
-    ControlConn* ctrlcon;
 };
 
 class ACDAQSwap : public snort::AnalyzerCommand
@@ -178,9 +189,9 @@ namespace snort
 {
 // from main.cc
 #ifdef REG_TEST
-void main_unicast_command(AnalyzerCommand* ac, unsigned target, ControlConn* ctrlcon = nullptr);
+void main_unicast_command(AnalyzerCommand*, unsigned target, ControlConn* = nullptr);
 #endif
-SO_PUBLIC void main_broadcast_command(snort::AnalyzerCommand* ac, ControlConn* ctrlcon = nullptr);
+SO_PUBLIC void main_broadcast_command(snort::AnalyzerCommand*, ControlConn* = nullptr);
 }
 
 #endif
index 33ae970782f8f967b4dd50dd0fa68013c25c5012..3fb387d60615535299b575a3754824adfb42f47c 100644 (file)
@@ -1053,7 +1053,6 @@ class NetworkModule : public Module
 public:
     NetworkModule() : Module("network", network_help, network_params) { }
     bool set(const char*, Value&, SnortConfig*) override;
-    bool end(const char*, int, SnortConfig*) override;
 
     Usage get_usage() const override
     { return CONTEXT; }
@@ -1090,16 +1089,6 @@ bool NetworkModule::set(const char*, Value& v, SnortConfig* sc)
     return true;
 }
 
-bool NetworkModule::end(const char*, int idx, SnortConfig* sc)
-{
-    if (!idx)
-    {
-        NetworkPolicy* p = get_network_policy();
-        sc->policy_map->set_user_network(p);
-    }
-    return true;
-}
-
 //-------------------------------------------------------------------------
 // inspection policy module
 //-------------------------------------------------------------------------
@@ -1177,10 +1166,12 @@ bool InspectionModule::set(const char*, Value& v, SnortConfig* sc)
     return true;
 }
 
-bool InspectionModule::end(const char*, int, SnortConfig* sc)
+bool InspectionModule::end(const char*, int, SnortConfig*)
 {
     InspectionPolicy* p = get_inspection_policy();
-    sc->policy_map->set_user_inspection(p);
+    NetworkPolicy* np = get_network_parse_policy();
+    assert(np);
+    np->set_user_inspection(p);
     return true;
 }
 
index 440e0974643beaf34f643127b91719eaea16f8da..ee19cd08d62bcdccde516869ff5e8e2be609ef95 100644 (file)
@@ -30,6 +30,7 @@
 #include "framework/file_policy.h"
 #include "framework/policy_selector.h"
 #include "log/messages.h"
+#include "main/thread_config.h"
 #include "managers/inspector_manager.h"
 #include "parser/parse_conf.h"
 #include "parser/vars.h"
@@ -45,8 +46,8 @@ using namespace snort;
 // traffic policy
 //-------------------------------------------------------------------------
 
-NetworkPolicy::NetworkPolicy(PolicyId id, PolicyId default_inspection_id)
-    : policy_id(id), default_inspection_policy_id(default_inspection_id)
+NetworkPolicy::NetworkPolicy(PolicyId id, PolicyId default_ips_id)
+    : policy_id(id), default_ips_policy_id(default_ips_id)
 { init(nullptr, nullptr); }
 
 NetworkPolicy::NetworkPolicy(NetworkPolicy* other_network_policy, const char* exclude_name)
@@ -55,7 +56,24 @@ NetworkPolicy::NetworkPolicy(NetworkPolicy* other_network_policy, const char* ex
 NetworkPolicy::~NetworkPolicy()
 {
     FilePolicyBase::delete_file_policy(file_policy);
+    if (cloned)
+    {
+        if ( !inspection_policy.empty() )
+        {
+            InspectionPolicy* default_policy = inspection_policy[0];
+            default_policy->cloned = true;
+            delete default_policy;
+        }
+    }
+    else
+    {
+        for ( auto p : inspection_policy )
+            delete p;
+    }
+
     InspectorManager::delete_policy(this, cloned);
+
+    inspection_policy.clear();
 }
 
 void NetworkPolicy::init(NetworkPolicy* other_network_policy, const char* exclude_name)
@@ -63,10 +81,27 @@ void NetworkPolicy::init(NetworkPolicy* other_network_policy, const char* exclud
     file_policy = new FilePolicy;
     if (other_network_policy)
     {
+        for ( unsigned i = 0; i < (other_network_policy->inspection_policy.size()); i++)
+        {
+            if ( i == 0 )
+                inspection_policy.emplace_back(
+                    new InspectionPolicy(other_network_policy->inspection_policy[i]));
+            else
+                inspection_policy.emplace_back(other_network_policy->inspection_policy[i]);
+        }
+        user_inspection = other_network_policy->user_inspection;
+        // Fix references to inspection_policy[0]
+        for ( auto p : other_network_policy->user_inspection )
+        {
+            if ( p.second == other_network_policy->inspection_policy[0] )
+                user_inspection[p.first] = inspection_policy[0];
+        }
+
+
         dbus.clone(other_network_policy->dbus, exclude_name);
         policy_id = other_network_policy->policy_id;
-        default_inspection_policy_id = other_network_policy->default_inspection_policy_id;
         user_policy_id = other_network_policy->user_policy_id;
+        default_ips_policy_id = other_network_policy->default_ips_policy_id;
 
         min_ttl = other_network_policy->min_ttl;
         new_ttl = other_network_policy->new_ttl;
@@ -86,6 +121,12 @@ FilePolicy* NetworkPolicy::get_file_policy() const
 void NetworkPolicy::add_file_policy_rule(FileRule& file_rule)
 { file_policy->insert_file_rule(file_rule); }
 
+InspectionPolicy* NetworkPolicy::get_user_inspection_policy(unsigned user_id)
+{
+    auto it = user_inspection.find(user_id);
+    return it == user_inspection.end() ? nullptr : it->second;
+}
+
 //-------------------------------------------------------------------------
 // inspection policy
 //-------------------------------------------------------------------------
@@ -176,6 +217,8 @@ IpsPolicy::~IpsPolicy()
 
 PolicyMap::PolicyMap(PolicyMap* other_map, const char* exclude_name)
 {
+    unsigned max = ThreadConfig::get_instance_max();
+    inspector_tinit_complete = new bool[max]{};
     if ( other_map )
         clone(other_map, exclude_name);
     else
@@ -183,13 +226,14 @@ PolicyMap::PolicyMap(PolicyMap* other_map, const char* exclude_name)
         file_id = InspectorManager::create_single_instance_inspector_policy();
         flow_tracking = InspectorManager::create_single_instance_inspector_policy();
         global_inspector_policy = InspectorManager::create_global_inspector_policy();
-        add_shell(new Shell(nullptr, true), true);
+        add_shell(new Shell(nullptr, true), nullptr);
         empty_ips_policy = new IpsPolicy(ips_policy.size());
         ips_policy.push_back(empty_ips_policy);
     }
 
     set_network_policy(network_policy[0]);
-    set_inspection_policy(inspection_policy[0]);
+    set_network_parse_policy(network_policy[0]);
+    set_inspection_policy(network_policy[0]->get_inspection_policy(0));
     set_ips_policy(ips_policy[0]);
 }
 
@@ -197,17 +241,10 @@ PolicyMap::~PolicyMap()
 {
     if ( cloned )
     {
-        if ( !inspection_policy.empty() )
+        for (auto np: network_policy)
         {
-            InspectionPolicy* default_policy = inspection_policy[0];
-            default_policy->cloned = true;
-            delete default_policy;
-        }
-        if ( !network_policy.empty() )
-        {
-            NetworkPolicy* default_policy = network_policy[0];
-            default_policy->cloned = true;
-            delete default_policy;
+            np->cloned = true;
+            delete np;
         }
     }
     else
@@ -215,9 +252,6 @@ PolicyMap::~PolicyMap()
         for ( auto p : shells )
             delete p;
 
-        for ( auto p : inspection_policy )
-            delete p;
-
         for ( auto p : ips_policy )
             delete p;
 
@@ -232,10 +266,20 @@ PolicyMap::~PolicyMap()
     InspectorManager::destroy_global_inspector_policy(global_inspector_policy, cloned);
 
     shells.clear();
-    inspection_policy.clear();
     ips_policy.clear();
     network_policy.clear();
     shell_map.clear();
+    delete[] inspector_tinit_complete;
+}
+
+bool PolicyMap::setup_network_policies()
+{
+    for (auto* np : network_policy)
+    {
+        if (!set_user_network(np))
+            return false;
+    }
+    return true;
 }
 
 void PolicyMap::clone(PolicyMap *other_map, const char* exclude_name)
@@ -248,63 +292,45 @@ void PolicyMap::clone(PolicyMap *other_map, const char* exclude_name)
     ips_policy = other_map->ips_policy;
     empty_ips_policy = other_map->empty_ips_policy;
 
-    for ( unsigned i = 0; i < (other_map->network_policy.size()); i++)
-    {
-        if ( i == 0 )
-            network_policy.emplace_back(new NetworkPolicy(other_map->network_policy[i],
-                exclude_name));
-        else
-            network_policy.emplace_back(other_map->network_policy[i]);
-    }
-
-    for ( unsigned i = 0; i < (other_map->inspection_policy.size()); i++)
-    {
-        if ( i == 0 )
-            inspection_policy.emplace_back(new InspectionPolicy(other_map->inspection_policy[i]));
-        else
-            inspection_policy.emplace_back(other_map->inspection_policy[i]);
-    }
+    for ( unsigned i = 0; i < other_map->network_policy.size(); i++)
+        network_policy.emplace_back(new NetworkPolicy(other_map->network_policy[i],
+            i ? nullptr : exclude_name));
 
     shell_map = other_map->shell_map;
     // Fix references to network_policy[0] and inspection_policy[0]
     for ( auto p : other_map->shell_map )
     {
-        if ( p.second->network == other_map->network_policy[0] )
-            shell_map[p.first]->network = network_policy[0];
-        if ( p.second->inspection == other_map->inspection_policy[0] )
-            shell_map[p.first] = std::make_shared<PolicyTuple>(inspection_policy[0], p.second->ips,
-                p.second->network);
-    }
-
-    user_network = other_map->user_network;
-    // Fix references to network_policy[0]
-    for ( auto p : other_map->user_network )
-    {
-        if ( p.second == other_map->network_policy[0] )
-            user_network[p.first] = network_policy[0];
-    }
-
-    user_inspection = other_map->user_inspection;
-    // Fix references to inspection_policy[0]
-    for ( auto p : other_map->user_inspection )
-    {
-        if ( p.second == other_map->inspection_policy[0] )
-            user_inspection[p.first] = inspection_policy[0];
+        for ( unsigned idx = 0; idx < other_map->network_policy.size(); ++idx)
+        {
+            if ( p.second->network == other_map->network_policy[idx] )
+            {
+                shell_map[p.first]->network = network_policy[idx];
+                shell_map[p.first]->network_parse = network_policy[idx];
+            }
+            if ( p.second->inspection == other_map->network_policy[idx]->inspection_policy[0] )
+                shell_map[p.first] =
+                    std::make_shared<PolicyTuple>(
+                        other_map->network_policy[idx]->inspection_policy[0], p.second->ips,
+                        p.second->network, p.second->network);
+        }
     }
 
+    //user_network = other_map->user_network;
     user_ips = other_map->user_ips;
 }
 
 InspectionPolicy* PolicyMap::add_inspection_shell(Shell* sh)
 {
-    unsigned idx = inspection_policy.size();
-    InspectionPolicy* p = new InspectionPolicy(idx);
+    NetworkPolicy* np = get_network_parse_policy();
+    assert(np);
+    unsigned idx = np->inspection_policy_count();
+    InspectionPolicy* ip = new InspectionPolicy(idx);
 
     shells.push_back(sh);
-    inspection_policy.push_back(p);
-    shell_map[sh] = std::make_shared<PolicyTuple>(p, nullptr, nullptr);
+    np->inspection_policy.push_back(ip);
+    shell_map[sh] = std::make_shared<PolicyTuple>(ip, nullptr, nullptr, np);
 
-    return p;
+    return ip;
 }
 
 IpsPolicy* PolicyMap::add_ips_shell(Shell* sh)
@@ -314,25 +340,31 @@ IpsPolicy* PolicyMap::add_ips_shell(Shell* sh)
 
     shells.push_back(sh);
     ips_policy.push_back(p);
-    shell_map[sh] = std::make_shared<PolicyTuple>(nullptr, p, nullptr);
+    shell_map[sh] = std::make_shared<PolicyTuple>(nullptr, p, nullptr, get_network_parse_policy());
 
     return p;
 }
 
-std::shared_ptr<PolicyTuple> PolicyMap::add_shell(Shell* sh, bool include_network)
+std::shared_ptr<PolicyTuple> PolicyMap::add_shell(Shell* sh, NetworkPolicy* np_in)
 {
     shells.push_back(sh);
-    inspection_policy.push_back(new InspectionPolicy(inspection_policy.size()));
-    InspectionPolicy* ip = inspection_policy.back();
-    NetworkPolicy* new_network_policy = nullptr;
-    if (include_network)
+    IpsPolicy* ips = new IpsPolicy(ips_policy.size());
+    ips_policy.push_back(ips);
+    NetworkPolicy* np;
+    if (!np_in)
+    {
+        np_in = np = new NetworkPolicy(network_policy.size(), ips->policy_id);
+        network_policy.push_back(np);
+    }
+    else
     {
-        new_network_policy = new NetworkPolicy(network_policy.size(), ip->policy_id);
-        network_policy.push_back(new_network_policy);
+        np = np_in;
+        np_in = nullptr;
     }
-    ips_policy.push_back(new IpsPolicy(ips_policy.size()));
-    return shell_map[sh] = std::make_shared<PolicyTuple>(ip,
-        ips_policy.back(), new_network_policy);
+    InspectionPolicy* ip = new InspectionPolicy(np->inspection_policy_count());
+    np->inspection_policy.push_back(ip);
+    return shell_map[sh] =
+        std::make_shared<PolicyTuple>(ip, ips, np_in, np);
 }
 
 std::shared_ptr<PolicyTuple> PolicyMap::get_policies(Shell* sh)
@@ -342,18 +374,39 @@ std::shared_ptr<PolicyTuple> PolicyMap::get_policies(Shell* sh)
     return pt == shell_map.end() ? nullptr : pt->second;
 }
 
+NetworkPolicy* PolicyMap::get_user_network(unsigned user_id)
+{
+    auto it = user_network.find(user_id);
+    NetworkPolicy* np = (it == user_network.end()) ? nullptr : it->second;
+    return np;
+}
+
+bool PolicyMap::set_user_network(NetworkPolicy* p)
+{
+    NetworkPolicy* current_np = get_user_network(p->user_policy_id);
+    if (current_np && p != current_np)
+        return false;
+    user_network[p->user_policy_id] = p;
+    return true;
+}
+
+
 //-------------------------------------------------------------------------
 // policy nav
 //-------------------------------------------------------------------------
 
-static THREAD_LOCAL NetworkPolicy* s_traffic_policy = nullptr;
+static THREAD_LOCAL NetworkPolicy* s_network_policy = nullptr;
+static THREAD_LOCAL NetworkPolicy* s_network_parse_policy = nullptr;
 static THREAD_LOCAL InspectionPolicy* s_inspection_policy = nullptr;
 static THREAD_LOCAL IpsPolicy* s_detection_policy = nullptr;
 
 namespace snort
 {
 NetworkPolicy* get_network_policy()
-{ return s_traffic_policy; }
+{ return s_network_policy; }
+
+NetworkPolicy* get_network_parse_policy()
+{ return s_network_parse_policy; }
 
 InspectionPolicy* get_inspection_policy()
 { return s_inspection_policy; }
@@ -361,8 +414,11 @@ InspectionPolicy* get_inspection_policy()
 IpsPolicy* get_ips_policy()
 { return s_detection_policy; }
 
+void set_network_parse_policy(NetworkPolicy* p)
+{ s_network_parse_policy = p; }
+
 void set_network_policy(NetworkPolicy* p)
-{ s_traffic_policy = p; }
+{ s_network_policy = p; }
 
 void set_inspection_policy(InspectionPolicy* p)
 { s_inspection_policy = p; }
@@ -370,54 +426,56 @@ void set_inspection_policy(InspectionPolicy* p)
 void set_ips_policy(IpsPolicy* p)
 { s_detection_policy = p; }
 
-InspectionPolicy* get_user_inspection_policy(const SnortConfig* sc, unsigned policy_id)
+InspectionPolicy* get_user_inspection_policy(unsigned policy_id)
 {
-    return sc->policy_map->get_user_inspection(policy_id);
+    NetworkPolicy* np = get_network_policy();
+    assert(np);
+    return np->get_user_inspection_policy(policy_id);
 }
 
 NetworkPolicy* get_default_network_policy(const SnortConfig* sc)
 { return sc->policy_map->get_network_policy(0); }
 
-InspectionPolicy* get_default_inspection_policy(const SnortConfig* sc)
-{
-    return
-        sc->policy_map->get_inspection_policy(get_network_policy()->default_inspection_policy_id);
-}
-
 IpsPolicy* get_ips_policy(const SnortConfig* sc, unsigned i)
 {
     return sc && i < sc->policy_map->ips_policy_count() ?
         sc->policy_map->get_ips_policy(i) : nullptr;
 }
 
-IpsPolicy* get_user_ips_policy(const SnortConfig* sc, unsigned policy_id)
+IpsPolicy* get_default_ips_policy(const snort::SnortConfig* sc)
 {
-    return sc->policy_map->get_user_ips(policy_id);
+    NetworkPolicy* np = get_network_policy();
+    assert(np);
+    return np->get_default_ips_policy(sc);
 }
 
+IpsPolicy* get_user_ips_policy(const SnortConfig* sc, unsigned policy_id)
+{ return sc->policy_map->get_user_ips(policy_id); }
+
 IpsPolicy* get_empty_ips_policy(const SnortConfig* sc)
-{
-    return sc->policy_map->get_empty_ips();
-}
+{ return sc->policy_map->get_empty_ips(); }
 } // namespace snort
 
-void set_network_policy(const SnortConfig* sc, unsigned i)
+void set_network_policy(unsigned i)
 {
-    PolicyMap* pm = sc->policy_map;
+    PolicyMap* pm = SnortConfig::get_conf()->policy_map;
 
     if ( i < pm->network_policy_count() )
         set_network_policy(pm->get_network_policy(i));
 }
 
-void set_inspection_policy(const SnortConfig* sc, unsigned i)
+void set_inspection_policy(unsigned i)
 {
-    PolicyMap* pm = sc->policy_map;
-
-    if ( i < pm->inspection_policy_count() )
-        set_inspection_policy(pm->get_inspection_policy(i));
+    NetworkPolicy* np = get_network_policy();
+    if (np)
+    {
+        InspectionPolicy* ip = np->get_inspection_policy(i);
+        if (ip)
+            set_inspection_policy(ip);
+    }
 }
 
-void set_ips_policy(const SnortConfig* sc, unsigned i)
+void set_ips_policy(const snort::SnortConfig* sc, unsigned i)
 {
     PolicyMap* pm = sc->policy_map;
 
@@ -441,20 +499,24 @@ void set_policies(const SnortConfig* sc, Shell* sh)
 
 void set_default_policy(const SnortConfig* sc)
 {
-    set_network_policy(sc->policy_map->get_network_policy(0));
-    set_inspection_policy(sc->policy_map->get_inspection_policy(0));
-    set_ips_policy(sc->policy_map->get_ips_policy(0));
+    NetworkPolicy* np = get_default_network_policy(sc);
+    set_network_policy(np);
+    set_inspection_policy(np->get_inspection_policy(0));
+    set_ips_policy(get_ips_policy(sc, 0));
 }
 
-void select_default_policy(const _daq_pkt_hdr* pkthdr, const SnortConfig* sc)
+void select_default_policy(const _daq_pkt_hdr& pkthdr, const SnortConfig* sc)
 {
     PolicySelector* ps = sc->policy_map->get_policy_selector();
     if (!ps || !ps->select_default_policies(pkthdr, sc))
-    {
-        set_network_policy(sc->policy_map->get_network_policy(0));
-        set_inspection_policy(sc->policy_map->get_inspection_policy(0));
-        set_ips_policy(sc->policy_map->get_ips_policy(0));
-    }
+        set_default_policy(sc);
+}
+
+void select_default_policy(const _daq_flow_stats& stats, const snort::SnortConfig* sc)
+{
+    PolicySelector* ps = sc->policy_map->get_policy_selector();
+    if (!ps || !ps->select_default_policies(stats, sc))
+        set_default_policy(sc);
 }
 
 bool only_inspection_policy()
index ebdf7ed7bef2e0512fdc358bdc6bba263c0ec66b..5956e2ff7d8d3614e334a8ceedc9512d78f0404e 100644 (file)
@@ -49,6 +49,7 @@ class PolicySelector;
 struct SnortConfig;
 }
 
+struct _daq_flow_stats;
 struct _daq_pkt_hdr;
 struct PortTable;
 struct vartable_t;
@@ -67,6 +68,50 @@ enum PolicyMode
 
 // FIXIT-L split into separate headers
 
+//-------------------------------------------------------------------------
+// navigator stuff
+//-------------------------------------------------------------------------
+
+struct InspectionPolicy;
+struct IpsPolicy;
+struct NetworkPolicy;
+class Shell;
+
+namespace snort
+{
+SO_PUBLIC NetworkPolicy* get_network_policy();
+NetworkPolicy* get_network_parse_policy();
+SO_PUBLIC InspectionPolicy* get_inspection_policy();
+SO_PUBLIC IpsPolicy* get_ips_policy();
+
+SO_PUBLIC void set_network_policy(NetworkPolicy*);
+void set_network_parse_policy(NetworkPolicy*);
+SO_PUBLIC void set_inspection_policy(InspectionPolicy*);
+SO_PUBLIC void set_ips_policy(IpsPolicy*);
+
+SO_PUBLIC NetworkPolicy* get_default_network_policy(const snort::SnortConfig*);
+// Based on currently set network policy
+SO_PUBLIC InspectionPolicy* get_user_inspection_policy(unsigned policy_id);
+
+SO_PUBLIC IpsPolicy* get_ips_policy(const snort::SnortConfig*, unsigned i = 0);
+// Based on currently set network policy
+SO_PUBLIC IpsPolicy* get_default_ips_policy(const snort::SnortConfig*);
+SO_PUBLIC IpsPolicy* get_user_ips_policy(const snort::SnortConfig*, unsigned policy_id);
+SO_PUBLIC IpsPolicy* get_empty_ips_policy(const snort::SnortConfig*);
+}
+
+void set_network_policy(unsigned = 0);
+void set_inspection_policy(unsigned = 0);
+void set_ips_policy(const snort::SnortConfig*, unsigned = 0);
+
+void set_policies(const snort::SnortConfig*, Shell*);
+void set_default_policy(const snort::SnortConfig*);
+void select_default_policy(const _daq_pkt_hdr&, const snort::SnortConfig*);
+void select_default_policy(const _daq_flow_stats&, const snort::SnortConfig*);
+
+bool only_inspection_policy();
+bool only_ips_policy();
+
 //-------------------------------------------------------------------------
 // traffic stuff
 //-------------------------------------------------------------------------
@@ -86,19 +131,59 @@ enum DecodeEventFlag
     DECODE_EVENT_FLAG__DEFAULT = 0x00000001
 };
 
-// Snort ac-split creates the nap (network analysis policy)
-// Snort++ breaks the nap into network and inspection
+//-------------------------------------------------------------------------
+// inspection stuff
+//-------------------------------------------------------------------------
+
+struct InspectionPolicy
+{
+public:
+    InspectionPolicy(PolicyId = 0);
+    InspectionPolicy(InspectionPolicy* old_inspection_policy);
+    ~InspectionPolicy();
+
+    void configure();
+
+public:
+    PolicyId policy_id = 0;
+    PolicyMode policy_mode = POLICY_MODE__MAX;
+    uint32_t user_policy_id = 0;
+    uuid_t uuid{};
+
+    struct FrameworkPolicy* framework_policy;
+    snort::DataBus dbus;
+    bool cloned;
+
+private:
+    void init(InspectionPolicy* old_inspection_policy);
+};
+
+//-------------------------------------------------------------------------
+// Network stuff
+//-------------------------------------------------------------------------
 
 class FilePolicy;
 class FileRule;
+struct IpsPolicy;
 
 struct NetworkPolicy
 {
 public:
-    NetworkPolicy(PolicyId = 0, PolicyId default_inspection_id = 0);
+    NetworkPolicy(PolicyId = 0, PolicyId default_ips_id = 0);
     NetworkPolicy(NetworkPolicy*, const char*);
     ~NetworkPolicy();
 
+    InspectionPolicy* get_inspection_policy(unsigned i = 0)
+    { return i < inspection_policy.size() ? inspection_policy[i] : nullptr; }
+    unsigned inspection_policy_count()
+    { return inspection_policy.size(); }
+    InspectionPolicy* get_user_inspection_policy(unsigned user_id);
+    void set_user_inspection(InspectionPolicy* p)
+    { user_inspection[p->user_policy_id] = p; }
+
+    IpsPolicy* get_default_ips_policy(const snort::SnortConfig* sc)
+    { return snort::get_ips_policy(sc, default_ips_policy_id); }
+
     void add_file_policy_rule(FileRule& file_rule);
     snort::FilePolicyBase* get_base_file_policy() const;
     FilePolicy* get_file_policy() const;
@@ -125,9 +210,12 @@ public:
     struct TrafficPolicy* traffic_policy;
     snort::DataBus dbus;
 
+    std::vector<InspectionPolicy*> inspection_policy;
+    std::unordered_map<unsigned, InspectionPolicy*> user_inspection;
+
     PolicyId policy_id = 0;
     uint32_t user_policy_id = 0;
-    PolicyId default_inspection_policy_id = 0;
+    PolicyId default_ips_policy_id = 0;
 
     // minimum possible (allows all but errors to pass by default)
     uint8_t min_ttl = 1;
@@ -142,33 +230,6 @@ private:
     void init(NetworkPolicy*, const char*);
 };
 
-//-------------------------------------------------------------------------
-// inspection stuff
-//-------------------------------------------------------------------------
-
-struct InspectionPolicy
-{
-public:
-    InspectionPolicy(PolicyId = 0);
-    InspectionPolicy(InspectionPolicy* old_inspection_policy);
-    ~InspectionPolicy();
-
-    void configure();
-
-public:
-    PolicyId policy_id = 0;
-    PolicyMode policy_mode = POLICY_MODE__MAX;
-    uint32_t user_policy_id = 0;
-    uuid_t uuid{};
-
-    struct FrameworkPolicy* framework_policy;
-    snort::DataBus dbus;
-    bool cloned;
-
-private:
-    void init(InspectionPolicy* old_inspection_policy);
-};
-
 //-------------------------------------------------------------------------
 // detection stuff
 //-------------------------------------------------------------------------
@@ -221,16 +282,17 @@ public:
 // binding stuff
 //-------------------------------------------------------------------------
 
-class Shell;
-
 struct PolicyTuple
 {
     InspectionPolicy* inspection = nullptr;
     IpsPolicy* ips = nullptr;
     NetworkPolicy* network = nullptr;
+    NetworkPolicy* network_parse = nullptr;
 
-    PolicyTuple(InspectionPolicy* ins_pol, IpsPolicy* ips_pol, NetworkPolicy* net_pol) :
-        inspection(ins_pol), ips(ips_pol), network(net_pol) { }
+    PolicyTuple(InspectionPolicy* ins_pol, IpsPolicy* ips_pol, NetworkPolicy* net_pol,
+        NetworkPolicy* net_parse) :
+        inspection(ins_pol), ips(ips_pol), network(net_pol), network_parse(net_parse)
+    { }
 };
 
 struct GlobalInspectorPolicy;
@@ -244,33 +306,18 @@ public:
 
     InspectionPolicy* add_inspection_shell(Shell*);
     IpsPolicy* add_ips_shell(Shell*);
-    std::shared_ptr<PolicyTuple> add_shell(Shell*, bool include_network);
+    std::shared_ptr<PolicyTuple> add_shell(Shell*, NetworkPolicy*);
     std::shared_ptr<PolicyTuple> get_policies(Shell* sh);
-    void clone(PolicyMap *old_map, const char* exclude_name);
 
     Shell* get_shell(unsigned i = 0)
     { return i < shells.size() ? shells[i] : nullptr; }
 
-    void set_user_network(NetworkPolicy* p)
-    { user_network[p->user_policy_id] = p; }
-
-    void set_user_inspection(InspectionPolicy* p)
-    { user_inspection[p->user_policy_id] = p; }
+    bool setup_network_policies();
 
     void set_user_ips(IpsPolicy* p)
     { user_ips[p->user_policy_id] = p; }
 
-    NetworkPolicy* get_user_network(unsigned user_id)
-    {
-        auto it = user_network.find(user_id);
-        return it == user_network.end() ? nullptr : it->second;
-    }
-
-    InspectionPolicy* get_user_inspection(unsigned user_id)
-    {
-        auto it = user_inspection.find(user_id);
-        return it == user_inspection.end() ? nullptr : it->second;
-    }
+    NetworkPolicy* get_user_network(unsigned user_id);
 
     IpsPolicy* get_user_ips(unsigned user_id)
     {
@@ -280,24 +327,15 @@ public:
 
     NetworkPolicy* get_network_policy(unsigned i = 0)
     { return i < network_policy.size() ? network_policy[i] : nullptr; }
-
-    InspectionPolicy* get_inspection_policy(unsigned i = 0)
-    { return i < inspection_policy.size() ? inspection_policy[i] : nullptr; }
-
-    IpsPolicy* get_ips_policy(unsigned i = 0)
-    { return i < ips_policy.size() ? ips_policy[i] : nullptr; }
-
-    IpsPolicy* get_empty_ips()
-    { return empty_ips_policy; }
-
     unsigned network_policy_count()
     { return network_policy.size(); }
 
-    unsigned inspection_policy_count()
-    { return inspection_policy.size(); }
-
+    IpsPolicy* get_ips_policy(unsigned i = 0)
+    { return i < ips_policy.size() ? ips_policy[i] : nullptr; }
     unsigned ips_policy_count()
     { return ips_policy.size(); }
+    IpsPolicy* get_empty_ips()
+    { return empty_ips_policy; }
 
     unsigned shells_count()
     { return shells.size(); }
@@ -328,17 +366,23 @@ public:
         return (it == std::end(shell_map)) ? nullptr : it->first;
     }
 
+    bool get_inspector_tinit_complete(unsigned instance_id) const
+    { return inspector_tinit_complete[instance_id]; }
+
+    void set_inspector_tinit_complete(unsigned instance_id, bool val)
+    { inspector_tinit_complete[instance_id] = val; }
+
 private:
+    void clone(PolicyMap *old_map, const char* exclude_name);
+    bool set_user_network(NetworkPolicy* p);
+
     std::vector<Shell*> shells;
-    std::vector<InspectionPolicy*> inspection_policy;
-    std::vector<IpsPolicy*> ips_policy;
     std::vector<NetworkPolicy*> network_policy;
-
+    std::vector<IpsPolicy*> ips_policy;
     IpsPolicy* empty_ips_policy;
 
     std::unordered_map<Shell*, std::shared_ptr<PolicyTuple>> shell_map;
     std::unordered_map<unsigned, NetworkPolicy*> user_network;
-    std::unordered_map<unsigned, InspectionPolicy*> user_inspection;
     std::unordered_map<unsigned, IpsPolicy*> user_ips;
 
     snort::PolicySelector* selector = nullptr;
@@ -346,44 +390,9 @@ private:
     SingleInstanceInspectorPolicy* flow_tracking;
     GlobalInspectorPolicy* global_inspector_policy;
 
+    bool* inspector_tinit_complete;
     bool cloned = false;
 };
 
-//-------------------------------------------------------------------------
-// navigator stuff
-//-------------------------------------------------------------------------
-
-// FIXIT-L may be inlined at some point; on lockdown for now
-// FIXIT-L SO_PUBLIC required because SnortConfig::inline_mode(), etc. uses the function
-namespace snort
-{
-SO_PUBLIC NetworkPolicy* get_network_policy();
-SO_PUBLIC InspectionPolicy* get_inspection_policy();
-SO_PUBLIC IpsPolicy* get_ips_policy();
-
-SO_PUBLIC void set_network_policy(NetworkPolicy*);
-SO_PUBLIC void set_inspection_policy(InspectionPolicy*);
-SO_PUBLIC void set_ips_policy(IpsPolicy*);
-
-SO_PUBLIC NetworkPolicy* get_default_network_policy(const snort::SnortConfig*);
-SO_PUBLIC InspectionPolicy* get_user_inspection_policy(const snort::SnortConfig*, unsigned policy_id);
-SO_PUBLIC InspectionPolicy* get_default_inspection_policy(const snort::SnortConfig*);
-
-SO_PUBLIC IpsPolicy* get_ips_policy(const snort::SnortConfig*, unsigned i = 0);
-SO_PUBLIC IpsPolicy* get_user_ips_policy(const snort::SnortConfig*, unsigned policy_id);
-SO_PUBLIC IpsPolicy* get_empty_ips_policy(const snort::SnortConfig*);
-}
-
-void set_network_policy(const snort::SnortConfig*, unsigned = 0);
-void set_inspection_policy(const snort::SnortConfig*, unsigned = 0);
-void set_ips_policy(const snort::SnortConfig*, unsigned = 0);
-
-void set_policies(const snort::SnortConfig*, Shell*);
-void set_default_policy(const snort::SnortConfig*);
-void select_default_policy(const _daq_pkt_hdr*, const snort::SnortConfig*);
-
-bool only_inspection_policy();
-bool only_ips_policy();
-
 #endif
 
diff --git a/src/main/reload_tuner.h b/src/main/reload_tuner.h
new file mode 100644 (file)
index 0000000..163547b
--- /dev/null
@@ -0,0 +1,73 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2022-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+
+#ifndef RELOAD_TUNER_H
+#define RELOAD_TUNER_H
+
+namespace snort
+{
+
+class ReloadResourceTuner
+{
+public:
+    static const unsigned RELOAD_MAX_WORK_PER_PACKET = 3;
+    // be aggressive when idle as analyzer gets chance once in every second only due to daq timeout
+    static const unsigned RELOAD_MAX_WORK_WHEN_IDLE = 32767;
+
+    virtual ~ReloadResourceTuner() = default;
+
+    // returns true if resource tuning required, false otherwise
+    virtual bool tinit() = 0;
+
+    // each of these returns true if resource tuning is complete, false otherwise
+    virtual bool tune_packet_context() = 0;
+    virtual bool tune_idle_context() = 0;
+
+protected:
+    ReloadResourceTuner() = default;
+
+    unsigned max_work = RELOAD_MAX_WORK_PER_PACKET;
+    unsigned max_work_idle = RELOAD_MAX_WORK_WHEN_IDLE;
+};
+
+class ReloadSwapper : public ReloadResourceTuner
+{
+public:
+    virtual ~ReloadSwapper() override = default;
+
+    // each of these returns true if resource tuning is complete, false otherwise
+    bool tune_packet_context() override
+    { return true; }
+    bool tune_idle_context() override
+    { return true; }
+
+    bool tinit() override
+    {
+        tswap();
+        return false;
+    }
+
+    virtual void tswap() = 0;
+
+protected:
+    ReloadSwapper() = default;
+};
+
+}
+
+#endif
index 483912eee00436daad45aebce0ce8be2f7e18511..4807f5b85a0c990606e6d1acd4308c1aefbeeadd 100644 (file)
@@ -479,9 +479,10 @@ bool Shell::configure(SnortConfig* sc, bool is_root)
         set_default_policy(sc);
     else
     {
+        set_network_policy(pt->network);
+        set_network_parse_policy(pt->network_parse);
         set_inspection_policy(pt->inspection);
         set_ips_policy(pt->ips);
-        set_network_policy(pt->network);
     }
 
     if (!sc->tweaks.empty())
@@ -544,6 +545,7 @@ bool Shell::configure(SnortConfig* sc, bool is_root)
 
     current_shells.pop();
 
+    set_network_parse_policy(nullptr);
     set_default_policy(sc);
     ModuleManager::set_config(nullptr);
     loaded = true;
@@ -563,6 +565,7 @@ void Shell::install(const char* name, const luaL_Reg* reg)
 
 void Shell::execute(const char* cmd, string& rsp)
 {
+    set_default_policy(SnortConfig::get_conf());
     int err = 0;
     Lua::ManageStack ms(lua);
 
index 8f3b4791c5a8c2ee7726f6f1722668e67ffc9347..3d4b2a51c0be7954125b4e252251a7191848fbe3 100644 (file)
@@ -159,6 +159,9 @@ void Snort::init(int argc, char** argv)
     /* Set the global snort_conf that will be used during run time */
     SnortConfig::set_conf(sc);
 
+    if (!sc->policy_map->setup_network_policies())
+        ParseError("Network policy user ids must be unique\n");
+
     // This call must be immediately after "SnortConfig::set_conf(sc)"
     // since the first trace call may happen somewhere after this point
     TraceApi::thread_init(sc->trace_config);
@@ -206,6 +209,8 @@ void Snort::init(int argc, char** argv)
     else if ( SnortConfig::log_verbose() )
         InspectorManager::print_config(sc);
 
+    InspectorManager::global_init();
+    InspectorManager::prepare_inspectors(sc);
     InspectorManager::prepare_controls(sc);
 
     // Must be after InspectorManager::configure()
@@ -481,7 +486,8 @@ SnortConfig* Snort::get_reload_config(const char* fname, const char* plugin_path
         return nullptr;
     }
 
-    InspectorManager::tear_down_removed_inspectors(old, sc);
+    InspectorManager::reconcile_inspectors(old, sc);
+    InspectorManager::prepare_inspectors(sc);
     InspectorManager::prepare_controls(sc);
 
     FileService::verify_reload(sc);
@@ -551,6 +557,7 @@ SnortConfig* Snort::get_updated_policy(
     reset_parse_errors();
 
     SnortConfig* sc = new SnortConfig(other_conf, iname);
+    sc->global_dbus->clone(*other_conf->global_dbus, iname);
 
     if ( fname )
     {
@@ -601,6 +608,8 @@ SnortConfig* Snort::get_updated_policy(
         return nullptr;
     }
 
+    InspectorManager::reconcile_inspectors(other_conf, sc, true);
+    InspectorManager::prepare_inspectors(sc);
     InspectorManager::prepare_controls(sc);
 
     other_conf->cloned = true;
@@ -614,6 +623,7 @@ SnortConfig* Snort::get_updated_module(SnortConfig* other_conf, const char* name
     reloading = true;
 
     SnortConfig* sc = new SnortConfig(other_conf, name);
+    sc->global_dbus->clone(*other_conf->global_dbus, name);
 
     if ( name )
     {
@@ -641,6 +651,8 @@ SnortConfig* Snort::get_updated_module(SnortConfig* other_conf, const char* name
         return nullptr;
     }
 
+    InspectorManager::reconcile_inspectors(other_conf, sc, true);
+    InspectorManager::prepare_inspectors(sc);
     InspectorManager::prepare_controls(sc);
 
     other_conf->cloned = true;
index 99f2436c5c04e271cdf120362a50e2b4bbe397b7..d2c7c7a8c6b942e8a618ed3935f14adb81b22592 100644 (file)
@@ -72,6 +72,7 @@
 #include "utils/util_cstring.h"
 
 #include "analyzer.h"
+#include "reload_tuner.h"
 #include "thread_config.h"
 
 using namespace snort;
@@ -90,7 +91,12 @@ using namespace snort;
 #define OUTPUT_U2   "unified2"
 #define OUTPUT_FAST "alert_fast"
 
-static THREAD_LOCAL const SnortConfig* snort_conf = nullptr;
+struct ThreadSnortConfig
+{
+    const SnortConfig* snort_conf;
+    unsigned reload_id;
+};
+static THREAD_LOCAL ThreadSnortConfig thread_snort_config = {};
 
 uint32_t SnortConfig::warning_flags = 0;
 uint32_t SnortConfig::logging_flags = 0;
@@ -136,20 +142,23 @@ static PolicyMode init_policy_mode(const SnortConfig* sc, PolicyMode mode)
 
 static void init_policies(SnortConfig* sc)
 {
-    IpsPolicy* ips_policy = nullptr;
-    InspectionPolicy* inspection_policy = nullptr;
+    for ( unsigned nidx = 0; nidx <  sc->policy_map->network_policy_count(); ++nidx )
+    {
+        NetworkPolicy* network_policy = sc->policy_map->get_network_policy(nidx);
+
+        for ( unsigned idx = 0; idx < network_policy->inspection_policy_count(); ++idx )
+        {
+            InspectionPolicy* inspection_policy = network_policy->get_inspection_policy(idx);
+            inspection_policy->policy_mode = init_policy_mode(sc, inspection_policy->policy_mode);
+        }
+    }
 
     for ( unsigned idx = 0; idx <  sc->policy_map->ips_policy_count(); ++idx )
     {
-        ips_policy = sc->policy_map->get_ips_policy(idx);
+        IpsPolicy* ips_policy = get_ips_policy(sc, idx);
         ips_policy->policy_mode = init_policy_mode(sc, ips_policy->policy_mode);
     }
 
-    for ( unsigned idx = 0; idx < sc->policy_map->inspection_policy_count(); ++idx )
-    {
-        inspection_policy = sc->policy_map->get_inspection_policy(idx);
-        inspection_policy->policy_mode = init_policy_mode(sc, inspection_policy->policy_mode);
-    }
 }
 
 void SnortConfig::init(const SnortConfig* const other_conf, ProtocolReference* protocol_reference,
@@ -179,6 +188,7 @@ void SnortConfig::init(const SnortConfig* const other_conf, ProtocolReference* p
         memory = new MemoryConfig();
         policy_map = new PolicyMap;
         thread_config = new ThreadConfig();
+        global_dbus = new DataBus();
 
         proto_ref = new ProtocolReference(protocol_reference);
         so_rules = new SoRules;
@@ -214,6 +224,7 @@ SnortConfig::~SnortConfig()
 {
     if ( cloned )
     {
+        delete global_dbus;
         policy_map->set_cloned(true);
         delete policy_map;
         return;
@@ -247,7 +258,7 @@ SnortConfig::~SnortConfig()
         snort_free(eth_dst);
 
     if ( fast_pattern_config &&
-        (!snort_conf || this == snort_conf ||
+        (!thread_snort_config.snort_conf || this == thread_snort_config.snort_conf ||
         (fast_pattern_config->get_search_api() !=
         get_conf()->fast_pattern_config->get_search_api())) )
     {
@@ -264,6 +275,7 @@ SnortConfig::~SnortConfig()
     delete trace_config;
     delete overlay_trace_config;
     delete ha_config;
+        delete global_dbus;
 
     delete profiler;
     delete latency;
@@ -331,6 +343,7 @@ void SnortConfig::post_setup()
 void SnortConfig::clone(const SnortConfig* const conf)
 {
     *this = *conf;
+    global_dbus = new DataBus();
     if (conf->homenet.get_family() != 0)
         memcpy(&homenet, &conf->homenet, sizeof(homenet));
 
@@ -467,6 +480,9 @@ bool SnortConfig::verify() const
     bool config_ok = false;
     const SnortConfig* sc = get_conf();
 
+    if (!policy_map->setup_network_policies())
+        ReloadError("Network policy user ids must be unique\n");
+
     if ( sc->asn1_mem != asn1_mem )
         ReloadError("Changing detection.asn1_mem requires a restart.\n");
 
@@ -978,14 +994,20 @@ void SnortConfig::release_scratch(int id)
 }
 
 SnortConfig* SnortConfig::get_main_conf()
-{ return const_cast<SnortConfig*>(snort_conf); }
+{ return const_cast<SnortConfig*>(thread_snort_config.snort_conf); }
 
 const SnortConfig* SnortConfig::get_conf()
-{ return snort_conf; }
+{ return thread_snort_config.snort_conf; }
+
+unsigned SnortConfig::get_thread_reload_id()
+{ return thread_snort_config.reload_id; }
+
+void SnortConfig::update_thread_reload_id()
+{ thread_snort_config.reload_id = thread_snort_config.snort_conf->reload_id; }
 
 void SnortConfig::set_conf(const SnortConfig* sc)
 {
-    snort_conf = sc;
+    thread_snort_config.snort_conf = sc;
 
     if ( sc )
     {
@@ -995,7 +1017,7 @@ void SnortConfig::set_conf(const SnortConfig* sc)
     }
 }
 
-void SnortConfig::register_reload_resource_tuner(ReloadResourceTuner* rrt)
+void SnortConfig::register_reload_handler(ReloadResourceTuner* rrt)
 {
     if (Snort::is_reloading())
         reload_tuners.push_back(rrt);
index f1eb85ecc1e5103ce55869f0d09093901364c105..f2da7cab8c3647b35ef778c88e385d136dde2fdd 100644 (file)
@@ -163,33 +163,11 @@ namespace snort
 {
 class GHash;
 class ProtocolReference;
+class ReloadResourceTuner;
 class ThreadConfig;
 class XHash;
 struct ProfilerConfig;
 
-class ReloadResourceTuner
-{
-public:
-    static const unsigned RELOAD_MAX_WORK_PER_PACKET = 3;
-    // be aggressive when idle as analyzer gets chance once in every second only due to daq timeout
-    static const unsigned RELOAD_MAX_WORK_WHEN_IDLE = 32767;
-
-    virtual ~ReloadResourceTuner() = default;
-
-    // returns true if resource tuning required, false otherwise
-    virtual bool tinit() = 0;
-
-    // each of these returns true if resource tuning is complete, false otherwise
-    virtual bool tune_packet_context() = 0;
-    virtual bool tune_idle_context() = 0;
-
-protected:
-    ReloadResourceTuner() = default;
-
-    unsigned max_work = RELOAD_MAX_WORK_PER_PACKET;
-    unsigned max_work_idle = RELOAD_MAX_WORK_WHEN_IDLE;
-};
-
 struct SnortConfig
 {
 private:
@@ -392,6 +370,8 @@ public:
     PolicyMap* policy_map = nullptr;
     std::string tweaks;
 
+    DataBus* global_dbus = nullptr;
+
     uint16_t tunnel_mask = 0;
 
     int16_t max_aux_ip = 16;
@@ -428,11 +408,11 @@ public:
     bool cloned = false;
     Plugins* plugins = nullptr;
     SoRules* so_rules = nullptr;
-    unsigned reload_id = 0;
 
     DumpConfigType dump_config_type = DUMP_CONFIG_NONE;
 private:
     std::list<ReloadResourceTuner*> reload_tuners;
+    unsigned reload_id = 0;
 
 public:
     //------------------------------------------------------
@@ -697,13 +677,16 @@ public:
     // runtime access to const config - especially for packet threads
     // prefer access via packet->context->conf
     SO_PUBLIC static const SnortConfig* get_conf();
+    // Thread local copy of the reload_id needed for commands that cause reevaluation
+    SO_PUBLIC static unsigned get_thread_reload_id();
+    SO_PUBLIC static void update_thread_reload_id();
 
     // runtime access to mutable config - main thread only, and only special cases
     SO_PUBLIC static SnortConfig* get_main_conf();
 
     static void set_conf(const SnortConfig*);
 
-    SO_PUBLIC void register_reload_resource_tuner(ReloadResourceTuner*);
+    SO_PUBLIC void register_reload_handler(ReloadResourceTuner*);
 
     static void cleanup_fatal_error();
 
index e26ac46649098ff6e1922160c61d8e9911edfeb1..7338bc281948fcb8b355723d0d8f9c073c2f4533 100644 (file)
@@ -1,7 +1,7 @@
 if ( ENABLE_SHELL )
     add_cpputest(distill_verdict_test
         SOURCES
-            stubs.h
+            distill_verdict_stubs.h
             ../analyzer.cc
             ../../packet_io/active.cc
     )
similarity index 96%
rename from src/main/test/stubs.h
rename to src/main/test/distill_verdict_stubs.h
index 41416d9472a3938ebcb32ce58106fee177e5f1fe..45c8c297e0e938096c0da685c68f018f11a38c08 100644 (file)
@@ -116,6 +116,9 @@ eth_t* eth_close(eth_t*) { return nullptr; }
 ssize_t eth_send(eth_t*, const void*, size_t) { return -1; }
 void HostAttributesManager::initialize() { }
 
+void select_default_policy(const _daq_pkt_hdr&, const snort::SnortConfig*) { }
+void select_default_policy(const _daq_flow_stats&, const snort::SnortConfig*) { }
+
 namespace snort
 {
 static struct timeval s_packet_time = { 0, 0 };
@@ -167,6 +170,7 @@ void DetectionEngine::clear_replacement() { }
 void DetectionEngine::disable_all(Packet*) { }
 unsigned get_instance_id() { return 0; }
 const SnortConfig* SnortConfig::get_conf() { return nullptr; }
+void SnortConfig::update_thread_reload_id() { }
 void PacketTracer::thread_init() { }
 void PacketTracer::thread_term() { }
 void PacketTracer::log(const char*, ...) { }
@@ -203,6 +207,7 @@ void InspectorManager::thread_stop(const SnortConfig*) { }
 void InspectorManager::thread_reinit(const SnortConfig*) { }
 void InspectorManager::thread_stop_removed(const SnortConfig*) { }
 void ModuleManager::accumulate() { }
+void ModuleManager::accumulate_module(const char*) { }
 void Stream::handle_timeouts(bool) { }
 void Stream::purge_flows() { }
 bool Stream::set_packet_action_to_hold(Packet*) { return false; }
@@ -216,3 +221,8 @@ Flow::Flow() = default;
 Flow::~Flow() = default;
 void ThreadConfig::implement_thread_affinity(SThreadType, unsigned) { }
 }
+
+namespace memory
+{
+void MemoryCap::free_space() { }
+}
index 9c426cbb59327a97c2baaf2f95a342ba9d34a44b..115b5efb7812bd623c59a207760d28bd57d14fff 100644 (file)
@@ -23,7 +23,7 @@
 
 #include <unistd.h>
 
-#include "stubs.h"
+#include "distill_verdict_stubs.h"
 
 #include "main/analyzer.h"
 #include "memory/memory_cap.h"
index 3d502019e1dfdd74558c57ba826a73ba06f5c27f..1752574565220358512d9f2287cd1be555ed17af 100644 (file)
@@ -13,6 +13,8 @@ set( MANAGERS_INCLUDES
     inspector_manager.h
 )
 
+add_subdirectory(test)
+
 add_library( managers OBJECT
     ${LUA_INCLUDES}
     ${MANAGERS_INCLUDES}
index 7d167dfa0ffa50b01ce0d15acb1167a23934c642..9c0307242894c9b78f43dd481e62d16e971cadeb 100644 (file)
@@ -39,8 +39,8 @@
 #include "main/snort_debug.h"
 #include "main/snort_module.h"
 #include "main/thread_config.h"
-#include "search_engines/search_tool.h"
 #include "protocols/packet.h"
+#include "search_engines/search_tool.h"
 #include "target_based/snort_protocols.h"
 #include "time/clock_defs.h"
 #include "time/stopwatch.h"
@@ -67,18 +67,18 @@ using namespace std;
 // this distinction should be more precise when policy foo is ripped out of
 // the instances.
 
-struct PHGlobal
+struct PHObject
 {
     const InspectApi& api;
     bool initialized = false;   // In the context of the main thread, this means that api.pinit()
                                 // has been called.  In the packet thread, it means that
                                 // api.tinit() has been called.
-    bool instance_initialized = false;  // In the packet thread, at least one instance has had
+    bool instance_initialized = false;  //In the packet thread, at least one instance has had
                                         // tinit() called.
 
-    PHGlobal(const InspectApi& p) : api(p) { }
+    PHObject(const InspectApi& p) : api(p) { }
 
-    static bool comp(const PHGlobal* a, const PHGlobal* b)
+    static bool comp(const PHObject* a, const PHObject* b)
     { return ( a->api.type < b->api.type ); }
 };
 
@@ -106,6 +106,17 @@ enum ReloadType
     RELOAD_TYPE_MAX
 };
 
+typedef vector<PHObject> PHObjectList;
+typedef vector<PHObjectList*> PHTSObjectLists;
+struct ThreadSpecificHandlers
+{
+    explicit ThreadSpecificHandlers(unsigned max)
+    { olists.resize(max); }
+    ~ThreadSpecificHandlers() = default;
+    PHTSObjectLists olists;
+    unsigned ref_count = 1;
+};
+
 struct PHInstance
 {
     PHClass& pp_class;
@@ -138,8 +149,8 @@ struct PHInstance
     ReloadType get_reload_type()
     { return reload_type; }
 
-    void tinit();
-    void tterm();
+    void tinit(PHObjectList* handlers);
+    void tterm(PHObjectList* handlers);
 };
 
 PHInstance::PHInstance(PHClass& p, SnortConfig* sc, Module* mod) : pp_class(p)
@@ -163,9 +174,18 @@ PHInstance::~PHInstance()
         handler->rem_global_ref();
 }
 
-typedef vector<PHGlobal*> PHGlobalList;
+typedef vector<PHObject*> PHGlobalList;
 typedef vector<PHClass*> PHClassList;
 typedef vector<PHInstance*> PHInstanceList;
+struct PHRemovedInstance
+{
+    PHRemovedInstance(PHInstance* i, PHTSObjectLists& handlers)
+        : instance(i), handlers(handlers)
+    { }
+    PHInstance* instance;
+    PHTSObjectLists& handlers;
+};
+typedef vector<PHRemovedInstance> PHRemovedInstanceList;
 typedef list<Inspector*> PHList;
 
 static PHGlobalList s_handlers;
@@ -173,7 +193,13 @@ static PHList s_trash;
 static PHList s_trash2;
 static bool s_sorted = false;
 
-static THREAD_LOCAL vector<PHGlobal>* s_tl_handlers = nullptr;
+static PHTSObjectLists s_tl_handlers;
+
+void InspectorManager::global_init()
+{
+    if (s_tl_handlers.size() != ThreadConfig::get_instance_max())
+        s_tl_handlers.resize(ThreadConfig::get_instance_max(), nullptr);
+}
 
 struct FrameworkConfig
 {
@@ -182,11 +208,11 @@ struct FrameworkConfig
 
 struct PHVector
 {
-    PHInstance** vec;
-    unsigned num;
+    PHInstance** vec = nullptr;
+    unsigned num = 0;
+    unsigned total_num = 0;
 
-    PHVector()
-    { vec = nullptr; num = 0; }
+    PHVector() = default;
 
     ~PHVector()
     { if ( vec ) delete[] vec; }
@@ -195,7 +221,10 @@ struct PHVector
     { vec = new PHInstance*[max]; }
 
     void add(PHInstance* p)
-    { vec[num++] = p; }
+    {
+        vec[num++] = p;
+        total_num = num;
+    }
 
     void add_control(PHInstance*);
 };
@@ -210,7 +239,6 @@ void PHVector::add_control(PHInstance* p)
 
     if ( strcmp(name, app_id) or !num )
         add(p);
-
     else
     {
         add(vec[0]);
@@ -220,75 +248,176 @@ void PHVector::add_control(PHInstance* p)
 
 struct InspectorList
 {
-    virtual ~InspectorList() = default;
+    virtual ~InspectorList();
 
     PHInstanceList ilist;   // List of inspector module instances
-    PHInstanceList removed_ilist;   // List of removed inspector module instances
+    PHRemovedInstanceList removed_ilist;    // List of removed inspector module instances
 
-    virtual void handle_new_reenabled(SnortConfig*, bool, bool) = 0;
+    virtual void handle_new_reenabled(SnortConfig*, bool, bool)
+    { }
     virtual void vectorize(SnortConfig*) = 0;
 
-    void tinit();
-    void tterm();
+    void tinit(PHObjectList* handlers);
+    void tterm(PHObjectList* handlers);
     void tterm_removed();
 
-    void populate_removed(SnortConfig*, InspectorList*);
+    void populate_removed(SnortConfig*, InspectorList* new_il, PHTSObjectLists& handlers);
+    void populate_removed(SnortConfig*, InspectorList* new_il, InspectorList* def_il,
+        PHTSObjectLists& handlers);
+    void populate_all_removed(SnortConfig* sc, InspectorList* def_il,
+        PHTSObjectLists& handlers);
     void clear_removed();
+    void reconcile_inspectors(SnortConfig*, InspectorList* old_list, bool cloned);
+    void allocate_thread_storage();
 };
 
-void InspectorList::tinit()
+InspectorList::~InspectorList()
+{  clear_removed(); }
+
+void InspectorList::tinit(PHObjectList* handlers)
 {
     for ( auto* p : ilist )
-        p->tinit();
+        p->tinit(handlers);
 }
 
-void InspectorList::tterm()
+void InspectorList::tterm(PHObjectList* handlers)
 {
     for ( auto* p : ilist )
-        p->tterm();
+        p->tterm(handlers);
 }
 
 void InspectorList::tterm_removed()
 {
-    for ( auto* p : removed_ilist )
-        p->tterm();
+    for ( auto& ri : removed_ilist )
+        ri.instance->tterm(ri.handlers[Inspector::slot]);
 }
 
 static PHInstance* get_instance(InspectorList* il, const char* keyword);
 
-void InspectorList::populate_removed(SnortConfig* sc, InspectorList* old)
+void InspectorList::populate_removed(SnortConfig* sc, InspectorList* new_il,
+    PHTSObjectLists& handlers)
 {
-    for (auto it = old->ilist.begin(); it != old->ilist.end(); ++it)
+    assert(new_il);
+    for (auto* p : ilist)
     {
-        PHInstance* instance = get_instance(this, (*it)->name.c_str());
+        PHInstance* instance = get_instance(new_il, p->name.c_str());
         if (!instance)
         {
-            removed_ilist.emplace_back(*it);
-            (*it)->handler->add_global_ref();
-            (*it)->handler->tear_down(sc);
+            new_il->removed_ilist.emplace_back(p, handlers);
+            p->handler->add_global_ref();
+            p->handler->tear_down(sc);
         }
     }
 }
 
+void InspectorList::populate_removed(SnortConfig* sc, InspectorList* new_il,
+    InspectorList* def_il, PHTSObjectLists& handlers)
+{
+    assert(def_il);
+    for (auto* p : ilist)
+    {
+        PHInstance* instance = new_il ? get_instance(new_il, p->name.c_str()) : nullptr;
+        if (!instance)
+        {
+            def_il->removed_ilist.emplace_back(p, handlers);
+            p->handler->add_global_ref();
+            p->handler->tear_down(sc);
+        }
+    }
+}
+
+void InspectorList::populate_all_removed(SnortConfig* sc, InspectorList* def_il,
+    PHTSObjectLists& handlers)
+{
+    assert(def_il);
+    for (auto* p : ilist)
+    {
+        def_il->removed_ilist.emplace_back(p, handlers);
+        p->handler->add_global_ref();
+        p->handler->tear_down(sc);
+    }
+}
+
 void InspectorList::clear_removed()
 {
-    for ( auto* p : removed_ilist )
-        p->handler->rem_global_ref();
+    for ( auto& ri : removed_ilist )
+        ri.instance->handler->rem_global_ref();
     removed_ilist.clear();
 }
 
+void InspectorList::reconcile_inspectors(SnortConfig* sc, InspectorList* old_list, bool cloned)
+{
+    if (old_list)
+    {
+        for (auto* p : ilist)
+        {
+            for (auto* old_p : old_list->ilist)
+            {
+                if (old_p->name == p->name)
+                {
+                    ReloadType reload_type = p->get_reload_type();
+                    if (!cloned || RELOAD_TYPE_NEW == reload_type
+                        || RELOAD_TYPE_REENABLED == reload_type)
+                    {
+                        p->handler->copy_thread_storage(old_p->handler);
+                        p->handler->install_reload_handler(sc);
+                    }
+                    break;
+                }
+            }
+        }
+    }
+}
+
+void InspectorList::allocate_thread_storage()
+{
+    for (auto* p : ilist)
+        p->handler->allocate_thread_storage();
+}
+
+static PHInstance* get_instance_from_vector(const char* key, PHInstance** vec, unsigned num)
+{
+    for (unsigned i = 0; i < num; ++i)
+    {
+        PHInstance* ph = vec[i];
+        if (ph->name == key)
+            return ph;
+    }
+    return nullptr;
+}
+
 struct TrafficPolicy : public InspectorList
 {
+    TrafficPolicy() = default;
+    ~TrafficPolicy() override;
     PHVector passive;
     PHVector packet;
     PHVector first;
     PHVector control;
 
-    void handle_new_reenabled(SnortConfig*, bool, bool) override
-    { }
+    ThreadSpecificHandlers* ts_handlers = nullptr;
+
     void vectorize(SnortConfig*) override;
+    PHInstance* get_instance_by_type(const char* key, InspectorType);
+
+    PHObjectList* get_specific_handlers();
 };
 
+TrafficPolicy::~TrafficPolicy()
+{
+    if (ts_handlers)
+    {
+        assert(ts_handlers->ref_count);
+        --ts_handlers->ref_count;
+        if (!ts_handlers->ref_count)
+        {
+            for (auto* h : ts_handlers->olists)
+                delete h;
+            delete ts_handlers;
+        }
+    }
+}
+
 void TrafficPolicy::vectorize(SnortConfig*)
 {
     passive.alloc(ilist.size());
@@ -313,7 +442,7 @@ void TrafficPolicy::vectorize(SnortConfig*)
             break;
 
         case IT_CONTROL:
-            control.add_control(p);
+            control.add(p);
             break;
 
         default:
@@ -325,6 +454,41 @@ void TrafficPolicy::vectorize(SnortConfig*)
     }
 }
 
+PHObjectList* TrafficPolicy::get_specific_handlers()
+{
+    assert(ts_handlers);
+    PHObjectList* handlers = ts_handlers->olists[Inspector::slot];
+    if (!handlers)
+    {
+        handlers = new PHObjectList;
+        ts_handlers->olists[Inspector::slot] = handlers;
+    }
+    return handlers;
+}
+
+PHInstance* TrafficPolicy::get_instance_by_type(const char* key, InspectorType type)
+{
+    switch (type)
+    {
+    case IT_PASSIVE:
+        return get_instance_from_vector(key, passive.vec, passive.total_num);
+
+    case IT_PACKET:
+        return get_instance_from_vector(key, packet.vec, packet.total_num);
+
+    case IT_FIRST:
+        return get_instance_from_vector(key, first.vec, first.total_num);
+
+    case IT_CONTROL:
+        return get_instance_from_vector(key, control.vec, control.total_num);
+
+    default:
+        assert(false);
+        break;
+    }
+    return nullptr;
+}
+
 class SingleInstanceInspectorPolicy
 {
 public:
@@ -332,13 +496,16 @@ public:
     ~SingleInstanceInspectorPolicy();
 
     bool get_new(SnortConfig*, Module*, PHClass&, PHInstance*&);
-    void populate_removed(SnortConfig*, SingleInstanceInspectorPolicy*);
+    void populate_removed(SnortConfig*, SingleInstanceInspectorPolicy* new_instance);
     void clear_removed();
     void configure(SnortConfig*);
-    void tinit();
-    void tterm();
+    void reconcile_inspector(SnortConfig*, SingleInstanceInspectorPolicy* old_instance,
+        bool cloned);
+    void tinit(PHObjectList* handlers);
+    void tterm(PHObjectList* handlers);
     void tterm_removed();
     void print_config(SnortConfig*, const char* title);
+    void allocate_thread_storage();
 
     PHInstance* instance = nullptr;
     PHInstance* removed_instance = nullptr;
@@ -355,7 +522,7 @@ SingleInstanceInspectorPolicy::~SingleInstanceInspectorPolicy()
 
         delete instance;
     }
-    assert(nullptr == removed_instance);
+    clear_removed();
 }
 
 bool SingleInstanceInspectorPolicy::get_new(SnortConfig*sc, Module* mod, PHClass& pc,
@@ -375,13 +542,13 @@ bool SingleInstanceInspectorPolicy::get_new(SnortConfig*sc, Module* mod, PHClass
 }
 
 void SingleInstanceInspectorPolicy::populate_removed(SnortConfig* sc,
-    SingleInstanceInspectorPolicy* old)
+    SingleInstanceInspectorPolicy* new_instance)
 {
-    if (old->instance && !instance)
+    if (instance && !new_instance->instance)
     {
-        removed_instance = old->instance;
-        removed_instance->handler->add_global_ref();
-        removed_instance->handler->tear_down(sc);
+        new_instance->removed_instance = instance;
+        instance->handler->add_global_ref();
+        instance->handler->tear_down(sc);
     }
 }
 
@@ -400,22 +567,37 @@ void SingleInstanceInspectorPolicy::configure(SnortConfig* sc)
         instance->handler->configure(sc);
 }
 
-void SingleInstanceInspectorPolicy::tinit()
+void SingleInstanceInspectorPolicy::reconcile_inspector(SnortConfig* sc,
+    SingleInstanceInspectorPolicy* old_instance, bool cloned)
+{
+    if (instance && old_instance && old_instance->instance)
+    {
+        ReloadType reload_type = instance->get_reload_type();
+        if (!cloned || RELOAD_TYPE_NEW == reload_type
+            || RELOAD_TYPE_REENABLED == reload_type)
+        {
+            instance->handler->copy_thread_storage(old_instance->instance->handler);
+            instance->handler->install_reload_handler(sc);
+        }
+    }
+}
+
+void SingleInstanceInspectorPolicy::tinit(PHObjectList* handlers)
 {
     if (instance)
-        instance->tinit();
+        instance->tinit(handlers);
 }
 
-void SingleInstanceInspectorPolicy::tterm()
+void SingleInstanceInspectorPolicy::tterm(PHObjectList* handlers)
 {
     if (instance)
-        instance->tterm();
+        instance->tterm(handlers);
 }
 
 void SingleInstanceInspectorPolicy::tterm_removed()
 {
     if (removed_instance)
-        removed_instance->tterm();
+        removed_instance->tterm(s_tl_handlers[Inspector::slot]);
 }
 
 void SingleInstanceInspectorPolicy::print_config(SnortConfig* sc, const char* title)
@@ -429,20 +611,27 @@ void SingleInstanceInspectorPolicy::print_config(SnortConfig* sc, const char* ti
     }
 }
 
+void SingleInstanceInspectorPolicy::allocate_thread_storage()
+{
+    if (instance)
+        instance->handler->allocate_thread_storage();
+}
+
 struct GlobalInspectorPolicy : public InspectorList
 {
     PHVector passive;
     PHVector probe;
+    PHVector control;
 
-    void handle_new_reenabled(SnortConfig*, bool, bool) override
-    { }
     void vectorize(SnortConfig*) override;
+    PHInstance* get_instance_by_type(const char* key, InspectorType);
 };
 
 void GlobalInspectorPolicy::vectorize(SnortConfig*)
 {
     passive.alloc(ilist.size());
     probe.alloc(ilist.size());
+    control.alloc(ilist.size());
     for ( auto* p : ilist )
     {
         switch ( p->pp_class.api.type )
@@ -455,6 +644,10 @@ void GlobalInspectorPolicy::vectorize(SnortConfig*)
             probe.add(p);
             break;
 
+        case IT_CONTROL:
+            control.add_control(p);
+            break;
+
         default:
             ParseError(
                 "Global inspector policy (global usage) does not handle inspector %s with type %s\n",
@@ -464,6 +657,26 @@ void GlobalInspectorPolicy::vectorize(SnortConfig*)
     }
 }
 
+PHInstance* GlobalInspectorPolicy::get_instance_by_type(const char* key, InspectorType type)
+{
+    switch (type)
+    {
+    case IT_PASSIVE:
+        return get_instance_from_vector(key, passive.vec, passive.total_num);
+
+    case IT_PROBE:
+        return get_instance_from_vector(key, probe.vec, probe.total_num);
+
+    case IT_CONTROL:
+        return get_instance_from_vector(key, control.vec, control.total_num);
+
+    default:
+        assert(false);
+        break;
+    }
+    return nullptr;
+}
+
 struct FrameworkPolicy : public InspectorList
 {
     PHVector passive;
@@ -482,8 +695,8 @@ struct FrameworkPolicy : public InspectorList
     void handle_new_reenabled(SnortConfig*, bool, bool) override;
     void vectorize(SnortConfig*) override;
     void add_inspector_to_cache(PHInstance*, SnortConfig*);
-    void remove_inspector_from_cache(Inspector*);
     bool delete_inspector(SnortConfig*, const char* iname);
+    PHInstance* get_instance_by_type(const char* key, InspectorType);
 };
 
 void FrameworkPolicy::add_inspector_to_cache(PHInstance* p, SnortConfig* sc)
@@ -497,49 +710,27 @@ void FrameworkPolicy::add_inspector_to_cache(PHInstance* p, SnortConfig* sc)
     }
 }
 
-void FrameworkPolicy::remove_inspector_from_cache(Inspector* ins)
-{
-    if (!ins)
-        return;
-
-    for(auto i = inspector_cache_by_id.begin(); i != inspector_cache_by_id.end(); i++)
-    {
-        if (ins == i->second)
-        {
-            inspector_cache_by_id.erase(i);
-            break;
-        }
-    }
-    for(auto i = inspector_cache_by_service.begin(); i != inspector_cache_by_service.end(); i++)
-    {
-        if (ins == i->second)
-        {
-            inspector_cache_by_service.erase(i);
-            break;
-        }
-    }
-}
-
 static bool get_instance(InspectorList*, const char*,
     std::vector<PHInstance*>::iterator&);
 
 void FrameworkPolicy::handle_new_reenabled(SnortConfig* sc, bool new_ins, bool reenabled_ins)
 {
-    std::vector<PHInstance*>::iterator old_binder;
-    if ( get_instance(this, bind_id, old_binder) )
+    if ( new_ins or reenabled_ins )
     {
-        if ( new_ins and default_binder )
+        std::vector<PHInstance*>::iterator old_binder;
+        if ( get_instance(this, bind_id, old_binder) )
         {
-            if ( !((*old_binder)->is_reloaded()) )
+            if ( new_ins and default_binder )
             {
-                (*old_binder)->set_reloaded(RELOAD_TYPE_REENABLED);
-                ilist.erase(old_binder);
+                if ( !((*old_binder)->is_reloaded()) )
+                {
+                    (*old_binder)->set_reloaded(RELOAD_TYPE_REENABLED);
+                    ilist.erase(old_binder);
+                }
+                default_binder = false;
             }
-            default_binder = false;
-        }
-        else if ( reenabled_ins and !((*old_binder)->is_reloaded()) )
-        {
-            (*old_binder)->handler->configure(sc);
+            else if ( reenabled_ins and !((*old_binder)->is_reloaded()) )
+                (*old_binder)->handler->configure(sc);
         }
     }
 }
@@ -606,7 +797,6 @@ bool FrameworkPolicy::delete_inspector(SnortConfig* sc, const char* iname)
     if ( get_instance(this, iname, old_it) )
     {
         (*old_it)->set_reloaded(RELOAD_TYPE_DELETED);
-        remove_inspector_from_cache((*old_it)->handler);
         ilist.erase(old_it);
         std::vector<PHInstance*>::iterator bind_it;
         if ( get_instance(this, bind_id, bind_it) )
@@ -616,6 +806,35 @@ bool FrameworkPolicy::delete_inspector(SnortConfig* sc, const char* iname)
     return false;
 }
 
+PHInstance* FrameworkPolicy::get_instance_by_type(const char* key, InspectorType type)
+{
+    switch (type)
+    {
+    case IT_PASSIVE:
+        return get_instance_from_vector(key, passive.vec, passive.total_num);
+
+    case IT_PACKET:
+        return get_instance_from_vector(key, packet.vec, packet.total_num);
+
+    case IT_NETWORK:
+        return get_instance_from_vector(key, network.vec, network.total_num);
+
+    case IT_STREAM:
+        return get_instance(this, key);
+
+    case IT_SERVICE:
+        return get_instance_from_vector(key, service.vec, service.total_num);
+
+    case IT_WIZARD:
+        return get_instance(this, key);
+
+    default:
+        assert(false);
+        break;
+    }
+    return nullptr;
+}
+
 //-------------------------------------------------------------------------
 // global stuff
 //-------------------------------------------------------------------------
@@ -641,7 +860,7 @@ const char* InspectorManager::get_inspector_type(const char* name)
 
 void InspectorManager::add_plugin(const InspectApi* api)
 {
-    PHGlobal* g = new PHGlobal(*api);
+    PHObject* g = new PHObject(*api);
     s_handlers.emplace_back(g);
 }
 
@@ -834,7 +1053,6 @@ void InspectorManager::new_policy(InspectionPolicy* pi, InspectionPolicy* other_
 void InspectorManager::new_policy(NetworkPolicy* pi, NetworkPolicy* other_pi)
 {
     pi->traffic_policy = new TrafficPolicy;
-
     if ( other_pi )
         pi->traffic_policy->ilist = other_pi->traffic_policy->ilist;
 }
@@ -843,7 +1061,7 @@ void InspectorManager::delete_policy(InspectionPolicy* pi, bool cloned)
 {
     for ( auto* p : pi->framework_policy->ilist )
     {
-        if ( cloned and !(p->is_reloaded()) )
+        if ( cloned and !p->is_reloaded() )
             continue;
 
         if ( p->handler->get_api()->type == IT_PASSIVE )
@@ -880,21 +1098,25 @@ void InspectorManager::update_policy(SnortConfig* sc)
     GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
     for ( auto* p : pp->ilist )
         p->set_reloaded(RELOAD_TYPE_NONE);
-    InspectionPolicy* ip = sc->policy_map->get_inspection_policy();
-    for ( auto* p : ip->framework_policy->ilist )
-        p->set_reloaded(RELOAD_TYPE_NONE);
-    NetworkPolicy* np = sc->policy_map->get_network_policy();
-    for ( auto* p : np->traffic_policy->ilist )
-        p->set_reloaded(RELOAD_TYPE_NONE);
+    for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx)
+    {
+        NetworkPolicy* np = sc->policy_map->get_network_policy(idx);
+        for ( auto* p : np->traffic_policy->ilist )
+            p->set_reloaded(RELOAD_TYPE_NONE);
+        InspectionPolicy* ip = np->get_inspection_policy();
+        for ( auto* p : ip->framework_policy->ilist )
+            p->set_reloaded(RELOAD_TYPE_NONE);
+    }
 }
 
 Binder* InspectorManager::get_binder()
 {
     InspectionPolicy* pi = get_inspection_policy();
 
-    if ( !pi || !pi->framework_policy )
+    if ( !pi )
         return nullptr;
 
+    assert(pi->framework_policy);
     return (Binder*)pi->framework_policy->binder;
 }
 
@@ -906,33 +1128,75 @@ void InspectorManager::clear_removed_inspectors(SnortConfig* sc)
     ft->clear_removed();
     GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
     pp->clear_removed();
-    TrafficPolicy* tp = sc->policy_map->get_network_policy()->traffic_policy;
-    tp->clear_removed();
-    FrameworkPolicy* fp = sc->policy_map->get_inspection_policy()->framework_policy;
-    fp->clear_removed();
+    for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx)
+    {
+        NetworkPolicy* np = sc->policy_map->get_network_policy(idx);
+        np->traffic_policy->clear_removed();
+        FrameworkPolicy* fp = np->get_inspection_policy()->framework_policy;
+        fp->clear_removed();
+    }
 }
 
-void InspectorManager::tear_down_removed_inspectors(const SnortConfig* old, SnortConfig* sc)
+void InspectorManager::reconcile_inspectors(const SnortConfig* old, SnortConfig* sc, bool cloned)
 {
     SingleInstanceInspectorPolicy* old_fid = old->policy_map->get_file_id();
     SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
-    fid->populate_removed(sc, old_fid);
+    old_fid->populate_removed(sc, fid);
+    fid->reconcile_inspector(sc, old_fid, cloned);
 
     SingleInstanceInspectorPolicy* old_ft = old->policy_map->get_flow_tracking();
     SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
-    ft->populate_removed(sc, old_ft);
+    old_ft->populate_removed(sc, ft);
+    ft->reconcile_inspector(sc, old_ft, cloned);
 
     GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
     GlobalInspectorPolicy* old_pp = old->policy_map->get_global_inspector_policy();
-    pp->populate_removed(sc, old_pp);
+    old_pp->populate_removed(sc, pp, s_tl_handlers);
+    pp->reconcile_inspectors(sc, old_pp, cloned);
 
-    TrafficPolicy* tp = get_default_network_policy(sc)->traffic_policy;
-    TrafficPolicy* old_tp = get_default_network_policy(old)->traffic_policy;
-    tp->populate_removed(sc, old_tp);
+    // Put all removed instances in the default traffic policy
+    TrafficPolicy* default_tp = sc->policy_map->get_network_policy(0)->traffic_policy;
+    for (unsigned idx = 0; idx < old->policy_map->network_policy_count(); ++idx)
+    {
+        NetworkPolicy* old_np = old->policy_map->get_network_policy(idx);
+        NetworkPolicy* np = sc->policy_map->get_user_network(old_np->user_policy_id);
+        if (np)
+        {
+            TrafficPolicy* tp = np->traffic_policy;
+            TrafficPolicy* old_tp = old_np->traffic_policy;
+            tp->ts_handlers = old_tp->ts_handlers;
+            ++tp->ts_handlers->ref_count;
+
+            PHTSObjectLists& handlers = tp->ts_handlers->olists;
+            old_tp->populate_removed(sc, tp, default_tp, handlers);
+
+            FrameworkPolicy* old_fp = old_np->get_inspection_policy(0)->framework_policy;
+            FrameworkPolicy* fp = np->get_inspection_policy(0)->framework_policy;
+            old_fp->populate_removed(sc, fp, default_tp, handlers);
+        }
+        else
+        {
+            TrafficPolicy* old_tp = old_np->traffic_policy;
+            old_tp->populate_all_removed(sc, default_tp, old_tp->ts_handlers->olists);
 
-    FrameworkPolicy* fp = get_default_inspection_policy(sc)->framework_policy;
-    FrameworkPolicy* old_fp = get_default_inspection_policy(old)->framework_policy;
-    fp->populate_removed(sc, old_fp);
+            FrameworkPolicy* old_fp = old_np->get_inspection_policy(0)->framework_policy;
+            old_fp->populate_all_removed(sc, default_tp, old_tp->ts_handlers->olists);
+        }
+    }
+
+    for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx)
+    {
+        NetworkPolicy* np = sc->policy_map->get_network_policy(idx);
+        NetworkPolicy* old_np = old->policy_map->get_user_network(np->user_policy_id);
+        if (old_np)
+        {
+            np->traffic_policy->reconcile_inspectors(sc, old_np->traffic_policy, cloned);
+
+            FrameworkPolicy* fp = np->get_inspection_policy(0)->framework_policy;
+            FrameworkPolicy* old_fp = old_np->get_inspection_policy(0)->framework_policy;
+            fp->reconcile_inspectors(sc, old_fp, cloned);
+        }
+    }
 }
 
 Inspector* InspectorManager::get_file_inspector(const SnortConfig* sc)
@@ -953,8 +1217,8 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons
         sc = SnortConfig::get_conf();
     if ( dflt_only )
     {
-        pi = get_default_inspection_policy(sc);
         ni = get_default_network_policy(sc);
+        pi = ni->get_inspection_policy(0);
     }
     else
     {
@@ -962,7 +1226,7 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons
         ni = get_network_policy();
     }
 
-    if ( pi && pi->framework_policy )
+    if ( pi )
     {
         PHInstance* p = get_instance(pi->framework_policy, key);
         if ( p )
@@ -992,13 +1256,58 @@ Inspector* InspectorManager::get_inspector(const char* key, bool dflt_only, cons
     return nullptr;
 }
 
+Inspector* InspectorManager::get_inspector(const char* key, Module::Usage usage,
+    InspectorType type, const SnortConfig* sc)
+{
+    if ( !sc )
+        sc = SnortConfig::get_conf();
+
+    if (Module::GLOBAL == usage && IT_FILE == type)
+    {
+        SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
+        assert(fid);
+        return (fid->instance && fid->instance->name == key) ? fid->instance->handler : nullptr;
+    }
+    else if (Module::GLOBAL == usage && IT_STREAM == type)
+    {
+        SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
+        assert(ft);
+        return (ft->instance && ft->instance->name == key) ? ft->instance->handler : nullptr;
+    }
+    else
+    {
+        if (Module::GLOBAL == usage && IT_SERVICE != type)
+        {
+            GlobalInspectorPolicy* il = sc->policy_map->get_global_inspector_policy();
+            assert(il);
+            PHInstance* p = il->get_instance_by_type(key, type);
+            return p ? p->handler : nullptr;
+        }
+        else if (Module::CONTEXT == usage)
+        {
+            TrafficPolicy* il = get_network_policy()->traffic_policy;
+            assert(il);
+            PHInstance* p = il->get_instance_by_type(key, type);
+            return p ? p->handler : nullptr;
+        }
+        else
+        {
+            FrameworkPolicy* il = get_inspection_policy()->framework_policy;
+            assert(il);
+            PHInstance* p = il->get_instance_by_type(key, type);
+            return p ? p->handler : nullptr;
+        }
+    }
+}
+
 Inspector* InspectorManager::get_service_inspector_by_service(const char* key)
 {
     InspectionPolicy* pi = get_inspection_policy();
 
-    if ( !pi || !pi->framework_policy )
+    if ( !pi )
         return nullptr;
 
+    assert(pi->framework_policy);
     auto g = pi->framework_policy->inspector_cache_by_service.find(key);
     return (g != pi->framework_policy->inspector_cache_by_service.end()) ? g->second : nullptr;
 }
@@ -1007,16 +1316,18 @@ Inspector* InspectorManager::get_service_inspector_by_id(const SnortProtocolId p
 {
     InspectionPolicy* pi = get_inspection_policy();
 
-    if ( !pi || !pi->framework_policy )
+    if ( !pi )
         return nullptr;
 
+    assert(pi->framework_policy);
     auto g = pi->framework_policy->inspector_cache_by_id.find(protocol_id);
     return (g != pi->framework_policy->inspector_cache_by_id.end()) ? g->second : nullptr;
 }
 
 bool InspectorManager::delete_inspector(SnortConfig* sc, const char* iname)
 {
-    FrameworkPolicy* fp = sc->policy_map->get_inspection_policy()->framework_policy;
+    FrameworkPolicy* fp =
+        sc->policy_map->get_network_policy(0)->get_inspection_policy()->framework_policy;
     return fp->delete_inspector(sc, iname);
 }
 
@@ -1078,32 +1389,33 @@ static PHClass* get_class(const char* keyword, FrameworkConfig* fc)
     return nullptr;
 }
 
-static PHGlobal& get_thread_local_plugin(const InspectApi& api)
+static PHObject& get_thread_local_plugin(const InspectApi& api, PHObjectList* handlers)
 {
-    assert(s_tl_handlers != nullptr);
+    assert(handlers);
 
-    for ( PHGlobal& phg : *s_tl_handlers )
+    for ( PHObject& phg : *handlers )
     {
         if ( &phg.api == &api )
             return phg;
     }
-    s_tl_handlers->emplace_back(api);
-    return s_tl_handlers->back();
+    handlers->emplace_back(api);
+    return handlers->back();
 }
 
-void PHInstance::tinit()
+void PHInstance::tinit(PHObjectList* handlers)
 {
-    PHGlobal& phg = get_thread_local_plugin(pp_class.api);
+    PHObject& phg = get_thread_local_plugin(pp_class.api, handlers);
     if ( !phg.instance_initialized )
     {
-        handler->tinit();
         phg.instance_initialized = true;
+        handler->tinit();
     }
 }
 
-void PHInstance::tterm()
+void PHInstance::tterm(PHObjectList* handlers)
 {
-    PHGlobal& phg = get_thread_local_plugin(pp_class.api);
+    assert(handlers);
+    PHObject& phg = get_thread_local_plugin(pp_class.api, handlers);
     if ( phg.instance_initialized )
     {
         handler->tterm();
@@ -1113,73 +1425,90 @@ void PHInstance::tterm()
 
 void InspectorManager::thread_init(const SnortConfig* sc)
 {
+    SnortConfig::update_thread_reload_id();
     Inspector::slot = get_instance_id();
 
     // Initial build out of this thread's configured plugin registry
-    s_tl_handlers = new vector<PHGlobal>;
+    PHObjectList* g_handlers = new PHObjectList;
+    s_tl_handlers[Inspector::slot] = g_handlers;
     for ( auto* p : sc->framework_config->clist )
     {
-        PHGlobal& phg = get_thread_local_plugin(p->api);
+        PHObject& phg = get_thread_local_plugin(p->api, g_handlers);
         if (phg.api.tinit)
             phg.api.tinit();
         phg.initialized = true;
     }
 
-    // pin->tinit() only called for default policy
-    set_default_policy(sc);
     SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
-    fid->tinit();
+    fid->tinit(g_handlers);
     SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
-    ft->tinit();
+    ft->tinit(g_handlers);
 
     GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
-    pp->tinit();
-
-    InspectionPolicy* pi = get_inspection_policy();
-    if ( pi && pi->framework_policy )
-        pi->framework_policy->tinit();
+    pp->tinit(g_handlers);
 
     for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++)
     {
         NetworkPolicy* npi = sc->policy_map->get_network_policy(i);
+        PHObjectList* handlers = npi->traffic_policy->get_specific_handlers();
         set_network_policy(npi);
-        npi->traffic_policy->tinit();
+        npi->traffic_policy->tinit(handlers);
+
+        InspectionPolicy* pi = npi->get_inspection_policy(0);
+        if ( pi )
+        {
+            set_inspection_policy(pi);
+            assert(pi->framework_policy);
+            pi->framework_policy->tinit(handlers);
+        }
     }
 }
 
 void InspectorManager::thread_reinit(const SnortConfig* sc)
 {
-    // Update this thread's configured plugin registry with any newly configured inspectors
-    for ( auto* p : sc->framework_config->clist )
+    SnortConfig::update_thread_reload_id();
+    unsigned instance_id = get_instance_id();
+    if (!sc->policy_map->get_inspector_tinit_complete(instance_id))
     {
-        PHGlobal& phg = get_thread_local_plugin(p->api);
-        if (!phg.initialized)
+        sc->policy_map->set_inspector_tinit_complete(instance_id, true);
+
+        // Update this thread's configured plugin registry with any newly configured inspectors
+        PHObjectList* g_handlers = s_tl_handlers[Inspector::slot];
+        for ( auto* p : sc->framework_config->clist )
         {
-            if (phg.api.tinit)
-                phg.api.tinit();
-            phg.initialized = true;
+            PHObject& phg = get_thread_local_plugin(p->api, g_handlers);
+            if (!phg.initialized)
+            {
+                if (phg.api.tinit)
+                    phg.api.tinit();
+                phg.initialized = true;
+            }
         }
-    }
-
-    set_default_policy(sc);
-    SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
-    fid->tinit();
-    SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
-    ft->tinit();
 
-    GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
-    pp->tinit();
+        SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
+        fid->tinit(g_handlers);
+        SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
+        ft->tinit(g_handlers);
 
-    for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++)
-    {
-        NetworkPolicy* npi = sc->policy_map->get_network_policy(i);
-        set_network_policy(npi);
-        npi->traffic_policy->tinit();
+        GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
+        pp->tinit(g_handlers);
 
-        // pin->tinit() only called for default policy
-        InspectionPolicy* pi = get_default_inspection_policy(sc);
-        if ( pi && pi->framework_policy )
-            pi->framework_policy->tinit();
+        for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++)
+        {
+            NetworkPolicy* npi = sc->policy_map->get_network_policy(i);
+            PHObjectList* handlers = npi->traffic_policy->get_specific_handlers();
+            set_network_policy(npi);
+            npi->traffic_policy->tinit(handlers);
+
+            // pin->tinit() only called for default policy
+            InspectionPolicy* pi = npi->get_inspection_policy(0);
+            if ( pi )
+            {
+                set_inspection_policy(pi);
+                assert(pi->framework_policy);
+                pi->framework_policy->tinit(handlers);
+            }
+        }
     }
 }
 
@@ -1194,65 +1523,70 @@ void InspectorManager::thread_stop_removed(const SnortConfig* sc)
     GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
     pp->tterm_removed();
 
-    // pin->tinit() only called for default policy
     NetworkPolicy* npi = get_default_network_policy(sc);
     if ( npi && npi->traffic_policy )
     {
         // Call pin->tterm() for anything that has been initialized and removed
         npi->traffic_policy->tterm_removed();
-    }
 
-    // pin->tinit() only called for default policy
-    InspectionPolicy* pi = get_default_inspection_policy(sc);
-
-    if ( pi && pi->framework_policy )
-    {
-        // Call pin->tterm() for anything that has been initialized and removed
-        pi->framework_policy->tterm_removed();
+        // pin->tinit() only called for default policy
+        InspectionPolicy* pi = npi->get_inspection_policy(0);
+        if ( pi )
+        {
+            assert(pi->framework_policy);
+            // Call pin->tterm() for anything that has been initialized and removed
+            pi->framework_policy->tterm_removed();
+        }
     }
 }
 
 void InspectorManager::thread_stop(const SnortConfig* sc)
 {
     // If thread_init() was never called, we have nothing to do.
-    if ( !s_tl_handlers )
+    PHObjectList* g_handlers = s_tl_handlers[Inspector::slot];
+    if ( !g_handlers )
         return;
 
     set_default_policy(sc);
     SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
-    fid->tterm();
+    fid->tterm(g_handlers);
     SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
-    ft->tterm();
+    ft->tterm(g_handlers);
 
     GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
-    pp->tterm();
-
-    InspectionPolicy* pi = get_inspection_policy();
-    if ( pi && pi->framework_policy )
-        pi->framework_policy->tterm();
+    pp->tterm(g_handlers);
 
     for ( unsigned i = 0; i < sc->policy_map->network_policy_count(); i++)
     {
         NetworkPolicy* npi = sc->policy_map->get_network_policy(i);
+        PHObjectList* handlers = npi->traffic_policy->get_specific_handlers();
         set_network_policy(npi);
-        npi->traffic_policy->tterm();
+        npi->traffic_policy->tterm(handlers);
+
+        InspectionPolicy* pi = npi->get_inspection_policy(0);
+        if ( pi )
+        {
+            assert(pi->framework_policy);
+            pi->framework_policy->tterm(handlers);
+        }
     }
 }
 
 void InspectorManager::thread_term()
 {
     // If thread_init() was never called, we have nothing to do.
-    if ( !s_tl_handlers )
+    PHObjectList* handlers = s_tl_handlers[Inspector::slot];
+    if ( !handlers )
         return;
 
     // Call tterm for every inspector plugin ever configured during the lifetime of this thread
-    for ( PHGlobal& phg : *s_tl_handlers )
+    for ( PHObject& phg : *handlers )
     {
         if ( phg.api.tterm && phg.initialized )
             phg.api.tterm();
     }
-    delete s_tl_handlers;
-    s_tl_handlers = nullptr;
+    delete handlers;
+    s_tl_handlers[Inspector::slot] = nullptr;
 }
 
 //-------------------------------------------------------------------------
@@ -1308,13 +1642,17 @@ void InspectorManager::instantiate(
             }
             else if (Module::CONTEXT == mod->get_usage())
             {
-                TrafficPolicy* il = get_network_policy()->traffic_policy;
+                NetworkPolicy* np = get_network_policy();
+                assert(np);
+                TrafficPolicy* il = np->traffic_policy;
                 assert(il);
                 ppi = get_new(ppc, il, keyword, mod, sc);
             }
             else
             {
-                FrameworkPolicy* il = get_inspection_policy()->framework_policy;
+                InspectionPolicy* ip = get_inspection_policy();
+                assert(ip);
+                FrameworkPolicy* il = ip->framework_policy;
                 assert(il);
                 ppi = get_new(ppc, il, keyword, mod, sc);
             }
@@ -1338,7 +1676,9 @@ Inspector* InspectorManager::instantiate(
     if ( !ppc )
         return nullptr;
 
-    auto fp = get_inspection_policy()->framework_policy;
+    InspectionPolicy* ip = get_inspection_policy();
+    assert(ip);
+    auto fp = ip->framework_policy;
     auto ppi = get_new(ppc, fp, name, mod, sc);
 
     if ( !ppi )
@@ -1381,7 +1721,9 @@ static void instantiate_default_binder(SnortConfig* sc, FrameworkPolicy* fp)
 
     const InspectApi* api = get_plugin(bind_id);
     InspectorManager::instantiate(api, m, sc);
-    fp->binder = get_instance(fp, bind_id)->handler;
+    PHInstance* instance = get_instance(fp, bind_id);
+    assert(instance);
+    fp->binder = instance->handler;
     fp->binder->configure(sc);
     fp->default_binder = true;
 }
@@ -1393,10 +1735,9 @@ static bool configure(SnortConfig* sc, InspectorList* il, bool cloned, bool& new
 
     for ( auto* p : il->ilist )
     {
-        ReloadType reload_type = p->get_reload_type();
-
         if ( cloned )
         {
+            ReloadType reload_type = p->get_reload_type();
             if ( reload_type == RELOAD_TYPE_NEW )
                 new_ins = true;
             else if ( reload_type == RELOAD_TYPE_REENABLED )
@@ -1406,9 +1747,7 @@ static bool configure(SnortConfig* sc, InspectorList* il, bool cloned, bool& new
         }
         ok = p->handler->configure(sc) && ok;
     }
-
-    if ( new_ins or reenabled_ins )
-        il->handle_new_reenabled(sc, new_ins, reenabled_ins);
+    il->handle_new_reenabled(sc, new_ins, reenabled_ins);
 
     sort(il->ilist.begin(), il->ilist.end(), PHInstance::comp);
     il->vectorize(sc);
@@ -1428,18 +1767,6 @@ Inspector* InspectorManager::acquire_file_inspector()
     return pi;
 }
 
-Inspector* InspectorManager::acquire(const char* key, bool dflt_only)
-{
-    Inspector* pi = get_inspector(key, dflt_only);
-
-    if ( !pi )
-        FatalError("unconfigured inspector: '%s'.\n", key);
-    else
-        pi->add_global_ref();
-
-    return pi;
-}
-
 void InspectorManager::release(Inspector* pi)
 {
     assert(pi);
@@ -1450,66 +1777,106 @@ bool InspectorManager::configure(SnortConfig* sc, bool cloned)
 {
     if ( !s_sorted )
     {
-        sort(s_handlers.begin(), s_handlers.end(), PHGlobal::comp);
+        sort(s_handlers.begin(), s_handlers.end(), PHObject::comp);
         s_sorted = true;
     }
     bool ok = true;
 
     SearchTool::set_conf(sc);
 
-    bool new_ins = false;
-    bool reenabled_ins = false;
-    for ( unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx )
-    {
-        if ( cloned and idx )
-            break;
-
-        ::set_network_policy(sc, idx);
-        NetworkPolicy* p = sc->policy_map->get_network_policy(idx);
-        ok = ::configure(sc, p->traffic_policy, cloned, new_ins, reenabled_ins) && ok;
-    }
-    ::set_network_policy(sc);
-
     SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
     fid->configure(sc);
 
     SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
     ft->configure(sc);
 
+    bool new_ins = false;
+    bool reenabled_ins = false;
+
     GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
     ok = ::configure(sc, pp, cloned, new_ins, reenabled_ins) && ok;
 
-    ::set_inspection_policy(sc);
-    for ( unsigned idx = 0; idx < sc->policy_map->inspection_policy_count(); ++idx )
+    for ( unsigned nidx = 0; nidx < sc->policy_map->network_policy_count(); ++nidx )
     {
-        if ( cloned and idx )
-            break;
+        NetworkPolicy* np = sc->policy_map->get_network_policy(nidx);
+        assert(np);
+        set_network_policy(np);
+        ok = ::configure(sc, np->traffic_policy, cloned, new_ins, reenabled_ins) && ok;
 
-        ::set_inspection_policy(sc, idx);
-        InspectionPolicy* p = sc->policy_map->get_inspection_policy(idx);
-        p->configure();
-        ok = ::configure(sc, p->framework_policy, cloned, new_ins, reenabled_ins) && ok;
+        for ( unsigned idx = 0; idx < np->inspection_policy_count(); ++idx )
+        {
+            if ( cloned and idx )
+                break;
+
+            InspectionPolicy* p = np->get_inspection_policy(idx);
+            assert(p);
+            set_inspection_policy(p);
+            p->configure();
+            ok = ::configure(sc, p->framework_policy, cloned, new_ins, reenabled_ins) && ok;
+        }
     }
 
-    ::set_inspection_policy(sc);
+    NetworkPolicy* np = sc->policy_map->get_network_policy();
+    assert(np);
+    set_network_policy(np);
+    set_inspection_policy(np->get_inspection_policy());
     SearchTool::set_conf(nullptr);
 
     return ok;
 }
 
+void InspectorManager::prepare_inspectors(SnortConfig* sc)
+{
+    SingleInstanceInspectorPolicy* fid = sc->policy_map->get_file_id();
+    fid->allocate_thread_storage();
+
+    SingleInstanceInspectorPolicy* ft = sc->policy_map->get_flow_tracking();
+    ft->allocate_thread_storage();
+
+    GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
+    pp->allocate_thread_storage();
+
+    for (unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx)
+    {
+        NetworkPolicy* np = sc->policy_map->get_network_policy(idx);
+        TrafficPolicy* tp = np->traffic_policy;
+        if (!tp->ts_handlers)
+            tp->ts_handlers = new ThreadSpecificHandlers(ThreadConfig::get_instance_max());
+        tp->allocate_thread_storage();
+    }
+}
+
 // remove any disabled controls while retaining order
 void InspectorManager::prepare_controls(SnortConfig* sc)
 {
+    GlobalInspectorPolicy* gp = sc->policy_map->get_global_inspector_policy();
+    unsigned g_c = 0;
+    std::vector<PHInstance*> g_disabled;
+    for ( unsigned i = 0; i < gp->control.num; ++i )
+    {
+        if ( !gp->control.vec[i]->handler->disable(sc) )
+            gp->control.vec[g_c++] = gp->control.vec[i];
+        else
+            g_disabled.emplace_back(gp->control.vec[i]);
+    }
+    gp->control.num = g_c;
+    for (auto* ph : g_disabled)
+        gp->control.vec[g_c++] = ph;
     for ( unsigned idx = 0; idx < sc->policy_map->network_policy_count(); ++idx )
     {
         TrafficPolicy* tp = sc->policy_map->get_network_policy(idx)->traffic_policy;
         unsigned c = 0;
+        std::vector<PHInstance*> disabled;
         for ( unsigned i = 0; i < tp->control.num; ++i )
         {
             if ( !tp->control.vec[i]->handler->disable(sc) )
                 tp->control.vec[c++] = tp->control.vec[i];
+            else
+                disabled.emplace_back(tp->control.vec[i]);
         }
         tp->control.num = c;
+        for (auto* ph : disabled)
+            tp->control.vec[c++] = ph;
     }
 }
 
@@ -1577,8 +1944,9 @@ void InspectorManager::print_config(SnortConfig* sc)
         }
 
         const auto inspection = policies->inspection;
-        if ( inspection and inspection->framework_policy )
+        if ( inspection )
         {
+            assert(inspection->framework_policy);
             const std::string label = "Inspection Policy : policy id " +
                 std::to_string(inspection->user_policy_id) + " : " +
                 shell->get_file();
@@ -1676,7 +2044,7 @@ void InspectorManager::full_inspection(Packet* p)
 {
     Flow* flow = p->flow;
 
-    if ( flow->service and flow->searching_for_service()
+    if ( flow->has_service() and flow->searching_for_service()
          and (!(p->is_cooked()) or p->is_defrag()) )
         bumble(p);
 
@@ -1765,6 +2133,9 @@ void InspectorManager::internal_execute(Packet* p)
     if ( p->disable_inspect )
         return;
 
+    GlobalInspectorPolicy* pp = sc->policy_map->get_global_inspector_policy();
+    assert(pp);
+
     if ( !p->flow )
     {
         ::execute<T>(p, tp->first.vec, tp->first.num);
@@ -1777,6 +2148,7 @@ void InspectorManager::internal_execute(Packet* p)
         if ( p->disable_inspect )
             return;
 
+        ::execute<T>(p, pp->control.vec, pp->control.num);
         ::execute<T>(p, tp->control.vec, tp->control.num);
     }
     else
@@ -1784,16 +2156,17 @@ void InspectorManager::internal_execute(Packet* p)
         if ( !p->has_paf_payload() and p->flow->flow_state == Flow::FlowState::INSPECT )
             p->flow->session->process(p);
 
-        if ( p->flow->reload_id != sc->reload_id )
+        unsigned reload_id = SnortConfig::get_thread_reload_id();
+        if ( p->flow->reload_id != reload_id )
         {
             ::execute<T>(p, tp->first.vec, tp->first.num);
 
-            p->flow->reload_id = sc->reload_id;
+            p->flow->reload_id = reload_id;
             if ( p->disable_inspect )
                 return;
         }
 
-        if ( !p->flow->service )
+        if ( !p->flow->has_service() )
             ::execute<T>(p, fp->network.vec, fp->network.num);
 
         if ( p->disable_inspect )
@@ -1802,6 +2175,8 @@ void InspectorManager::internal_execute(Packet* p)
         if ( p->flow->full_inspection() )
             full_inspection<T>(p);
 
+        if ( !p->disable_inspect and !p->flow->is_inspection_disabled() )
+            ::execute<T>(p, pp->control.vec, pp->control.num);
         if ( !p->disable_inspect and !p->flow->is_inspection_disabled() )
             ::execute<T>(p, tp->control.vec, tp->control.num);
     }
index 0d6a33cd131e0bfd0bde9a81a77420890b2b83bf..a9b4a4913755ca87cf5dfe70341007d4215ef269 100644 (file)
@@ -26,6 +26,7 @@
 #include <map>
 
 #include "framework/inspector.h"
+#include "framework/module.h"
 
 class Binder;
 class SingleInstanceInspectorPolicy;
@@ -50,6 +51,8 @@ public:
     static void dump_buffers();
     static void release_plugins();
 
+    static void global_init();
+
     static std::vector<const InspectApi*> get_apis();
     static const char* get_inspector_type(const char* name);
 
@@ -78,6 +81,8 @@ public:
     SO_PUBLIC static Inspector* get_file_inspector(const SnortConfig* = nullptr);
     SO_PUBLIC static Inspector* get_inspector(
         const char* key, bool dflt_only = false, const SnortConfig* = nullptr);
+    SO_PUBLIC static Inspector* get_inspector(const char* key, Module::Usage, InspectorType,
+        const SnortConfig* = nullptr);
 
     static Inspector* get_service_inspector_by_service(const char*);
     static Inspector* get_service_inspector_by_id(const SnortProtocolId);
@@ -85,10 +90,10 @@ public:
     SO_PUBLIC static Binder* get_binder();
 
     SO_PUBLIC static Inspector* acquire_file_inspector();
-    SO_PUBLIC static Inspector* acquire(const char* key, bool dflt_only = false);
     SO_PUBLIC static void release(Inspector*);
 
     static bool configure(SnortConfig*, bool cloned = false);
+    static void prepare_inspectors(SnortConfig*);
     static void prepare_controls(SnortConfig*);
     static std::string generate_inspector_label(const PHInstance*);
     static void print_config(SnortConfig*);
@@ -105,7 +110,7 @@ public:
 
     static void clear(Packet*);
     static void empty_trash();
-    static void tear_down_removed_inspectors(const SnortConfig*, SnortConfig*);
+    static void reconcile_inspectors(const SnortConfig*, SnortConfig*, bool cloned = false);
     static void clear_removed_inspectors(SnortConfig*);
 
 #ifdef PIGLET
index 88850726613d1ac59459e5757fd0f9f3d94ff5b8..357c95f246f351e80ed88265435ffd5b407f698d 100644 (file)
@@ -725,7 +725,8 @@ SO_PUBLIC bool open_table(const char* s, int idx)
 
     // FIXIT-M only basic modules, inspectors and ips actions can be reloaded at present
     if ( ( Snort::is_reloading() ) and h->api
-            and h->api->type != PT_INSPECTOR and h->api->type != PT_IPS_ACTION )
+            and h->api->type != PT_INSPECTOR and h->api->type != PT_IPS_ACTION
+            and h->api->type != PT_POLICY_SELECTOR )
     {
         return false;
     }
diff --git a/src/managers/test/CMakeLists.txt b/src/managers/test/CMakeLists.txt
new file mode 100644 (file)
index 0000000..cecd450
--- /dev/null
@@ -0,0 +1,5 @@
+add_cpputest(get_inspector_test
+    SOURCES
+        get_inspector_stubs.h
+        ../inspector_manager.cc
+)
diff --git a/src/managers/test/get_inspector_stubs.h b/src/managers/test/get_inspector_stubs.h
new file mode 100644 (file)
index 0000000..1c7f44a
--- /dev/null
@@ -0,0 +1,87 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2020-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// stubs.h author Ron Dempster <rdempste@cisco.com>
+
+#include "detection/detection_engine.h"
+#include "main/policy.h"
+#include "main/snort.h"
+#include "main/snort_config.h"
+#include "main/thread_config.h"
+#include "managers/inspector_manager.h"
+#include "managers/module_manager.h"
+#include "network_inspectors/binder/bind_module.h"
+#include "search_engines/search_tool.h"
+#include "trace/trace.h"
+#include "trace/trace_api.h"
+
+THREAD_LOCAL const snort::Trace* snort_trace = nullptr;
+
+std::shared_ptr<PolicyTuple> PolicyMap::get_policies(Shell*) { return nullptr; }
+NetworkPolicy* PolicyMap::get_user_network(unsigned) { return nullptr; }
+void InspectionPolicy::configure() { }
+void BinderModule::add(const char*, const char*) { }
+void BinderModule::add(unsigned, const char*) { }
+
+void set_default_policy(const snort::SnortConfig*) { }
+
+namespace snort
+{
+unsigned THREAD_LOCAL Inspector::slot = 0;
+const SnortConfig* SearchTool::conf = nullptr;
+[[noreturn]] void FatalError(const char*,...) { exit(-1); }
+void LogMessage(const char*, ...) { }
+void LogLabel(const char*, FILE*) { }
+void ParseError(const char*, ...) { }
+void WarningMessage(const char*, ...) { }
+void DataBus::publish(const char*, Packet*, Flow*) { }
+void DetectionEngine::disable_content(Packet*) { }
+unsigned SnortConfig::get_thread_reload_id() { return 1; }
+void SnortConfig::update_thread_reload_id() { }
+Inspector::Inspector() { }
+Inspector::~Inspector() { }
+bool Inspector::likes(Packet*) { return false; }
+bool Inspector::get_buf(const char*, Packet*, InspectionBuffer&) { return false; }
+StreamSplitter* Inspector::get_splitter(bool) { return nullptr; }
+void Inspector::add_global_ref() { }
+void Inspector::rem_ref() { }
+void Inspector::rem_global_ref() { }
+void Inspector::allocate_thread_storage() { }
+void Inspector::copy_thread_storage(snort::Inspector*) { }
+const char* InspectApi::get_type(InspectorType) { return ""; }
+unsigned ThreadConfig::get_instance_max() { return 1; }
+bool Snort::is_reloading() { return false; }
+SnortProtocolId ProtocolReference::find(const char*) const { return UNKNOWN_PROTOCOL_ID; }
+SnortProtocolId ProtocolReference::add(const char*) { return UNKNOWN_PROTOCOL_ID; }
+uint8_t TraceApi::get_constraints_generation() { return 0; }
+void TraceApi::filter(const Packet&) { }
+PegCount Module::get_global_count(const char*) const { return 0; }
+void Module::sum_stats(bool) { }
+void Module::show_interval_stats(IndexVec&, FILE*) { }
+void Module::show_stats() { }
+void Module::reset_stats() { }
+DataBus::DataBus() { }
+DataBus::~DataBus() { }
+Module* ModuleManager::get_module(const char*) { return nullptr; }
+
+NetworkPolicy* get_default_network_policy(const SnortConfig*) { return nullptr; }
+void set_network_policy(NetworkPolicy*) { }
+void set_inspection_policy(InspectionPolicy*) { }
+void set_ips_policy(IpsPolicy*) { }
+unsigned get_instance_id() { return 0; }
+void trace_vprintf(const char*, TraceLevel, const char*, const Packet*, const char*, va_list) { }
+}
diff --git a/src/managers/test/get_inspector_test.cc b/src/managers/test/get_inspector_test.cc
new file mode 100644 (file)
index 0000000..11470af
--- /dev/null
@@ -0,0 +1,327 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2020-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// get_inspector_test.cc author Ron Dempster <rdempste@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include <unordered_map>
+
+#include "get_inspector_stubs.h"
+
+#include <CppUTest/CommandLineTestRunner.h>
+#include <CppUTest/TestHarness.h>
+#include <CppUTestExt/MockSupport.h>
+
+using namespace snort;
+
+bool Inspector::is_inactive() { return true; }
+
+NetworkPolicy* snort::get_network_policy()
+{ return (NetworkPolicy*)mock().getData("network_policy").getObjectPointer(); }
+InspectionPolicy* snort::get_inspection_policy()
+{ return (InspectionPolicy*)mock().getData("inspection_policy").getObjectPointer(); }
+
+InspectionPolicy::InspectionPolicy(PolicyId)
+{ InspectorManager::new_policy(this, nullptr); }
+InspectionPolicy::~InspectionPolicy()
+{ InspectorManager::delete_policy(this, false); }
+NetworkPolicy::NetworkPolicy(PolicyId, PolicyId)
+{ InspectorManager::new_policy(this, nullptr); }
+NetworkPolicy::~NetworkPolicy()
+{
+    for ( auto p : inspection_policy )
+        delete p;
+
+    InspectorManager::delete_policy(this, false);
+    inspection_policy.clear();
+}
+PolicyMap::PolicyMap(PolicyMap*, const char*)
+{
+    file_id = InspectorManager::create_single_instance_inspector_policy();
+    flow_tracking = InspectorManager::create_single_instance_inspector_policy();
+    global_inspector_policy = InspectorManager::create_global_inspector_policy();
+    NetworkPolicy* np = new NetworkPolicy(network_policy.size(), 0);
+    network_policy.push_back(np);
+    InspectionPolicy* ip = new InspectionPolicy();
+    np->inspection_policy.push_back(ip);
+}
+PolicyMap::~PolicyMap()
+{
+    InspectorManager::destroy_single_instance_inspector(file_id);
+    InspectorManager::destroy_single_instance_inspector(flow_tracking);
+    InspectorManager::destroy_global_inspector_policy(global_inspector_policy, false);
+    for ( auto p : network_policy )
+        delete p;
+}
+SnortConfig::SnortConfig(const SnortConfig* const, const char*)
+{
+    policy_map = new PolicyMap();
+    InspectorManager::new_config(this);
+}
+SnortConfig::~SnortConfig()
+{
+    InspectorManager::delete_config(this);
+    delete policy_map;
+}
+const SnortConfig* SnortConfig::get_conf()
+{ return (const SnortConfig*)mock().getData("snort_config").getObjectPointer(); }
+
+Module::Module(const char* name, const char*) : name(name)
+{ }
+
+class TestInspector : public Inspector
+{
+public:
+    TestInspector() = default;
+    ~TestInspector() override = default;
+    void eval(Packet*) override { }
+};
+
+class TestModule : public Module
+{
+public:
+    TestModule(const char* name, Module::Usage usage) : Module(name, ""), usage(usage)
+    { }
+    ~TestModule() override = default;
+    Usage get_usage() const override
+    { return usage; }
+
+protected:
+    Usage usage;
+};
+
+
+static Inspector* test_ctor(Module* mod)
+{
+    std::unordered_map<Module*, Inspector*>* mod_to_ins =
+        (std::unordered_map<Module*, Inspector*>*)mock().getData("mod_to_ins").getObjectPointer();
+    auto it = mod_to_ins->find(mod);
+    return it == mod_to_ins->end() ? nullptr : it->second;
+}
+
+static void test_dtor(Inspector*)
+{ }
+
+#define DECLARE_ENTRY(NAME, USAGE) \
+    static TestModule NAME##_mod(#NAME, USAGE); \
+    static InspectApi NAME##_api; \
+    static TestInspector NAME##_ins
+
+DECLARE_ENTRY(binder, Module::Usage::INSPECT);
+
+DECLARE_ENTRY(file, Module::Usage::GLOBAL);
+DECLARE_ENTRY(stream, Module::Usage::GLOBAL);
+
+DECLARE_ENTRY(global_passive, Module::Usage::GLOBAL);
+DECLARE_ENTRY(global_probe, Module::Usage::GLOBAL);
+DECLARE_ENTRY(global_control, Module::Usage::GLOBAL);
+
+DECLARE_ENTRY(context_passive, Module::Usage::CONTEXT);
+DECLARE_ENTRY(context_packet, Module::Usage::CONTEXT);
+DECLARE_ENTRY(context_first, Module::Usage::CONTEXT);
+DECLARE_ENTRY(context_control, Module::Usage::CONTEXT);
+
+DECLARE_ENTRY(inspect_passive, Module::Usage::INSPECT);
+DECLARE_ENTRY(inspect_packet, Module::Usage::INSPECT);
+DECLARE_ENTRY(inspect_network, Module::Usage::INSPECT);
+DECLARE_ENTRY(inspect_service, Module::Usage::INSPECT);
+DECLARE_ENTRY(inspect_stream, Module::Usage::INSPECT);
+DECLARE_ENTRY(inspect_wizard, Module::Usage::INSPECT);
+
+#define ADD_ENTRY(NAME, TYPE) \
+    do { \
+        NAME##_api = {}; \
+        NAME##_api.base.name = NAME##_mod.get_name(); \
+        NAME##_api.type = TYPE; \
+        NAME##_api.ctor = test_ctor; \
+        NAME##_api.dtor = test_dtor; \
+        NAME##_ins.set_api(&NAME##_api); \
+        InspectorManager::add_plugin(&NAME##_api); \
+    } while (0)
+
+#define INSTANTIATE(NAME) \
+    do { \
+        mod_to_ins[&NAME##_mod] = &NAME##_ins; \
+        InspectorManager::instantiate(&NAME##_api, &NAME##_mod, sc, NAME##_mod.get_name()); \
+    } while (0)
+
+void setup_test_globals()
+{
+    ADD_ENTRY(binder, IT_PASSIVE);
+
+    ADD_ENTRY(file, IT_FILE);
+    ADD_ENTRY(stream, IT_STREAM);
+
+    ADD_ENTRY(global_passive, IT_PASSIVE);
+    ADD_ENTRY(global_probe, IT_PROBE);
+    ADD_ENTRY(global_control, IT_CONTROL);
+
+    ADD_ENTRY(context_passive, IT_PASSIVE);
+    ADD_ENTRY(context_packet, IT_PACKET);
+    ADD_ENTRY(context_first, IT_FIRST);
+    ADD_ENTRY(context_control, IT_CONTROL);
+
+    ADD_ENTRY(inspect_passive, IT_PASSIVE);
+    ADD_ENTRY(inspect_packet, IT_PACKET);
+    ADD_ENTRY(inspect_network, IT_NETWORK);
+    ADD_ENTRY(inspect_service, IT_SERVICE);
+    ADD_ENTRY(inspect_stream, IT_STREAM);
+    ADD_ENTRY(inspect_wizard, IT_WIZARD);
+}
+
+TEST_GROUP(get_inspector_tests)
+{
+    SnortConfig* sc;
+    std::unordered_map<Module*, Inspector*> mod_to_ins;
+
+    void setup() override
+    {
+        sc = new SnortConfig;
+        mock().setDataObject("snort_config", "const SnortConfig", sc);
+        mock().setDataObject("mod_to_ins", "std::unordered_map<Module*, Inspector*>", &mod_to_ins);
+        NetworkPolicy* np = sc->policy_map->get_network_policy();
+        mock().setDataObject("network_policy", "NetworkPolicy", np);
+        InspectionPolicy* ip = np->get_inspection_policy();
+        mock().setDataObject("inspection_policy", "InspectionPolicy", ip);
+
+        INSTANTIATE(binder);
+
+        INSTANTIATE(file);
+        INSTANTIATE(stream);
+
+        INSTANTIATE(global_passive);
+        INSTANTIATE(global_probe);
+        INSTANTIATE(global_control);
+
+        INSTANTIATE(context_passive);
+        INSTANTIATE(context_packet);
+        INSTANTIATE(context_first);
+        INSTANTIATE(context_control);
+
+        INSTANTIATE(inspect_passive);
+        INSTANTIATE(inspect_packet);
+        INSTANTIATE(inspect_network);
+        INSTANTIATE(inspect_service);
+        INSTANTIATE(inspect_stream);
+        INSTANTIATE(inspect_wizard);
+
+        InspectorManager::configure(sc, false);
+    }
+
+    void teardown() override
+    {
+        delete sc;
+        InspectorManager::empty_trash();
+        mod_to_ins.clear();
+        mock().clear();
+    }
+};
+
+#define THE_TEST(NAME, USAGE, TYPE) \
+    do { \
+        Inspector* ins = InspectorManager::get_inspector(NAME##_mod.get_name(), USAGE, TYPE); \
+        CHECK_TEXT(&NAME##_ins == ins, "Did not find the " #NAME " inspector"); \
+        STRCMP_EQUAL_TEXT(ins->get_name(), NAME##_mod.get_name(), "Inspector name is not " #NAME); \
+        ins = InspectorManager::get_inspector("not_" #NAME, USAGE, TYPE); \
+        CHECK_TEXT(nullptr == ins, "Found the not_" #NAME " inspector"); \
+    } while (0)
+
+TEST(get_inspector_tests, file)
+{
+    THE_TEST(file, Module::Usage::GLOBAL, IT_FILE);
+}
+
+TEST(get_inspector_tests, stream)
+{
+    THE_TEST(stream, Module::Usage::GLOBAL, IT_STREAM);
+}
+
+TEST(get_inspector_tests, global_passive)
+{
+    THE_TEST(global_passive, Module::Usage::GLOBAL, IT_PASSIVE);
+}
+
+TEST(get_inspector_tests, global_probe)
+{
+    THE_TEST(global_probe, Module::Usage::GLOBAL, IT_PROBE);
+}
+
+TEST(get_inspector_tests, global_control)
+{
+    THE_TEST(global_control, Module::Usage::GLOBAL, IT_CONTROL);
+}
+
+TEST(get_inspector_tests, context_passive)
+{
+    THE_TEST(context_passive, Module::Usage::CONTEXT, IT_PASSIVE);
+}
+
+TEST(get_inspector_tests, context_packet)
+{
+    THE_TEST(context_packet, Module::Usage::CONTEXT, IT_PACKET);
+}
+
+TEST(get_inspector_tests, context_first)
+{
+    THE_TEST(context_first, Module::Usage::CONTEXT, IT_FIRST);
+}
+
+TEST(get_inspector_tests, context_control)
+{
+    THE_TEST(context_control, Module::Usage::CONTEXT, IT_CONTROL);
+}
+
+TEST(get_inspector_tests, inspect_passive)
+{
+    THE_TEST(inspect_passive, Module::Usage::INSPECT, IT_PASSIVE);
+}
+
+TEST(get_inspector_tests, inspect_packet)
+{
+    THE_TEST(inspect_packet, Module::Usage::INSPECT, IT_PACKET);
+}
+
+TEST(get_inspector_tests, inspect_network)
+{
+    THE_TEST(inspect_network, Module::Usage::INSPECT, IT_NETWORK);
+}
+
+TEST(get_inspector_tests, inspect_service)
+{
+    THE_TEST(inspect_service, Module::Usage::INSPECT, IT_SERVICE);
+}
+
+TEST(get_inspector_tests, inspect_stream)
+{
+    THE_TEST(inspect_stream, Module::Usage::INSPECT, IT_STREAM);
+}
+
+TEST(get_inspector_tests, inspect_wizard)
+{
+    THE_TEST(inspect_wizard, Module::Usage::INSPECT, IT_WIZARD);
+}
+
+int main(int argc, char** argv)
+{
+    setup_test_globals();
+    int r = CommandLineTestRunner::RunAllTests(argc, argv);
+    InspectorManager::release_plugins();
+    return r;
+}
index 10eda764dfd631e60bf3d74733a7590cae438301..7fef03787b41f2b538e2326b195f737e24ae7f3b 100644 (file)
@@ -115,7 +115,7 @@ bool AppIdContext::init_appid(SnortConfig* sc, AppIdInspector& inspector)
     {
         odp_ctxt->get_client_disco_mgr().initialize(inspector);
         odp_ctxt->get_service_disco_mgr().initialize(inspector);
-        odp_thread_local_ctxt->initialize(*this, true);
+        odp_thread_local_ctxt->initialize(sc, *this, true);
         odp_ctxt->initialize(inspector);
 
         // do not reload third party on reload_config()
@@ -224,12 +224,13 @@ AppId OdpContext::get_protocol_service_id(IpProtocol proto)
     return ip_protocol[(uint16_t)proto];
 }
 
-void OdpThreadContext::initialize(AppIdContext& ctxt, bool is_control, bool reload_odp)
+void OdpThreadContext::initialize(const SnortConfig* sc, AppIdContext& ctxt, bool is_control,
+    bool reload_odp)
 {
     if (!is_control and reload_odp)
-        LuaDetectorManager::init_thread_manager(ctxt);
+        LuaDetectorManager::init_thread_manager(sc, ctxt);
     else
-        LuaDetectorManager::initialize(ctxt, is_control, reload_odp);
+        LuaDetectorManager::initialize(sc, ctxt, is_control, reload_odp);
 }
 
 OdpThreadContext::~OdpThreadContext()
index adae41630e2dc80320fa9a3e8e297029189b7965..4a23839475b73d93ee616c9c1a08a33b7df51c79 100644 (file)
@@ -153,10 +153,10 @@ public:
         return host_port_cache.find(ip, port, proto, *this);
     }
 
-    bool host_port_cache_add(const snort::SfIp* ip, uint16_t port, IpProtocol proto, unsigned type,
-        AppId appid)
+    bool host_port_cache_add(const snort::SnortConfig* sc, const snort::SfIp* ip, uint16_t port,
+        IpProtocol proto, unsigned type, AppId appid)
     {
-        return host_port_cache.add(ip, port, proto, type, appid);
+        return host_port_cache.add(sc, ip, port, proto, type, appid);
     }
 
     AppId length_cache_find(const LengthKey& key)
@@ -241,7 +241,8 @@ class OdpThreadContext
 {
 public:
     ~OdpThreadContext();
-    void initialize(AppIdContext& ctxt, bool is_control=false, bool reload_odp=false);
+    void initialize(const snort::SnortConfig*, AppIdContext& ctxt, bool is_control=false,
+        bool reload_odp=false);
 
     void set_lua_detector_mgr(LuaDetectorManager& mgr)
     {
index edcee9b10e768f60a83fdf437e52dbeee02ebe5a..0b1a8ddfac66005c27780ca61f74dc2ff4f1e17b 100644 (file)
@@ -43,13 +43,13 @@ THREAD_LOCAL AppIdHAAppsClient* AppIdHAManager::ha_apps_client = nullptr;
 THREAD_LOCAL AppIdHAHttpClient* AppIdHAManager::ha_http_client = nullptr;
 THREAD_LOCAL AppIdHATlsHostClient* AppIdHAManager::ha_tls_host_client = nullptr;
 
-static AppIdSession* create_appid_session(Flow& flow, const FlowKey* key)
+static AppIdSession* create_appid_session(Flow& flow, const FlowKey* key,
+    AppIdInspector& inspector)
 {
-    AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true);
     AppIdSession* asd = new AppIdSession(static_cast<IpProtocol>(key->ip_protocol),
         flow.flags.client_initiated ? &flow.client_ip : &flow.server_ip,
-        flow.flags.client_initiated ? flow.client_port : flow.server_port, *inspector,
-        inspector->get_ctxt().get_odp_ctxt(), key->addressSpaceId);
+        flow.flags.client_initiated ? flow.client_port : flow.server_port, inspector,
+        inspector.get_ctxt().get_odp_ctxt(), key->addressSpaceId);
     if (appidDebug->is_active())
         LogMessage("AppIdDbg %s high-avail - New AppId session created in consume\n",
             appidDebug->get_debug_session());
@@ -67,7 +67,9 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
     if (size != sizeof(AppIdSessionHAApps))
         return false;
 
-    AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true);
+    AppIdInspector* inspector =
+        static_cast<AppIdInspector*>(
+            InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type));
     if (!inspector)
         return false;
 
@@ -77,9 +79,9 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
     if (appidDebug->is_enabled())
     {
         appidDebug->activate(flow, asd, inspector->get_ctxt().config.log_all_sessions);
-        LogMessage("AppIdDbg %s high-avail - Consuming app data - flags 0x%x, service %d, client %d, "
-            "payload %d, misc %d, referred %d, client_inferred_service %d, port_service %d, "
-            "tp_app %d, tp_payload %d\n",
+        LogMessage("AppIdDbg %s high-avail - Consuming app data - flags 0x%x, service %d, "
+            "client %d, payload %d, misc %d, referred %d, client_inferred_service %d, "
+            "port_service %d, tp_app %d, tp_payload %d\n",
             appidDebug->get_debug_session(), appHA->flags, appHA->appId[APPID_HA_APP_SERVICE],
             appHA->appId[APPID_HA_APP_CLIENT], appHA->appId[APPID_HA_APP_PAYLOAD],
             appHA->appId[APPID_HA_APP_MISC], appHA->appId[APPID_HA_APP_REFERRED],
@@ -90,7 +92,7 @@ bool AppIdHAAppsClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
 
     if (!asd)
     {
-        asd = create_appid_session(*flow, key);
+        asd = create_appid_session(*flow, key, *inspector);
         asd->set_service_id(appHA->appId[APPID_HA_APP_SERVICE], asd->get_odp_ctxt());
         if (asd->get_service_id() == APP_ID_FTP_CONTROL)
         {
@@ -227,7 +229,9 @@ bool AppIdHAHttpClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
     if (size != sizeof(AppIdSessionHAHttp))
         return false;
 
-    AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true);
+    AppIdInspector* inspector =
+        static_cast<AppIdInspector*>(
+            InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type));
     if (!inspector)
         return false;
 
@@ -241,7 +245,7 @@ bool AppIdHAHttpClient::consume(Flow*& flow, const FlowKey* key, HAMessage& msg,
     }
 
     if (!asd)
-        asd = create_appid_session(*flow, key);
+        asd = create_appid_session(*flow, key, *inspector);
 
     AppidChangeBits change_bits;
     AppIdHttpSession* hsession = asd->get_http_session();
@@ -315,7 +319,9 @@ bool AppIdHATlsHostClient::consume(Flow*& flow, const FlowKey* key, HAMessage& m
     if (size != sizeof(AppIdSessionHATlsHost))
         return false;
 
-    AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME, true);
+    AppIdInspector* inspector =
+        static_cast<AppIdInspector*>(
+            InspectorManager::get_inspector(MOD_NAME, MOD_USAGE, appid_inspector_api.type));
     if (!inspector)
         return false;
 
@@ -329,7 +335,7 @@ bool AppIdHATlsHostClient::consume(Flow*& flow, const FlowKey* key, HAMessage& m
     }
 
     if (!asd)
-        asd = create_appid_session(*flow, key);
+        asd = create_appid_session(*flow, key, *inspector);
 
     asd->set_tls_host(appHA->tls_host);
 
index 435c5af85fae22770f2ee75ef565b754c29eec4e..ac6521849a40df6eac001de094f60e59601a3496 100644 (file)
@@ -122,27 +122,27 @@ bool AppIdInspector::configure(SnortConfig* sc)
 
     ctxt->init_appid(sc, *this);
 
-    DataBus::subscribe_network(SIP_EVENT_TYPE_SIP_DIALOG_KEY, new SipEventHandler(*this));
+    DataBus::subscribe_global(SIP_EVENT_TYPE_SIP_DIALOG_KEY, new SipEventHandler(*this), *sc);
 
-    DataBus::subscribe_network(HTTP_REQUEST_HEADER_EVENT_KEY, new HttpEventHandler(
-        HttpEventHandler::REQUEST_EVENT, *this));
+    DataBus::subscribe_global(HTTP_REQUEST_HEADER_EVENT_KEY, new HttpEventHandler(
+        HttpEventHandler::REQUEST_EVENT, *this), *sc);
 
-    DataBus::subscribe_network(HTTP_RESPONSE_HEADER_EVENT_KEY, new HttpEventHandler(
-        HttpEventHandler::RESPONSE_EVENT, *this));
+    DataBus::subscribe_global(HTTP_RESPONSE_HEADER_EVENT_KEY, new HttpEventHandler(
+        HttpEventHandler::RESPONSE_EVENT, *this), *sc);
 
-    DataBus::subscribe_network(HTTP2_REQUEST_BODY_EVENT_KEY, new AppIdHttp2ReqBodyEventHandler());
+    DataBus::subscribe_global(HTTP2_REQUEST_BODY_EVENT_KEY, new AppIdHttp2ReqBodyEventHandler(), *sc);
 
-    DataBus::subscribe_network(DATA_DECRYPT_EVENT, new DataDecryptEventHandler());
+    DataBus::subscribe_global(DATA_DECRYPT_EVENT, new DataDecryptEventHandler(), *sc);
 
-    DataBus::subscribe_network(DCERPC_EXP_SESSION_EVENT_KEY, new DceExpSsnEventHandler());
+    DataBus::subscribe_global(DCERPC_EXP_SESSION_EVENT_KEY, new DceExpSsnEventHandler(), *sc);
 
-    DataBus::subscribe_network(OPPORTUNISTIC_TLS_EVENT, new AppIdOpportunisticTlsEventHandler());
+    DataBus::subscribe_global(OPPORTUNISTIC_TLS_EVENT, new AppIdOpportunisticTlsEventHandler(), *sc);
 
-    DataBus::subscribe_network(EVE_PROCESS_EVENT, new AppIdEveProcessEventHandler());
+    DataBus::subscribe_global(EVE_PROCESS_EVENT, new AppIdEveProcessEventHandler(), *sc);
 
-    DataBus::subscribe_network(SSH_EVENT, new SshEventHandler());
+    DataBus::subscribe_global(SSH_EVENT, new SshEventHandler(), *sc);
 
-    DataBus::subscribe_network(FLOW_NO_SERVICE_EVENT, new AppIdServiceEventHandler(*this));
+    DataBus::subscribe_global(FLOW_NO_SERVICE_EVENT, new AppIdServiceEventHandler(*this), *sc);
 
     return true;
 }
@@ -163,7 +163,7 @@ void AppIdInspector::tinit()
 
     assert(!odp_thread_local_ctxt);
     odp_thread_local_ctxt = new OdpThreadContext();
-    odp_thread_local_ctxt->initialize(*ctxt);
+    odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), *ctxt);
 
     AppIdServiceState::initialize(config->memcap);
     assert(!pkt_thread_tp_appid_ctxt);
@@ -250,6 +250,7 @@ static void appid_inspector_tinit()
 static void appid_inspector_tterm()
 {
     TPLibHandler::tfini();
+    AppIdPegCounts::sum_stats();
     AppIdPegCounts::cleanup_pegs();
     AppIdServiceState::clean();
     delete appidDebug;
index a182e38b154fbd6ed28eb0e44ce31aa4d194d3ee..d27b4a4698c3f45864d52ae1edf9e43117dc7441 100644 (file)
@@ -54,6 +54,8 @@ private:
     AppIdContext* ctxt = nullptr;
 };
 
+extern const snort::InspectApi appid_inspector_api;
+
 extern THREAD_LOCAL OdpThreadContext* odp_thread_local_ctxt;
 extern THREAD_LOCAL OdpContext* pkt_thread_odp_ctxt;
 extern THREAD_LOCAL ThirdPartyAppIdContext* pkt_thread_tp_appid_ctxt;
index abe438adabc2fc00e3d6bd92de2de802c98ed51e..79c93716a0832c19b7cbab8757a772200ed16f91 100644 (file)
@@ -151,7 +151,7 @@ class ACThirdPartyAppIdContextSwap : public AnalyzerCommand
 public:
     bool execute(Analyzer&, void**) override;
     ACThirdPartyAppIdContextSwap(const AppIdInspector& inspector, ControlConn* conn)
-        : inspector(inspector), tracker_ref(conn)
+        : AnalyzerCommand(conn), inspector(inspector)
     {
         LogMessage("== swapping third-party configuration\n");
     }
@@ -160,7 +160,6 @@ public:
     const char* stringify() override { return "THIRD-PARTY_CONTEXT_SWAP"; }
 private:
     const AppIdInspector& inspector;
-    ControlConn* tracker_ref;
 };
 
 bool ACThirdPartyAppIdContextSwap::execute(Analyzer&, void**)
@@ -179,7 +178,7 @@ ACThirdPartyAppIdContextSwap::~ACThirdPartyAppIdContextSwap()
     std::string file_path = ctxt.get_tp_appid_ctxt()->get_user_config();
     ctxt.get_odp_ctxt().get_app_info_mgr().dump_appid_configurations(file_path);
     LogMessage("== third-party configuration swap complete\n");
-    ReloadTracker::end(tracker_ref);
+    ReloadTracker::end(ctrlcon);
 }
 
 class ACThirdPartyAppIdContextUnload : public AnalyzerCommand
@@ -187,13 +186,13 @@ class ACThirdPartyAppIdContextUnload : public AnalyzerCommand
 public:
     bool execute(Analyzer&, void**) override;
     ACThirdPartyAppIdContextUnload(const AppIdInspector& inspector, ThirdPartyAppIdContext* tp_ctxt,
-        ControlConn* ctrlcon): inspector(inspector), tp_ctxt(tp_ctxt), ctrlcon(ctrlcon) { }
+        ControlConn* conn): AnalyzerCommand(conn), inspector(inspector), tp_ctxt(tp_ctxt)
+    { }
     ~ACThirdPartyAppIdContextUnload() override;
     const char* stringify() override { return "THIRD-PARTY_CONTEXT_UNLOAD"; }
 private:
     const AppIdInspector& inspector;
     ThirdPartyAppIdContext* tp_ctxt =  nullptr;
-    ControlConn* ctrlcon;
 };
 
 bool ACThirdPartyAppIdContextUnload::execute(Analyzer& ac, void**)
@@ -218,9 +217,7 @@ ACThirdPartyAppIdContextUnload::~ACThirdPartyAppIdContextUnload()
     AppIdContext& ctxt = inspector.get_ctxt();
     ctxt.create_tp_appid_ctxt();
     main_broadcast_command(new ACThirdPartyAppIdContextSwap(inspector, ctrlcon));
-    LogMessage("== reload third-party complete\n");
-    if (ctrlcon && !ctrlcon->is_local())
-        ctrlcon->respond("== reload third-party complete\n");
+    log_message("== reload third-party complete\n");
     ReloadTracker::update(ctrlcon, "unload old third-party complete, start swapping to new configuration.");
 }
 
@@ -228,14 +225,14 @@ class ACOdpContextSwap : public AnalyzerCommand
 {
 public:
     bool execute(Analyzer&, void**) override;
-    ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt, ControlConn* ctrlcon) :
-        inspector(inspector), odp_ctxt(odp_ctxt), ctrlcon(ctrlcon) { }
+    ACOdpContextSwap(const AppIdInspector& inspector, OdpContext& odp_ctxt, ControlConn* conn) :
+        AnalyzerCommand(conn), inspector(inspector), odp_ctxt(odp_ctxt)
+    { }
     ~ACOdpContextSwap() override;
     const char* stringify() override { return "ODP_CONTEXT_SWAP"; }
 private:
     const AppIdInspector& inspector;
     OdpContext& odp_ctxt;
-    ControlConn* ctrlcon;
 };
 
 bool ACOdpContextSwap::execute(Analyzer&, void**)
@@ -254,7 +251,7 @@ bool ACOdpContextSwap::execute(Analyzer&, void**)
     assert(odp_thread_local_ctxt);
     delete odp_thread_local_ctxt;
     odp_thread_local_ctxt = new OdpThreadContext;
-    odp_thread_local_ctxt->initialize(ctxt, false, true);
+    odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), ctxt, false, true);
     return true;
 }
 
@@ -270,9 +267,8 @@ ACOdpContextSwap::~ACOdpContextSwap()
             file_path = std::string(ctxt.config.app_detector_dir) + "/../userappid.conf";
         ctxt.get_odp_ctxt().get_app_info_mgr().dump_appid_configurations(file_path);
     }
-    LogMessage("== reload detectors complete\n");
     ReloadTracker::end(ctrlcon);
-    ctrlcon->respond("== reload detectors complete\n");
+    log_message("== reload detectors complete\n");
 }
 
 static int enable_debug(lua_State* L)
@@ -394,7 +390,7 @@ static int reload_detectors(lua_State* L)
     OdpContext& odp_ctxt = ctxt.get_odp_ctxt();
     odp_ctxt.get_client_disco_mgr().initialize(*inspector);
     odp_ctxt.get_service_disco_mgr().initialize(*inspector);
-    odp_thread_local_ctxt->initialize(ctxt, true, true);
+    odp_thread_local_ctxt->initialize(SnortConfig::get_conf(), ctxt, true, true);
     odp_ctxt.initialize(*inspector);
 
     ctrlcon->respond("== swapping detectors configuration\n");
@@ -524,7 +520,7 @@ bool AppIdModule::end(const char* fqn, int, SnortConfig* sc)
     assert(config);
 
     if ( Snort::is_reloading() && strcmp(fqn, "appid") == 0 )
-        sc->register_reload_resource_tuner(new AppIdReloadTuner(config->memcap));
+        sc->register_reload_handler(new AppIdReloadTuner(config->memcap));
 
     if ( !config->app_detector_dir )
     {
index ce59a8c10147fc29d1e09d5ad40fce1d3d14d3d0..5779dd3e52512ba19d72671bccaad19e93471271 100644 (file)
 #include "framework/module.h"
 #include "main/analyzer.h"
 #include "main/analyzer_command.h"
-#include "main/snort_config.h"
+#include "main/reload_tuner.h"
 
 #include "appid_config.h"
 #include "appid_peg_counts.h"
 
 namespace snort
 {
+struct SnortConfig;
 class Trace;
 }
 
@@ -43,6 +44,7 @@ extern THREAD_LOCAL const snort::Trace* appid_trace;
 
 #define MOD_NAME "appid"
 #define MOD_HELP "application and service identification"
+#define MOD_USAGE snort::Module::GLOBAL
 
 
 class AppIdReloadTuner : public snort::ReloadResourceTuner
@@ -91,7 +93,7 @@ public:
     void reset_stats() override;
 
     Usage get_usage() const override
-    { return CONTEXT; }
+    { return MOD_USAGE; }
     void sum_stats(bool) override;
     void show_dynamic_stats() override;
 
index 2c87d2ff4bb1ec21bc620738e2c0d2687e94959d..8f52b4ada1e0528536bfb958079267549edc7a5e 100644 (file)
@@ -28,6 +28,8 @@
 #include <algorithm>
 #include <string>
 
+#include "framework/inspector.h"
+#include "main/thread_config.h"
 #include "utils/stats.h"
 
 using namespace snort;
@@ -36,10 +38,12 @@ std::unordered_map<AppId, uint32_t> AppIdPegCounts::appid_detector_pegs_idx;
 std::vector<std::string> AppIdPegCounts::appid_detectors_info;
 THREAD_LOCAL std::vector<AppIdPegCounts::AppIdDynamicPeg>* AppIdPegCounts::appid_peg_counts;
 AppIdPegCounts::AppIdDynamicPeg AppIdPegCounts::appid_dynamic_sum[SF_APPID_MAX + 1];
+AppIdPegCounts::AppIdDynamicPeg AppIdPegCounts::zeroed_peg;
+PegCount AppIdPegCounts::all_zeroed_peg[DetectorPegs::NUM_APPID_DETECTOR_PEGS] = {};
 
 void AppIdPegCounts::init_pegs()
 {
-    AppIdPegCounts::AppIdDynamicPeg zeroed_peg = AppIdPegCounts::AppIdDynamicPeg();
+    assert(!appid_peg_counts);
     appid_peg_counts = new std::vector<AppIdPegCounts::AppIdDynamicPeg>(
         appid_detectors_info.size() + 1, zeroed_peg);
 }
@@ -47,6 +51,7 @@ void AppIdPegCounts::init_pegs()
 void AppIdPegCounts::cleanup_pegs()
 {
     delete appid_peg_counts;
+    appid_peg_counts = nullptr;
 }
 
 void AppIdPegCounts::cleanup_peg_info()
@@ -57,22 +62,21 @@ void AppIdPegCounts::cleanup_peg_info()
 
 void AppIdPegCounts::cleanup_dynamic_sum()
 {
-    if ( !appid_peg_counts )
-        return;
-
     for ( unsigned app_num = 0; app_num < AppIdPegCounts::appid_detectors_info.size(); app_num++ )
     {
         memset(appid_dynamic_sum[app_num].stats, 0, sizeof(PegCount) *
             DetectorPegs::NUM_APPID_DETECTOR_PEGS);
-        memset((*appid_peg_counts)[app_num].stats, 0, sizeof(PegCount) *
-            DetectorPegs::NUM_APPID_DETECTOR_PEGS);
+        if ( appid_peg_counts )
+            memset((*appid_peg_counts)[app_num].stats, 0, sizeof(PegCount) *
+                DetectorPegs::NUM_APPID_DETECTOR_PEGS);
     }
 
     // reset unknown_app stats
     memset(appid_dynamic_sum[SF_APPID_MAX].stats, 0, sizeof(PegCount) *
         DetectorPegs::NUM_APPID_DETECTOR_PEGS);
-    memset((*appid_peg_counts)[appid_peg_counts->size() - 1].stats, 0, sizeof(PegCount) *
-        DetectorPegs::NUM_APPID_DETECTOR_PEGS);
+    if ( appid_peg_counts )
+        memset((*appid_peg_counts)[appid_peg_counts->size() - 1].stats, 0, sizeof(PegCount) *
+            DetectorPegs::NUM_APPID_DETECTOR_PEGS);
 }
 
 void AppIdPegCounts::add_app_peg_info(std::string app_name, AppId app_id)
index 4d80e017421d233aefeadddedabad9f3a5df45ad..9fb148284a697a8da25d588a29f2a063546b50c9 100644 (file)
@@ -76,14 +76,13 @@ public:
 
         bool all_zeros()
         {
-            PegCount zeroed_peg[DetectorPegs::NUM_APPID_DETECTOR_PEGS] = { };
-            return !memcmp(stats, &zeroed_peg, sizeof(stats));
+            return !memcmp(stats, &all_zeroed_peg, sizeof(stats));
         }
 
         void print(const char* app, char* buf, int buf_size)
         {
             snprintf(buf, buf_size, "%25.25s: " FMTu64("-10") " " FMTu64("-10") " " FMTu64("-10")
-                " " FMTu64("-10") " " FMTu64("-10") " " FMTu64("-10"), app, 
+                " " FMTu64("-10") " " FMTu64("-10") " " FMTu64("-10"), app,
                 stats[SERVICE_DETECTS], stats[CLIENT_DETECTS], stats[USER_DETECTS],
                 stats[PAYLOAD_DETECTS], stats[MISC_DETECTS], stats[REFERRED_DETECTS]);
         }
@@ -111,6 +110,8 @@ private:
     static AppIdDynamicPeg appid_dynamic_sum[SF_APPID_MAX+1];
     static THREAD_LOCAL std::vector<AppIdDynamicPeg>* appid_peg_counts;
     static uint32_t get_stats_index(AppId id);
+    static AppIdDynamicPeg zeroed_peg;
+    static PegCount all_zeroed_peg[DetectorPegs::NUM_APPID_DETECTOR_PEGS];
 };
 #endif
 
index 616f979e2dda28fe584cd8faa9793c43d05cfe55..6c69946b04c26ccd731a8c79e0c177aa38b4024a 100644 (file)
@@ -51,14 +51,15 @@ HostPortVal* HostPortCache::find(const SfIp* ip, uint16_t port, IpProtocol proto
         return nullptr;
 }
 
-bool HostPortCache::add(const SfIp* ip, uint16_t port, IpProtocol proto, unsigned type, AppId
-    appId)
+bool HostPortCache::add(const SnortConfig* sc, const SfIp* ip, uint16_t port, IpProtocol proto,
+    unsigned type, AppId appId)
 {
     HostPortKey hk;
     HostPortVal hv;
 
     hk.ip = *ip;
-    AppIdInspector* inspector = (AppIdInspector*) InspectorManager::get_inspector(MOD_NAME);
+    AppIdInspector* inspector =
+        (AppIdInspector*)InspectorManager::get_inspector(MOD_NAME, false, sc);
     assert(inspector);
     const AppIdContext& ctxt = inspector->get_ctxt();
     hk.port = (ctxt.get_odp_ctxt().allow_port_wildcard_host_cache)? 0 : port;
index 62ae8b578744d49b7675a891331297132a0f7ff8..4a469372dc53868769eeaf83c5240be485723fb9 100644 (file)
 #include "sfip/sf_ip.h"
 #include "utils/cpp_macros.h"
 
+namespace snort
+{
+    struct SnortConfig;
+}
+
 class OdpContext;
 
 PADDING_GUARD_BEGIN
@@ -64,7 +69,8 @@ class HostPortCache
 {
 public:
     HostPortVal* find(const snort::SfIp*, uint16_t port, IpProtocol, const OdpContext&);
-    bool add(const snort::SfIp*, uint16_t port, IpProtocol, unsigned type, AppId);
+    bool add(const snort::SnortConfig*, const snort::SfIp*, uint16_t port, IpProtocol,
+        unsigned type, AppId);
     void dump();
 
     ~HostPortCache()
index c546b4d2d2d19c3b4f93b96d3ee000b6b987eca3..6cbc2593ebf3691f37932c7ae925019cca6902f7 100644 (file)
@@ -1196,7 +1196,11 @@ static int detector_add_host_port_application(lua_State* L)
     if (toipprotocol(L, ++index, proto))
         return 0;
 
-    if (!ud->get_odp_ctxt().host_port_cache_add(&ip_address, (uint16_t)port, proto, type, app_id))
+    lua_getglobal(L, LUA_STATE_GLOBAL_SC_ID);
+    const SnortConfig* sc = *static_cast<const SnortConfig**>(lua_touserdata(L, -1));
+    lua_pop(L, 1);
+    if (!ud->get_odp_ctxt().host_port_cache_add(
+        sc, &ip_address, (uint16_t)port, proto, type, app_id))
         ErrorMessage("%s:Failed to backend call\n",__func__);
 
     return 0;
index 8d849689c89b2266a62923411eef71a50eaf3da9..8aa83efb8ec42b4b79e361184c7deeb7443f8e8f 100644 (file)
@@ -42,6 +42,8 @@ class AppInfoTableEntry;
 #define DETECTOR "Detector"
 #define DETECTORFLOW "DetectorFlow"
 
+#define LUA_STATE_GLOBAL_SC_ID  "snort_config"
+
 struct DetectorPackageInfo
 {
     std::string initFunctionName;
index d9414813ebf196ff1dacc8c766911050985c16f4..ebdca953245f8d77fb7d5b12f6695156775c3234 100644 (file)
@@ -202,7 +202,8 @@ LuaDetectorManager::~LuaDetectorManager()
     cb_detectors.clear(); // do not free Lua objects in cb_detectors
 }
 
-void LuaDetectorManager::initialize(AppIdContext& ctxt, bool is_control, bool reload)
+void LuaDetectorManager::initialize(const SnortConfig* sc, AppIdContext& ctxt, bool is_control,
+    bool reload)
 {
     LuaDetectorManager* lua_detector_mgr = new LuaDetectorManager(ctxt, is_control);
     odp_thread_local_ctxt->set_lua_detector_mgr(*lua_detector_mgr);
@@ -226,17 +227,17 @@ void LuaDetectorManager::initialize(AppIdContext& ctxt, bool is_control, bool re
     }
 
     lua_detector_mgr->initialize_lua_detectors(is_control, reload);
-    lua_detector_mgr->activate_lua_detectors();
+    lua_detector_mgr->activate_lua_detectors(sc);
 
     if (ctxt.config.list_odp_detectors)
         lua_detector_mgr->list_lua_detectors();
 }
 
-void LuaDetectorManager::init_thread_manager(const AppIdContext& ctxt)
+void LuaDetectorManager::init_thread_manager(const SnortConfig* sc, const AppIdContext& ctxt)
 {
     LuaDetectorManager* lua_detector_mgr = lua_detector_mgr_list[get_instance_id()];
     odp_thread_local_ctxt->set_lua_detector_mgr(*lua_detector_mgr);
-    lua_detector_mgr->activate_lua_detectors();
+    lua_detector_mgr->activate_lua_detectors(sc);
     if (ctxt.config.list_odp_detectors)
         lua_detector_mgr->list_lua_detectors();
 }
@@ -582,7 +583,7 @@ void LuaDetectorManager::initialize_lua_detectors(bool is_control, bool reload)
     load_lua_detectors(path, true, is_control, reload);
 }
 
-void LuaDetectorManager::activate_lua_detectors()
+void LuaDetectorManager::activate_lua_detectors(const SnortConfig* sc)
 {
     uint32_t lua_tracker_size = compute_lua_tracker_size(MAX_MEMORY_FOR_LUA_DETECTORS,
         allocated_objects.size());
@@ -616,6 +617,9 @@ void LuaDetectorManager::activate_lua_detectors()
 
         /*second parameter is a table containing configuration stuff. */
         lua_newtable(L);
+        const SnortConfig** sc_ud = static_cast<const SnortConfig**>(lua_newuserdata(L, sizeof(const SnortConfig*)));
+        *(sc_ud) = sc;
+        lua_setglobal(L, LUA_STATE_GLOBAL_SC_ID);
         if (lua_pcall(L, 2, 1, 0))
         {
             if (init(L))
@@ -627,6 +631,7 @@ void LuaDetectorManager::activate_lua_detectors()
             lo = allocated_objects.erase(lo);
             continue;
         }
+        *(sc_ud) = nullptr;
 
         lua_getfield(L, LUA_REGISTRYINDEX, lsd->package_info.name.c_str());
         set_lua_tracker_size(L, lua_tracker_size);
index d380f8bee4f981aa29db44aee5ea30d5afb85f75..a8d57faf5f2141fa31876d43c5e007d6cfdf4146 100644 (file)
 
 #include "application_ids.h"
 
+namespace snort
+{
+    struct SnortConfig;
+}
+
 class AppIdContext;
 class AppIdDetector;
 struct DetectorFlow;
@@ -50,8 +55,9 @@ class LuaDetectorManager
 public:
     LuaDetectorManager(AppIdContext&, bool);
     ~LuaDetectorManager();
-    static void initialize(AppIdContext&, bool is_control=false, bool reload=false);
-    static void init_thread_manager(const AppIdContext&);
+    static void initialize(const snort::SnortConfig*, AppIdContext&, bool is_control=false,
+        bool reload=false);
+    static void init_thread_manager(const snort::SnortConfig*, const AppIdContext&);
     static void clear_lua_detector_mgrs();
 
     void set_detector_flow(DetectorFlow* df)
@@ -70,7 +76,7 @@ public:
 
 private:
     void initialize_lua_detectors(bool is_control, bool reload = false);
-    void activate_lua_detectors();
+    void activate_lua_detectors(const snort::SnortConfig*);
     void list_lua_detectors();
     bool load_detector(char* detector_name, bool is_custom, bool is_control, bool reload, std::string& buf);
     void load_lua_detectors(const char* path, bool is_custom, bool is_control, bool reload = false);
index 0a444b1ab77972ed99d580064901d5d8d5ea7f08..b7455990b45c6cb87fd007420e162ef9f8abe62b 100644 (file)
@@ -46,6 +46,7 @@ static const PegInfo bind_pegs[] =
 {
     { CountType::SUM, "raw_packets", "raw packets evaluated" },
     { CountType::SUM, "new_flows", "new flows evaluated" },
+    { CountType::SUM, "rebinds", "flows rebound" },
     { CountType::SUM, "service_changes", "flow service changes evaluated" },
     { CountType::SUM, "assistant_inspectors", "flow assistant inspector requests handled" },
     { CountType::SUM, "new_standby_flows", "new HA flows evaluated" },
@@ -431,7 +432,7 @@ bool BinderModule::end(const char* fqn, int idx, SnortConfig* sc)
             if ( policy_type == FILE_KEY )
             {
                 Shell* sh = new Shell(policy_filename.c_str());
-                auto policies = sc->policy_map->add_shell(sh, false);
+                auto policies = sc->policy_map->add_shell(sh, get_network_parse_policy());
                 binding.use.inspection_index = policies->inspection->policy_id;
                 binding.use.ips_index = policies->ips->policy_id;
             }
index 3b561c72a9916c47454a6bb3cdc883efb910830e..da15cb44d881750e9098d715232b9d2df1827128 100644 (file)
@@ -33,6 +33,7 @@ struct BindStats
 {
     PegCount raw_packets;
     PegCount new_flows;
+    PegCount rebinds;
     PegCount service_changes;
     PegCount assistant_inspectors;
     PegCount new_standby_flows;
index 1fa26e9db2b4ffb29205028abccf186574a0fbc8..f4ba3d3c9aedecf5eb147f4ea9d4c0e34b35abcc 100644 (file)
@@ -420,15 +420,8 @@ void Stuff::apply_action(Flow& flow)
 
 void Stuff::apply_session(Flow& flow)
 {
-    if (client)
-        flow.set_client(client);
-    else if (flow.ssn_client)
-        flow.clear_client();
-
-    if (server)
-        flow.set_server(server);
-    else if (flow.ssn_server)
-        flow.clear_server();
+    flow.set_client(client);
+    flow.set_server(server);
 }
 
 void Stuff::apply_service(Flow& flow)
@@ -490,6 +483,7 @@ public:
     void handle_flow_setup(Flow&, bool standby = false);
     void handle_flow_service_change(Flow&);
     void handle_assistant_gadget(const char* service, Flow&);
+    void handle_flow_after_reload(Flow&);
 
 private:
     void get_policy_bindings(Flow&, const char* service);
@@ -577,6 +571,19 @@ public:
     }
 };
 
+class RebindFlow : public DataHandler
+{
+public:
+    RebindFlow() : DataHandler(BIND_NAME) { }
+
+    void handle(DataEvent&, Flow* flow) override
+    {
+        Binder* binder = InspectorManager::get_binder();
+        if (binder && flow)
+            binder->handle_flow_after_reload(*flow);
+    }
+};
+
 Binder::Binder(std::vector<Binding>& bv, std::vector<Binding>& pbv)
 {
     bindings = std::move(bv);
@@ -623,6 +630,7 @@ bool Binder::configure(SnortConfig* sc)
     DataBus::subscribe(FLOW_SERVICE_CHANGE_EVENT, new FlowServiceChangeHandler());
     DataBus::subscribe(STREAM_HA_NEW_FLOW_EVENT, new StreamHANewFlowHandler());
     DataBus::subscribe(FLOW_ASSISTANT_GADGET_EVENT, new AssistantGadgetHandler());
+    DataBus::subscribe(FLOW_STATE_RELOADED_EVENT, new RebindFlow());
 
     return true;
 }
@@ -721,7 +729,7 @@ void Binder::handle_flow_setup(Flow& flow, bool standby)
         if (flow.ssn_state.snort_protocol_id != UNKNOWN_PROTOCOL_ID)
         {
             const SnortConfig* sc = SnortConfig::get_conf();
-            flow.service = sc->proto_ref->get_name(flow.ssn_state.snort_protocol_id);
+            flow.service = sc->proto_ref->get_shared_name(flow.ssn_state.snort_protocol_id);
         }
     }
 
@@ -742,7 +750,7 @@ void Binder::handle_flow_service_change(Flow& flow)
 
     Inspector* ins = nullptr;
     Inspector* data = nullptr;
-    if (flow.service)
+    if (flow.has_service())
     {
         ins = find_gadget(flow, data);
         if (flow.gadget != ins)
@@ -789,8 +797,8 @@ void Binder::handle_flow_service_change(Flow& flow)
 
     // If there is no inspector bound to this flow after the service change, see if there's at least
     // an associated protocol ID.
-    if (!ins && flow.service)
-        flow.ssn_state.snort_protocol_id = SnortConfig::get_conf()->proto_ref->find(flow.service);
+    if (!ins && flow.has_service())
+        flow.ssn_state.snort_protocol_id = SnortConfig::get_conf()->proto_ref->find(flow.service->c_str());
 
     if (flow.is_stream())
     {
@@ -820,9 +828,18 @@ void Binder::handle_assistant_gadget(const char* service, Flow& flow)
     bstats.assistant_inspectors++;
 }
 
+void Binder::handle_flow_after_reload(Flow& flow)
+{
+    Stuff stuff;
+    get_bindings(flow, stuff);
+    stuff.apply_action(flow);
+
+    bstats.rebinds++;
+    bstats.verdicts[stuff.action]++;
+}
+
 void Binder::get_policy_bindings(Flow& flow, const char* service)
 {
-    const SnortConfig* sc = SnortConfig::get_conf();
     unsigned inspection_index = 0;
     unsigned ips_index = 0;
 
@@ -847,13 +864,14 @@ void Binder::get_policy_bindings(Flow& flow, const char* service)
 
     if (inspection_index)
     {
-        set_inspection_policy(sc, inspection_index);
+        set_inspection_policy(inspection_index);
         if (!service)
             flow.inspection_policy_id = inspection_index;
     }
 
     if (ips_index)
     {
+        const SnortConfig* sc = SnortConfig::get_conf();
         set_ips_policy(sc, ips_index);
         if (!service)
             flow.ips_policy_id = ips_index;
@@ -862,7 +880,6 @@ void Binder::get_policy_bindings(Flow& flow, const char* service)
 
 void Binder::get_policy_bindings(Packet* p)
 {
-    const SnortConfig* sc = SnortConfig::get_conf();
     unsigned inspection_index = 0;
     unsigned ips_index = 0;
 
@@ -887,12 +904,13 @@ void Binder::get_policy_bindings(Packet* p)
 
     if (inspection_index)
     {
-        set_inspection_policy(sc, inspection_index);
+        set_inspection_policy(inspection_index);
         p->user_inspection_policy_id = get_inspection_policy()->user_policy_id;
     }
 
     if (ips_index)
     {
+        const SnortConfig* sc = SnortConfig::get_conf();
         set_ips_policy(sc, ips_index);
         p->user_ips_policy_id = get_ips_policy()->user_policy_id;
     }
@@ -965,7 +983,7 @@ void Binder::get_bindings(Packet* p, Stuff& stuff)
 Inspector* Binder::find_gadget(Flow& flow, Inspector*& data)
 {
     Stuff stuff;
-    get_bindings(flow, stuff, flow.service);
+    get_bindings(flow, stuff, flow.has_service() ? flow.service->c_str() : nullptr);
     data = stuff.data;
     return stuff.gadget;
 }
index 282bde1714563bb28c378fadca0dd772e73acde9..4b66064f1a5e6800a7b68b3b4a74b7b6b5341315 100644 (file)
@@ -579,10 +579,10 @@ inline bool Binding::check_service(const Flow& flow) const
     if (!when.has_criteria(BindWhen::Criteria::BWC_SVC))
         return true;
 
-    if (!flow.service)
+    if (!flow.has_service())
         return false;
 
-    return when.svc == flow.service;
+    return when.svc == flow.service->c_str();
 }
 
 inline bool Binding::check_service(const char* service) const
index 6bb3aa0b94b7cae439f8e321801de93d9e5eb065..725668ba2d6411104b196214aba0b63697db27ed 100644 (file)
@@ -331,7 +331,7 @@ bool PerfMonModule::end(const char* fqn, int idx, SnortConfig* sc)
 {
 
     if ( Snort::is_reloading() && strcmp(fqn, "perf_monitor") == 0 )
-        sc->register_reload_resource_tuner(new PerfMonReloadTuner(config->flowip_memcap));
+        sc->register_reload_handler(new PerfMonReloadTuner(config->flowip_memcap));
 
     if ( idx != 0 && strcmp(fqn, "perf_monitor.modules") == 0 )
         return config->modules.back().confirm_parse();
index 6f3ba11aad94abee5351a43c5289bd2245510b11..83933017204ac454e895cc5618c11d0081948ce2 100644 (file)
@@ -21,7 +21,7 @@
 #ifndef PERF_RELOAD_TUNER_H
 #define PERF_RELOAD_TUNER_H
 
-#include "main/snort_config.h"
+#include "main/reload_tuner.h"
 
 class PerfMonReloadTuner : public snort::ReloadResourceTuner
 {
index 63c29c45999e9e7685185cd71babd82cdc7e2539..bfb0e1077f38afbcad67d6f2e4aff53454baff7b 100644 (file)
@@ -25,6 +25,7 @@
 #include "ps_module.h"
 #include "log/messages.h"
 #include "main/snort.h"
+#include "main/snort_config.h"
 
 #include <cassert>
 
@@ -326,7 +327,7 @@ bool PortScanModule::set(const char* fqn, Value& v, SnortConfig*)
 bool PortScanModule::end(const char* fqn, int, SnortConfig* sc)
 {
     if ( Snort::is_reloading() && strcmp(fqn, "port_scan") == 0 )
-        sc->register_reload_resource_tuner(new PortScanReloadTuner(config->memcap));
+        sc->register_reload_handler(new PortScanReloadTuner(config->memcap));
     return true;
 }
 
index 798b436ee3d625ee39bfdf1a62a2233d860fc411..4d17b3e3004b977f1f7dfe73ae84a1b3277703dd 100644 (file)
 #define PS_MODULE_H
 
 #include "framework/module.h"
-#include "main/snort_config.h"
+#include "main/reload_tuner.h"
 #include "ps_detect.h"
 #include "ps_pegs.h"
 
+namespace snort
+{
+struct SnortConfig;
+}
+
 #define PS_NAME "port_scan"
 #define PS_HELP "detect various ip, icmp, tcp, and udp port or protocol scans"
 
index c4c225f15d2e428a6b22d2528198f410453df26d..e4d8680a8024749705a1044385a4e25ba90fd4ae 100644 (file)
@@ -3,6 +3,8 @@ set (REPUTATION_INCLUDES
 )
 
 add_library( reputation OBJECT
+    reputation_commands.cc
+    reputation_commands.h
     reputation_config.h
     reputation_inspect.h
     reputation_inspect.cc
diff --git a/src/network_inspectors/reputation/reputation_commands.cc b/src/network_inspectors/reputation/reputation_commands.cc
new file mode 100644 (file)
index 0000000..201e17c
--- /dev/null
@@ -0,0 +1,90 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2021-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// reputation_commands.cc author Ron Dempster <rdempste@cisco.com>
+
+#ifdef HAVE_CONFIG_H
+#include "config.h"
+#endif
+
+#include "reputation_commands.h"
+
+#include "control/control.h"
+#include "log/messages.h"
+#include "main/analyzer_command.h"
+#include "managers/inspector_manager.h"
+
+#include "reputation_common.h"
+#include "reputation_inspect.h"
+
+using namespace snort;
+
+class ReputationReload : public AnalyzerCommand
+{
+public:
+    ReputationReload(ControlConn*, Reputation&);
+    ~ReputationReload() override;
+
+    bool execute(Analyzer&, void**) override;
+    bool need_update_reload_id() const override
+    { return true; }
+
+    const char* stringify() override
+    { return "REPUTATION_RELOAD"; }
+
+protected:
+    Reputation& ins;
+    ReputationData* data;
+};
+
+ReputationReload::ReputationReload(ControlConn* conn, Reputation& ins)
+    : AnalyzerCommand(conn), ins(ins)
+{
+    ins.add_global_ref();
+    log_message(".. reputation reloading\n");
+    data = ins.load_data();
+}
+
+ReputationReload::~ReputationReload()
+{
+    ins.swap_data(data);
+    log_message("== Reputation reload complete\n");
+    ins.rem_global_ref();
+}
+
+bool ReputationReload::execute(Analyzer&, void**)
+{
+    ins.swap_thread_data(data);
+    return true;
+}
+
+static int reload(lua_State* L)
+{
+    ControlConn* ctrlcon = ControlConn::query_from_lua(L);
+    Reputation* ins = static_cast<Reputation*>(InspectorManager::get_inspector(REPUTATION_NAME));
+    if (ins)
+        main_broadcast_command(new ReputationReload(ctrlcon, *ins), ctrlcon);
+    else
+        AnalyzerCommand::log_message(ctrlcon, "No reputation instance configured to reload\n");
+    return 0;
+}
+
+const Command reputation_cmds[] =
+{
+    {"reload", reload, nullptr, "reload reputation data"},
+    {nullptr, nullptr, nullptr, nullptr}
+};
diff --git a/src/network_inspectors/reputation/reputation_commands.h b/src/network_inspectors/reputation/reputation_commands.h
new file mode 100644 (file)
index 0000000..15a1b4b
--- /dev/null
@@ -0,0 +1,28 @@
+//--------------------------------------------------------------------------
+// Copyright (C) 2021-2022 Cisco and/or its affiliates. All rights reserved.
+//
+// This program is free software; you can redistribute it and/or modify it
+// under the terms of the GNU General Public License Version 2 as published
+// by the Free Software Foundation.  You may not use, modify or distribute
+// this program under any other version of the GNU General Public License.
+//
+// This program is distributed in the hope that it will be useful, but
+// WITHOUT ANY WARRANTY; without even the implied warranty of
+// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+// General Public License for more details.
+//
+// You should have received a copy of the GNU General Public License along
+// with this program; if not, write to the Free Software Foundation, Inc.,
+// 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
+//--------------------------------------------------------------------------
+// reputation_commands.h author Ron Dempster <rdempste@cisco.com>
+
+#ifndef REPUTATION_COMMANDS_H
+#define REPUTATION_COMMANDS_H
+
+#include "framework/module.h"
+
+extern const snort::Command reputation_cmds[];
+
+#endif
+
index e92addbf5f1bf36cab17b01efa9ce17dbcbfa3cd..99a8fcfc7e784e022f37806dd7a7950431388ced 100644 (file)
@@ -80,20 +80,13 @@ typedef std::vector<ListFile*> ListFiles;
 struct ReputationConfig
 {
     uint32_t memcap = 500;
-    int num_entries = 0;
     bool scanlocal = false;
     IPdecision priority = TRUSTED;
     NestedIP nested_ip = INNER;
     AllowAction allow_action = DO_NOT_BLOCK;
     std::string blocklist_path;
     std::string allowlist_path;
-    bool memcap_reached = false;
-    uint8_t* reputation_segment = nullptr;
-    table_flat_t* ip_list = nullptr;
-    ListFiles list_files;
     std::string list_dir;
-
-    ~ReputationConfig();
 };
 
 struct IPrepInfo
index f2f3f54e7d1d489c5bd664a76f3e98490bdbf8f6..d41f28b6e9b509e52ce733f7dfc5a345fcd76d3d 100644 (file)
 #include "detection/detection_engine.h"
 #include "events/event_queue.h"
 #include "log/messages.h"
+#include "main/snort.h"
+#include "main/snort_config.h"
+#include "managers/inspector_manager.h"
 #include "network_inspectors/packet_tracer/packet_tracer.h"
 #include "packet_io/active.h"
 #include "profiler/profiler.h"
 #include "protocols/packet.h"
 #include "pub_sub/auxiliary_ip_event.h"
+#include "utils/util.h"
 
 #include "reputation_parse.h"
 
@@ -55,52 +59,28 @@ const PegInfo reputation_peg_names[] =
 { CountType::END, nullptr, nullptr }
 };
 
-const char* NestedIPKeyword[] =
-{
-    "inner",
-    "outer",
-    "all",
-    nullptr
-};
+#define MANIFEST_FILENAME "interface.info"
 
-const char* AllowActionOption[] =
+static inline IPrepInfo* reputation_lookup(const ReputationConfig& config,
+    ReputationData& data, const SfIp* ip)
 {
-    "do_not_block",
-    "trust",
-    nullptr
-};
-
-/*
- * Function prototype(s)
- */
-static void snort_reputation(ReputationConfig* GlobalConf, Packet* p);
-static void populate_trace_data(IPdecision& decision, Packet* p);
-
-static inline IPrepInfo* reputation_lookup(ReputationConfig* config, const SfIp* ip)
-{
-    IPrepInfo* result;
-
-    if (!config->scanlocal)
+    if (!config.scanlocal)
     {
         if (ip->is_private() )
-        {
             return nullptr;
-        }
     }
 
-    result = (IPrepInfo*)sfrt_flat_dir8x_lookup(ip, config->ip_list);
-
-    return (result);
+    return (IPrepInfo*)sfrt_flat_dir8x_lookup(ip, data.ip_list);
 }
 
-static inline IPdecision get_reputation(ReputationConfig* config, IPrepInfo* rep_info,
-    uint32_t* listid, uint32_t ingress_intf, uint32_t egress_intf)
+static inline IPdecision get_reputation(const ReputationConfig& config, ReputationData& data,
+    IPrepInfo* rep_info, uint32_t* listid, uint32_t ingress_intf, uint32_t egress_intf)
 {
     IPdecision decision = DECISION_NULL;
 
     /*Walk through the IPrepInfo lists*/
-    uint8_t* base = (uint8_t*)config->ip_list;
-    ListFiles& list_info =  config->list_files;
+    uint8_t* base = (uint8_t*)data.ip_list;
+    ListFiles& list_info = data.list_files;
 
     while (rep_info)
     {
@@ -117,7 +97,7 @@ static inline IPdecision get_reputation(ReputationConfig* config, IPrepInfo* rep
             {
                 if (TRUSTED_DO_NOT_BLOCK == (IPdecision)list_info[list_index]->list_type)
                     return DECISION_NULL;
-                if (config->priority == (IPdecision)list_info[list_index]->list_type )
+                if (config.priority == (IPdecision)list_info[list_index]->list_type )
                 {
                     *listid = list_info[list_index]->list_id;
                     return  ((IPdecision)list_info[list_index]->list_type);
@@ -138,18 +118,16 @@ static inline IPdecision get_reputation(ReputationConfig* config, IPrepInfo* rep
     return decision;
 }
 
-static bool decision_per_layer(ReputationConfig* config, Packet* p,
-    uint32_t ingress_intf, uint32_t egress_intf, const ip::IpApi& ip_api, IPdecision* decision_final)
+static bool decision_per_layer(const ReputationConfig& config, ReputationData& data,
+    Packet* p, uint32_t ingress_intf, uint32_t egress_intf, const ip::IpApi& ip_api,
+    IPdecision* decision_final)
 {
-    const SfIp* ip;
-    IPdecision decision;
-    IPrepInfo* result;
-
-    ip = ip_api.get_src();
-    result = reputation_lookup(config, ip);
+    const SfIp* ip = ip_api.get_src();
+    IPrepInfo* result = reputation_lookup(config, data, ip);
     if (result)
     {
-        decision = get_reputation(config, result, &p->iplist_id, ingress_intf, egress_intf);
+        IPdecision decision = get_reputation(config, data, result, &p->iplist_id, ingress_intf,
+            egress_intf);
 
         if (decision == BLOCKED)
             *decision_final = BLOCKED_SRC;
@@ -160,15 +138,16 @@ static bool decision_per_layer(ReputationConfig* config, Packet* p,
         else
             *decision_final = decision;
 
-        if ( config->priority == decision)
+        if ( config.priority == decision)
             return true;
     }
 
     ip = ip_api.get_dst();
-    result = reputation_lookup(config, ip);
+    result = reputation_lookup(config, data, ip);
     if (result)
     {
-        decision = get_reputation(config, result, &p->iplist_id, ingress_intf, egress_intf);
+        IPdecision decision = get_reputation(config, data, result, &p->iplist_id, ingress_intf,
+            egress_intf);
 
         if (decision == BLOCKED)
             *decision_final = BLOCKED_DST;
@@ -179,14 +158,15 @@ static bool decision_per_layer(ReputationConfig* config, Packet* p,
         else
             *decision_final = decision;
 
-        if ( config->priority == decision)
+        if ( config.priority == decision)
             return true;
     }
 
     return false;
 }
 
-static IPdecision reputation_decision(ReputationConfig* config, Packet* p)
+static IPdecision reputation_decision(const ReputationConfig& config, ReputationData& data,
+    Packet* p)
 {
     IPdecision decision_final = DECISION_NULL;
     uint32_t ingress_intf = 0;
@@ -201,9 +181,9 @@ static IPdecision reputation_decision(ReputationConfig* config, Packet* p)
             egress_intf = p->pkth->egress_index;
     }
 
-    if (config->nested_ip == INNER)
+    if (config.nested_ip == INNER)
     {
-        decision_per_layer(config, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final);
+        decision_per_layer(config, data, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final);
         return decision_final;
     }
 
@@ -213,19 +193,19 @@ static IPdecision reputation_decision(ReputationConfig* config, Packet* p)
     int8_t num_layer = 0;
     IpProtocol tmp_next = p->get_ip_proto_next();
 
-    if (config->nested_ip == OUTER)
+    if (config.nested_ip == OUTER)
     {
         layer::set_outer_ip_api(p, p->ptrs.ip_api, p->ip_proto_next, num_layer);
-        decision_per_layer(config, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final);
+        decision_per_layer(config, data, p, ingress_intf, egress_intf, p->ptrs.ip_api, &decision_final);
     }
-    else if (config->nested_ip == ALL)
+    else if (config.nested_ip == ALL)
     {
         bool done = false;
         IPdecision decision_current = DECISION_NULL;
 
         while (!done and layer::set_outer_ip_api(p, p->ptrs.ip_api, p->ip_proto_next, num_layer))
         {
-            done = decision_per_layer(config, p, ingress_intf, egress_intf, p->ptrs.ip_api,
+            done = decision_per_layer(config, data, p, ingress_intf, egress_intf, p->ptrs.ip_api,
                 &decision_current);
             if (decision_current != DECISION_NULL)
             {
@@ -241,18 +221,19 @@ static IPdecision reputation_decision(ReputationConfig* config, Packet* p)
 
     if (decision_final != BLOCKED_SRC and decision_final != BLOCKED_DST)
         p->ptrs.ip_api = tmp_api;
-    else if (config->nested_ip == ALL and p->ptrs.ip_api != blocked_api)
+    else if (config.nested_ip == ALL and p->ptrs.ip_api != blocked_api)
         p->ptrs.ip_api = blocked_api;
 
     p->ip_proto_next = tmp_next;
     return decision_final;
 }
 
-static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, const SfIp* ip)
+static IPdecision snort_reputation_aux_ip(const ReputationConfig& config, ReputationData& data,
+    Packet* p, const SfIp* ip)
 {
     IPdecision decision = DECISION_NULL;
 
-    if (!config->ip_list)
+    if (!data.ip_list)
         return decision;
 
     uint32_t ingress_intf = 0;
@@ -267,10 +248,10 @@ static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, c
             egress_intf = p->pkth->egress_index;
     }
 
-    IPrepInfo* result = reputation_lookup(config, ip);
+    IPrepInfo* result = reputation_lookup(config, data, ip);
     if (result)
     {
-        decision = get_reputation(config, result, &p->iplist_id, ingress_intf,
+        decision = get_reputation(config, data, result, &p->iplist_id, ingress_intf,
             egress_intf);
 
         if (decision == BLOCKED)
@@ -279,7 +260,7 @@ static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, c
                 p->flow->flags.reputation_blocklist = true;
 
             // Prior to IPRep logging, IPS policy must be set to the default policy,
-            set_ips_policy(SnortConfig::get_conf(), 0);
+            set_ips_policy(get_default_ips_policy(SnortConfig::get_conf()));
 
             DetectionEngine::queue_event(GID_REPUTATION, REPUTATION_EVENT_BLOCKLIST_DST);
             p->active->drop_packet(p, true);
@@ -321,14 +302,67 @@ static IPdecision snort_reputation_aux_ip(ReputationConfig* config, Packet* p, c
     return decision;
 }
 
-static void snort_reputation(ReputationConfig* config, Packet* p)
+static const char* to_string(IPdecision ipd)
+{
+    switch (ipd)
+    {
+    case BLOCKED:
+        return "blocked";
+    case TRUSTED:
+        return "trusted";
+    case MONITORED:
+        return "monitored";
+    case BLOCKED_SRC:
+        return "blocked_src";
+    case BLOCKED_DST:
+        return "blocked_dst";
+    case TRUSTED_SRC:
+        return "trusted_src";
+    case TRUSTED_DST:
+        return "trusted_dst";
+    case TRUSTED_DO_NOT_BLOCK:
+        return "trusted_do_not_block";
+    case MONITORED_SRC:
+        return "monitored_src";
+    case MONITORED_DST:
+        return "monitored_dst";
+    case DECISION_NULL:
+    case DECISION_MAX:
+    default:
+        return "";
+    }
+}
+
+static void populate_trace_data(IPdecision& decision, Packet* p)
+{
+    char addr[INET6_ADDRSTRLEN];
+    const SfIp* ip = nullptr;
+
+    if (BLOCKED_SRC == decision or MONITORED_SRC == decision or TRUSTED_SRC == decision)
+    {
+        ip = p->ptrs.ip_api.get_src();
+    }
+    else if (BLOCKED_DST == decision or MONITORED_DST == decision or TRUSTED_DST == decision)
+    {
+        ip = p->ptrs.ip_api.get_dst();
+    }
+
+    sfip_ntop(ip, addr, sizeof(addr));
+
+    PacketTracer::daq_log("SI-IP+%" PRId64"+%s list id %u+Matched ip %s, action %s$",
+        TO_NSECS(pt_timer->get()),
+        (TRUSTED_SRC == decision or TRUSTED_DST == decision)?"Do_not_block":"Block",
+        p->iplist_id, addr, to_string(decision));
+}
+
+static void snort_reputation(const ReputationConfig& config, ReputationData& data, Packet* p)
 {
     IPdecision decision;
 
-    if (!config->ip_list)
+    if (!data.ip_list)
         return;
 
-    decision = reputation_decision(config, p);
+    decision = reputation_decision(config, data, p);
     Active* act = p->active;
 
     if (BLOCKED_SRC == decision or BLOCKED_DST == decision)
@@ -364,7 +398,7 @@ static void snort_reputation(ReputationConfig* config, Packet* p)
         const auto& aux_ip_list =  p->flow->stash->get_aux_ip_list();
         for ( const auto& ip : aux_ip_list )
         {
-            if ( BLOCKED == snort_reputation_aux_ip(config, p, &ip) )
+            if ( BLOCKED == snort_reputation_aux_ip(config, data, p, &ip) )
                 return;
         }
     }
@@ -437,73 +471,23 @@ static const char* to_string(AllowAction aa)
     return "";
 }
 
-static const char* to_string(IPdecision ipd)
-{
-    switch (ipd)
-    {
-    case BLOCKED:
-        return "blocked";
-    case TRUSTED:
-        return "trusted";
-    case MONITORED:
-        return "monitored";
-    case BLOCKED_SRC:
-        return "blocked_src";
-    case BLOCKED_DST:
-        return "blocked_dst";
-    case TRUSTED_SRC:
-        return "trusted_src";
-    case TRUSTED_DST:
-        return "trusted_dst";
-    case TRUSTED_DO_NOT_BLOCK:
-        return "trusted_do_not_block";
-    case MONITORED_SRC:
-        return "monitored_src";
-    case MONITORED_DST:
-        return "monitored_dst";
-    case DECISION_NULL:
-    case DECISION_MAX:
-    default:
-        return "";
-    }
-}
-
-static void populate_trace_data(IPdecision& decision, Packet* p)
-{
-    char addr[INET6_ADDRSTRLEN];
-    const SfIp* ip = nullptr;
-
-    if (BLOCKED_SRC == decision or MONITORED_SRC == decision or TRUSTED_SRC == decision)
-    {
-        ip = p->ptrs.ip_api.get_src();
-    }
-    else if (BLOCKED_DST == decision or MONITORED_DST == decision or TRUSTED_DST == decision)
-    {
-        ip = p->ptrs.ip_api.get_dst();
-    }
-
-    sfip_ntop(ip, addr, sizeof(addr));
-
-    PacketTracer::daq_log("SI-IP+%" PRId64"+%s list id %u+Matched ip %s, action %s$",
-        TO_NSECS(pt_timer->get()),
-        (TRUSTED_SRC == decision or TRUSTED_DST == decision)?"Do_not_block":"Block",
-        p->iplist_id, addr, to_string(decision));
-}
-
 class AuxiliaryIpRepHandler : public DataHandler
 {
 public:
-    AuxiliaryIpRepHandler(ReputationConfig& rc) : DataHandler(REPUTATION_NAME), conf(rc) { }
+    explicit AuxiliaryIpRepHandler(Reputation& inspector)
+        : DataHandler(REPUTATION_NAME), inspector(inspector)
+    { }
     void handle(DataEvent&, Flow*) override;
 
 private:
-    ReputationConfig& conf;
+    Reputation& inspector;
 };
 
 void AuxiliaryIpRepHandler::handle(DataEvent& event, Flow*)
 {
     Profile profile(reputation_perf_stats);
-    snort_reputation_aux_ip(&conf, DetectionEngine::get_current_packet(),
+    snort_reputation_aux_ip(inspector.get_config(), inspector.get_data(),
+        DetectionEngine::get_current_packet(),
         static_cast<AuxiliaryIpEvent*>(&event)->get_ip());
 }
 
@@ -511,26 +495,58 @@ void AuxiliaryIpRepHandler::handle(DataEvent& event, Flow*)
 // class stuff
 //-------------------------------------------------------------------------
 
-Reputation::Reputation(ReputationConfig* pc)
+ReputationData::~ReputationData()
 {
-    config = *pc;
-    ReputationConfig* conf = &config;
+    if (reputation_segment)
+        snort_free(reputation_segment);
+
+    for (auto& file : list_files)
+        delete file;
+}
+
+Reputation::Reputation(ReputationConfig* pc) : config(*pc)
+{ rep_data = load_data(); }
+
+Reputation::~Reputation()
+{ delete rep_data; }
+
+ReputationData* Reputation::load_data()
+{
+    ReputationData* data = new ReputationData();
     if (!config.list_dir.empty())
-        read_manifest(MANIFEST_FILENAME, conf);
+        read_manifest(MANIFEST_FILENAME, config, *data);
 
-    add_block_allow_List(conf);
-    estimate_num_entries(conf);
-    if (conf->num_entries <= 0)
+    add_block_allow_List(config, *data);
+    estimate_num_entries(*data);
+    if (0 >= data->num_entries)
     {
         ParseWarning(WARN_CONF,
             "reputation: can't find any allowlist/blocklist entries; disabled.");
-        return;
     }
+    else
+    {
+        ip_list_init(data->num_entries + 1, config, *data);
+        reputationstats.memory_allocated = sfrt_flat_usage(data->ip_list);
+    }
+
+    return data;
+}
+
+void Reputation::swap_thread_data(ReputationData* data)
+{ set_thread_specific_data(data); }
 
-    ip_list_init(conf->num_entries + 1, conf);
-    reputationstats.memory_allocated = sfrt_flat_usage(conf->ip_list);
+void Reputation::swap_data(ReputationData* data)
+{
+    delete rep_data;
+    rep_data = data;
 }
 
+void Reputation::tinit()
+{ set_thread_specific_data(rep_data); }
+
+void Reputation::tterm()
+{ set_thread_specific_data(nullptr); }
+
 void Reputation::show(const SnortConfig*) const
 {
     ConfigLogger::log_value("blocklist", config.blocklist_path.c_str());
@@ -556,16 +572,21 @@ void Reputation::eval(Packet* p)
     if (PacketTracer::is_daq_activated())
         PacketTracer::pt_timer_start();
 
-    snort_reputation(&config, p);
+    ReputationData* data = static_cast<ReputationData*>(get_thread_specific_data());
+    assert(data);
+    snort_reputation(config, *data, p);
     ++reputationstats.packets;
 }
 
 bool Reputation::configure(SnortConfig*)
 {
-    DataBus::subscribe_network( AUXILIARY_IP_EVENT, new AuxiliaryIpRepHandler(config) );
+    DataBus::subscribe_network( AUXILIARY_IP_EVENT, new AuxiliaryIpRepHandler(*this) );
     return true;
 }
 
+void Reputation::install_reload_handler(SnortConfig* sc)
+{ sc->register_reload_handler(new ReputationReloadSwapper(*this)); }
+
 //-------------------------------------------------------------------------
 // api stuff
 //-------------------------------------------------------------------------
index 715523905cdd15d0e797ded78430bfa1fc988db5..9ba573abc9571e09272f956c18a892dd9e003e65 100644 (file)
 #ifndef REPUTATION_INSPECT_H
 #define REPUTATION_INSPECT_H
 
-#include "flow/flow.h"
+#include "framework/inspector.h"
 
 #include "reputation_module.h"
 
+class ReputationData
+{
+public:
+    ReputationData() = default;
+    ~ReputationData();
+
+    ListFiles list_files;
+    uint8_t* reputation_segment = nullptr;
+    table_flat_t* ip_list = nullptr;
+    int num_entries = 0;
+    bool memcap_reached = false;
+};
+
 class Reputation : public snort::Inspector
 {
 public:
-    Reputation(ReputationConfig*);
+    explicit Reputation(ReputationConfig*);
+    ~Reputation() override;
+
+    void tinit() override;
+    void tterm() override;
 
     void show(const snort::SnortConfig*) const override;
     void eval(snort::Packet*) override;
     bool configure(snort::SnortConfig*) override;
+    void install_reload_handler(snort::SnortConfig*) override;
+
+    ReputationData& get_data()
+    { return *rep_data; }
+    const ReputationConfig& get_config()
+    { return config; }
+    ReputationData* load_data();
+
+    void swap_thread_data(ReputationData*);
+    void swap_data(ReputationData*);
 
 private:
     ReputationConfig config;
+    ReputationData* rep_data;
 };
 
 #endif
index 5bcd6541c47f1aa4a00571674b208a9a380b1f1c..e130cc3af55f08d2034a0eb5a1328c0982df0d8e 100644 (file)
@@ -29,6 +29,8 @@
 #include "log/messages.h"
 #include "utils/util.h"
 
+#include "reputation_commands.h"
+#include "reputation_inspect.h"
 #include "reputation_parse.h"
 
 using namespace snort;
@@ -109,6 +111,9 @@ ReputationModule::~ReputationModule()
 const RuleMap* ReputationModule::get_rules() const
 { return reputation_rules; }
 
+const Command* ReputationModule::get_commands() const
+{ return reputation_cmds; }
+
 const PegInfo* ReputationModule::get_pegs() const
 { return reputation_peg_names; }
 
@@ -175,3 +180,7 @@ bool ReputationModule::end(const char*, int, SnortConfig*)
 
     return true;
 }
+
+void ReputationReloadSwapper::tswap()
+{ inspector.set_thread_specific_data(&inspector.get_data()); }
+
index 5b221d3a8a10d2e75093170542b904e13b4a5b19..6f8a53961e3a7275be8ac49077b48b38b5e055c4 100644 (file)
@@ -24,6 +24,8 @@
 // Interface to the REPUTATION network inspector
 
 #include "framework/module.h"
+#include "main/reload_tuner.h"
+
 #include "reputation_config.h"
 #include "reputation_common.h"
 
@@ -36,6 +38,21 @@ extern THREAD_LOCAL snort::ProfileStats reputation_perf_stats;
 extern unsigned long total_duplicates;
 extern unsigned long total_invalids;
 
+class Reputation;
+
+class ReputationReloadSwapper : public snort::ReloadSwapper
+{
+public:
+    explicit ReputationReloadSwapper(Reputation& ins) : inspector(ins)
+    { }
+    ~ReputationReloadSwapper() override = default;
+
+    void tswap() override;
+
+private:
+    Reputation& inspector;
+};
+
 class ReputationModule : public snort::Module
 {
 public:
@@ -50,6 +67,7 @@ public:
     { return GID_REPUTATION; }
 
     const snort::RuleMap* get_rules() const override;
+    const snort::Command* get_commands() const override;
     const PegInfo* get_pegs() const override;
     PegCount* get_counts() const override;
     snort::ProfileStats* get_profile() const override;
index d6e038942468c55c6ee0f575038ae33c58a735a3..d1df8868115e8a958daefa4dd28cce9e1ecba7fd 100644 (file)
@@ -36,6 +36,9 @@
 #include "utils/util.h"
 #include "utils/util_cstring.h"
 
+#include "reputation_config.h"
+#include "reputation_inspect.h"
+
 using namespace snort;
 using namespace std;
 
@@ -76,19 +79,6 @@ unsigned long total_invalids;
 
 int totalNumEntries = 0;
 
-static void load_list_file(ListFile*, ReputationConfig* config);
-
-ReputationConfig::~ReputationConfig()
-{
-    if (reputation_segment != nullptr)
-        snort_free(reputation_segment);
-
-    for (auto& file : list_files)
-    {
-        delete file;
-    }
-}
-
 static uint32_t estimate_size(uint32_t num_entries, uint32_t memcap)
 {
     uint64_t size;
@@ -114,49 +104,6 @@ static uint32_t estimate_size(uint32_t num_entries, uint32_t memcap)
     return (uint32_t)size;
 }
 
-void ip_list_init(uint32_t max_entries, ReputationConfig* config)
-{
-    if ( !config->ip_list )
-    {
-        uint32_t mem_size;
-        mem_size = estimate_size(max_entries, config->memcap);
-        config->reputation_segment = (uint8_t*)snort_alloc(mem_size);
-
-        segment_meminit(config->reputation_segment, mem_size);
-
-        /*DIR_16x7_4x4 for performance, but memory usage is high
-         *Use  DIR_8x16 worst case IPV4 5K, IPV6 15K (bytes)
-         *Use  DIR_16x7_4x4 worst case IPV4 500, IPV6 2.5M
-         */
-        config->ip_list = sfrt_flat_new(DIR_8x16, IPv6, max_entries, config->memcap);
-
-        if ( !config->ip_list )
-        {
-            ErrorMessage("Failed to create IP list.\n");
-            return;
-        }
-
-        total_duplicates = 0;
-        for (size_t i = 0; i < config->list_files.size(); i++)
-        {
-            config->list_files[i]->list_index = (uint8_t)i + 1;
-            if (config->list_files[i]->file_type == ALLOW_LIST)
-            {
-                if (config->allow_action == DO_NOT_BLOCK)
-                    config->list_files[i]->list_type = TRUSTED_DO_NOT_BLOCK;
-                else
-                    config->list_files[i]->list_type = TRUSTED;
-            }
-            else if (config->list_files[i]->file_type == BLOCK_LIST)
-                config->list_files[i]->list_type = BLOCKED;
-            else if (config->list_files[i]->file_type == MONITOR_LIST)
-                config->list_files[i]->list_type = MONITORED;
-
-            load_list_file(config->list_files[i], config);
-        }
-    }
-}
-
 static inline IPrepInfo* get_last_index(IPrepInfo* rep_info, uint8_t* base, int* last_index)
 {
     int i;
@@ -311,51 +258,39 @@ static int64_t update_entry_info(INFO* current, INFO new_entry, SaveDest save_de
     return bytes_allocated;
 }
 
-static int add_ip(SfCidr* ip_addr,INFO info_ptr, ReputationConfig* config)
+static int add_ip(SfCidr* ip_addr,INFO info_ptr, const ReputationConfig& config,
+    ReputationData& data)
 {
-    int ret;
-    int final_ret = IP_INSERT_SUCCESS;
     /*This variable is used to check whether a more generic address
      * overrides specific address
      */
     uint32_t usage_before;
     uint32_t usage_after;
 
-    usage_before =  sfrt_flat_usage(config->ip_list);
+    usage_before =  sfrt_flat_usage(data.ip_list);
 
+    int final_ret = IP_INSERT_SUCCESS;
     /*Check whether the same or more generic address is already in the table*/
-    if (nullptr != sfrt_flat_lookup(ip_addr->get_addr(), config->ip_list))
-    {
+    if (nullptr != sfrt_flat_lookup(ip_addr->get_addr(), data.ip_list))
         final_ret = IP_INSERT_DUPLICATE;
-    }
 
-    ret = sfrt_flat_insert(ip_addr, (unsigned char)ip_addr->get_bits(), info_ptr, RT_FAVOR_ALL,
-        config->ip_list, &update_entry_info);
+    int ret = sfrt_flat_insert(ip_addr, (unsigned char)ip_addr->get_bits(), info_ptr, RT_FAVOR_ALL,
+        data.ip_list, &update_entry_info);
 
     if (RT_SUCCESS == ret)
-    {
         totalNumEntries++;
-    }
     else if (MEM_ALLOC_FAILURE == ret)
-    {
         final_ret = IP_MEM_ALLOC_FAILURE;
-    }
     else
-    {
         final_ret = IP_INSERT_FAILURE;
-    }
 
-    usage_after = sfrt_flat_usage(config->ip_list);
+    usage_after = sfrt_flat_usage(data.ip_list);
     /*Compare in the same scale*/
-    if (usage_after  > (config->memcap << 20))
-    {
+    if (usage_after > (config.memcap << 20))
         final_ret = IP_MEM_ALLOC_FAILURE;
-    }
     /*Check whether there a more specific address will be overridden*/
     if (usage_before > usage_after )
-    {
         final_ret = IP_INSERT_DUPLICATE;
-    }
 
     return final_ret;
 }
@@ -499,7 +434,8 @@ static int snort_pton(char const* src, SfCidr* dest)
     return 1;
 }
 
-static int process_line(char* line, INFO info, ReputationConfig* config)
+static int process_line(char* line, INFO info, const ReputationConfig& config,
+    ReputationData& data)
 {
     SfCidr address;
 
@@ -509,7 +445,7 @@ static int process_line(char* line, INFO info, ReputationConfig* config)
     if ( snort_pton(line, &address) < 1 )
         return IP_INVALID;
 
-    return add_ip(&address, info, config);
+    return add_ip(&address, info, config, data);
 }
 
 static int update_path_to_file(char* full_filename, unsigned int max_size, const char* filename)
@@ -571,7 +507,8 @@ static char* get_list_type_name(ListFile* list_info)
     }
 }
 
-static void load_list_file(ListFile* list_info, ReputationConfig* config)
+static void load_list_file(ListFile* list_info, const ReputationConfig& config,
+    ReputationData& data)
 {
     char linebuf[MAX_ADDR_LINE_LENGTH];
     char full_path_filename[PATH_MAX+1];
@@ -589,7 +526,7 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config)
     unsigned int fail_count = 0;   /*number of invalid entries in this file*/
     unsigned int num_loaded_before = 0;     /*number of valid entries loaded */
 
-    if (config->memcap_reached)
+    if (data.memcap_reached)
         return;
 
     update_path_to_file(full_path_filename, PATH_MAX, list_info->file_name.c_str());
@@ -602,10 +539,8 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config)
     /*convert list info to ip entry info*/
     ip_info_ptr = segment_snort_calloc(1,sizeof(IPrepInfo));
     if (!(ip_info_ptr))
-    {
         return;
-    }
-    base = (uint8_t*)config->ip_list;
+    base = (uint8_t*)data.ip_list;
     ip_info = ((IPrepInfo*)&base[ip_info_ptr]);
     ip_info->list_indexes[0] = list_info->list_index;
 
@@ -618,7 +553,7 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config)
         return;
     }
 
-    num_loaded_before = sfrt_flat_num_entries(config->ip_list);
+    num_loaded_before = sfrt_flat_num_entries(data.ip_list);
     while ( fgets(linebuf, MAX_ADDR_LINE_LENGTH, fp) )
     {
         int ret;
@@ -633,7 +568,7 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config)
             *cmt = '\0';
 
         /* process the line */
-        ret = process_line(linebuf, ip_info_ptr, config);
+        ret = process_line(linebuf, ip_info_ptr, config, data);
 
         if (IP_INSERT_SUCCESS == ret)
         {
@@ -655,9 +590,9 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config)
         {
             ErrorMessage(
                 "WARNING: %s(%d) => Memcap %u Mbytes reached when inserting IP Address: %s\n",
-                full_path_filename, addrline, config->memcap,linebuf);
+                full_path_filename, addrline, config.memcap,linebuf);
 
-            config->memcap_reached = true;
+            data.memcap_reached = true;
             break;
         }
     }
@@ -673,12 +608,55 @@ static void load_list_file(ListFile* list_info, ReputationConfig* config)
         ErrorMessage("    Additional duplicate addresses were not listed.\n");
 
     LogMessage("    Reputation entries loaded: %u, invalid: %u, re-defined: %u (from file %s)\n",
-        sfrt_flat_num_entries(config->ip_list) - num_loaded_before,
+        sfrt_flat_num_entries(data.ip_list) - num_loaded_before,
         invalid_count, duplicate_count, full_path_filename);
 
     fclose(fp);
 }
 
+void ip_list_init(uint32_t max_entries, const ReputationConfig& config, ReputationData& data)
+{
+    if ( !data.ip_list )
+    {
+        uint32_t mem_size;
+        mem_size = estimate_size(max_entries, config.memcap);
+        data.reputation_segment = (uint8_t*)snort_alloc(mem_size);
+
+        segment_meminit(data.reputation_segment, mem_size);
+
+        /*DIR_16x7_4x4 for performance, but memory usage is high
+         *Use  DIR_8x16 worst case IPV4 5K, IPV6 15K (bytes)
+         *Use  DIR_16x7_4x4 worst case IPV4 500, IPV6 2.5M
+         */
+        data.ip_list = sfrt_flat_new(DIR_8x16, IPv6, max_entries, config.memcap);
+
+        if ( !data.ip_list )
+        {
+            ErrorMessage("Failed to create IP list.\n");
+            return;
+        }
+
+        total_duplicates = 0;
+        for (size_t i = 0; i < data.list_files.size(); i++)
+        {
+            data.list_files[i]->list_index = (uint8_t)i + 1;
+            if (data.list_files[i]->file_type == ALLOW_LIST)
+            {
+                if (config.allow_action == DO_NOT_BLOCK)
+                    data.list_files[i]->list_type = TRUSTED_DO_NOT_BLOCK;
+                else
+                    data.list_files[i]->list_type = TRUSTED;
+            }
+            else if (data.list_files[i]->file_type == BLOCK_LIST)
+                data.list_files[i]->list_type = BLOCKED;
+            else if (data.list_files[i]->file_type == MONITOR_LIST)
+                data.list_files[i]->list_type = MONITORED;
+
+            load_list_file(data.list_files[i], config, data);
+        }
+    }
+}
+
 static int num_lines_in_file(char* fname)
 {
     FILE* fp;
@@ -735,37 +713,33 @@ static int load_file(int total_lines, const char* path)
     return num_lines;
 }
 
-void estimate_num_entries(ReputationConfig* config)
+void estimate_num_entries(ReputationData& data)
 {
-    int total_lines = 0;
+    data.num_entries = 0;
 
-    for (auto& file : config->list_files)
-    {
-        total_lines += load_file(total_lines, file->file_name.c_str());
-    }
-
-    config->num_entries = total_lines;
+    for (auto& file : data.list_files)
+        data.num_entries += load_file(data.num_entries, file->file_name.c_str());
 }
 
-void add_block_allow_List(ReputationConfig* config)
+void add_block_allow_List(const ReputationConfig& config, ReputationData& data)
 {
-    if (config->blocklist_path.size())
+    if (config.blocklist_path.size())
     {
         ListFile* listItem = new ListFile;
         listItem->all_intfs_enabled = true;
-        listItem->file_name = config->blocklist_path;
+        listItem->file_name = config.blocklist_path;
         listItem->file_type = BLOCK_LIST;
         listItem->list_id = 0;
-        config->list_files.emplace_back(listItem);
+        data.list_files.emplace_back(listItem);
     }
-    if (config->allowlist_path.size())
+    if (config.allowlist_path.size())
     {
         ListFile* listItem = new ListFile;
         listItem->all_intfs_enabled = true;
-        listItem->file_name = config->allowlist_path;
+        listItem->file_name = config.allowlist_path;
         listItem->file_type = ALLOW_LIST;
         listItem->list_id = 0;
-        config->list_files.emplace_back(listItem);
+        data.list_files.emplace_back(listItem);
     }
 }
 
@@ -828,7 +802,7 @@ static int get_file_type(char* type_name)
 //If no interface information provided, this means all interfaces are applied.
 
 static bool process_line_in_manifest(ListFile* list_item, const char* manifest, const char* line,
-    int line_number, ReputationConfig* config)
+    int line_number, const ReputationConfig& config, ReputationData& data)
 {
     char* token;
     int token_index = 0;
@@ -846,7 +820,7 @@ static bool process_line_in_manifest(ListFile* list_item, const char* manifest,
         switch (token_index)
         {
         case 0:    // File name
-            list_item->file_name = config->list_dir + '/' + token;
+            list_item->file_name = config.list_dir + '/' + token;
             break;
 
         case 1:    // List ID
@@ -925,17 +899,14 @@ static bool process_line_in_manifest(ListFile* list_item, const char* manifest,
         list_item->all_intfs_enabled = true;
     }
 
-    config->list_files.emplace_back(list_item);
+    data.list_files.emplace_back(list_item);
     return true;
 }
 
-int read_manifest(const char* manifest_file, ReputationConfig* config)
+void read_manifest(const char* manifest_file, const ReputationConfig& config, ReputationData& data)
 {
-    int line_number = 0;
-    std::string line;
     char full_path_dir[PATH_MAX+1];
-
-    update_path_to_file(full_path_dir, PATH_MAX, config->list_dir.c_str());
+    update_path_to_file(full_path_dir, PATH_MAX, config.list_dir.c_str());
     std::string manifest_full_path = std::string(full_path_dir) + '/' + manifest_file;
 
     std::fstream fs;
@@ -944,9 +915,11 @@ int read_manifest(const char* manifest_file, ReputationConfig* config)
     if (!fs.good())
     {
         ErrorMessage("Can't open file: %s\n", manifest_full_path.c_str());
-        return -1;
+        return;
     }
 
+    int line_number = 0;
+    std::string line;
     while (std::getline(fs, line))
     {
         line_number++;
@@ -958,12 +931,11 @@ int read_manifest(const char* manifest_file, ReputationConfig* config)
 
         //Processing the line
         ListFile* list_item = new ListFile;
-        if (!process_line_in_manifest(list_item, manifest_file, line.c_str(), line_number, config))
+        if (!process_line_in_manifest(
+                list_item, manifest_file, line.c_str(), line_number, config, data))
             delete list_item;
     }
 
     fs.close();
-
-    return 0;
 }
 
index 6c354a4168e92d285189abb08bf4ff4add804c3e..33ba33443c87aa6504312e029e34ade1885a83fb 100644 (file)
 #ifndef REPUTATION_PARSE_H
 #define REPUTATION_PARSE_H
 
-#include "reputation_config.h"
+#include <cstdint>
 
-#define MANIFEST_FILENAME "interface.info"
+struct ReputationConfig;
+class ReputationData;
 
-void ip_list_init(uint32_t,ReputationConfig *config);
-void estimate_num_entries(ReputationConfig* config);
-int read_manifest(const char* filename, ReputationConfig* config);
-void add_block_allow_List(ReputationConfig* config);
+void ip_list_init(uint32_t max_entries, const ReputationConfig&, ReputationData&);
+void estimate_num_entries(ReputationData&);
+void read_manifest(const char* filename, const ReputationConfig&, ReputationData&);
+void add_block_allow_List(const ReputationConfig&, ReputationData&);
 
 #endif
index ce8c154bdbfbe43bcecd68d0f3d7bfd86294e541..08ea9977e7cda42256fd8e21c4caa7ece0ae3d85 100644 (file)
@@ -86,7 +86,7 @@ RnaInspector::~RnaInspector()
     }
 }
 
-bool RnaInspector::configure(SnortConfig* sc)
+bool RnaInspector::configure(SnortConfig*)
 {
     DataBus::subscribe_network( APPID_EVENT_ANY_CHANGE, new RnaAppidEventHandler(*pnd) );
     DataBus::subscribe_network( DHCP_INFO_EVENT, new RnaDHCPInfoEventHandler(*pnd) );
@@ -110,13 +110,12 @@ bool RnaInspector::configure(SnortConfig* sc)
     if (rna_conf && rna_conf->log_when_idle)
         DataBus::subscribe_network( THREAD_IDLE_EVENT, new RnaIdleEventHandler(*pnd) );
 
-    // tinit is not called during reload, so pass processor pointers to threads via reload tuner
-    if ( Snort::is_reloading() && InspectorManager::get_inspector(RNA_NAME, true) )
-        sc->register_reload_resource_tuner(new FpProcReloadTuner(*mod_conf));
-
     return true;
 }
 
+void RnaInspector::install_reload_handler(SnortConfig* sc)
+{ sc->register_reload_handler(new FpProcReloadTuner(*mod_conf)); }
+
 void RnaInspector::eval(Packet* p)
 {
     Profile profile(rna_perf_stats);
index a3a8e45c39e64a136463fe6131ab5b0f1ac65046..c5774a796727555ccd94786379cea13361720d12 100644 (file)
@@ -46,6 +46,7 @@ public:
     ~RnaInspector() override;
 
     bool configure(snort::SnortConfig*) override;
+    void install_reload_handler(snort::SnortConfig*) override;
     void eval(snort::Packet*) override;
     void show(const snort::SnortConfig*) const override;
     void tinit() override;
index 21d8bb2c36349a722c79ec5f2b230bbc8d2d82cd..d5c98d80e0a1181e1f11abce0249b28a0e0e9838 100644 (file)
@@ -22,7 +22,7 @@
 #define RNA_MODULE_H
 
 #include "framework/module.h"
-#include "main/snort_config.h"
+#include "main/reload_tuner.h"
 #include "main/snort_debug.h"
 #include "profiler/profiler.h"
 
 #include "rna_mac_cache.h"
 #include "rna_name.h"
 
+namespace snort
+{
+struct SnortConfig;
+}
+
 struct RnaStats
 {
     PegCount appid_change;
index a032bd2e7a6f8b3978d2fad3507f227c14b569cd..f8dd9f8fa7ab71fe1a3c85911382591b65251cce 100644 (file)
@@ -38,6 +38,8 @@ void Module::show_interval_stats(std::vector<unsigned int, std::allocator<unsign
 {}
 void LogMessage(const char*,...) {}
 void WarningMessage(const char*,...) {}
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
 SnortConfig::SnortConfig(const SnortConfig* const, const char*) {}
 SnortConfig::~SnortConfig() = default;
 time_t packet_time() { return 0; }
index 20832560507d8af563b8f1ac2cc9cc34a6c4d688..b0309593d7f78b37a63b49f0d7993cc95c4fba4d 100644 (file)
@@ -44,7 +44,7 @@ void set_smb_fp_processor(SmbFpProcessor*) { }
 
 namespace snort
 {
-    void SnortConfig::register_reload_resource_tuner(ReloadResourceTuner* rrt) { delete rrt; }
+    void SnortConfig::register_reload_handler(ReloadResourceTuner* rrt) { delete rrt; }
 }
 
 TEST_GROUP(rna_module_test)
index ffaa9ddb11d6b9b5612be5806586464b60f386d6..6ba1df06aad747ed549d73ffe362c2fb0068e1b3 100644 (file)
@@ -63,6 +63,8 @@ Flow::Flow()
 Flow::~Flow() = default;
 IpsContext::IpsContext(unsigned int) { }
 IpsContext::~IpsContext() = default;
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
 SnortConfig::SnortConfig(snort::SnortConfig const*, const char*) { }
 SnortConfig::~SnortConfig() = default;
 
index b79db86bf26fd78cfdda6bf094bd479b43fe1200..c3c654d73870e3ccbf8aa3bd7e606051d1d76531 100644 (file)
@@ -67,9 +67,11 @@ public:
 
     void show() const override;
 
-    bool select_default_policies(const _daq_pkt_hdr*, const SnortConfig*) override;
+    bool select_default_policies(const _daq_pkt_hdr&, const SnortConfig*) override;
+    bool select_default_policies(const _daq_flow_stats&, const SnortConfig*) override;
 
 protected:
+    bool select_default_policies(uint32_t key, const SnortConfig*);
     std::vector<AddressSpaceSelection> policy_selections;
     std::unordered_map<uint32_t, snort::PolicySelectUse*> policy_map;
 };
@@ -113,18 +115,18 @@ void AddressSpaceSelector::show() const
     }
 }
 
-bool AddressSpaceSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, const SnortConfig* sc)
+bool AddressSpaceSelector::select_default_policies(uint32_t key, const SnortConfig* sc)
 {
     Profile profile(address_space_selectPerfStats);
 
     address_space_select_stats.packets++;
 
-    auto i = policy_map.find(static_cast<uint32_t>(pkthdr->address_space_id));
+    auto i = policy_map.find(key);
     if (i != policy_map.end())
     {
         auto use = (*i).second;
-        set_network_policy(sc, use->network_index);
-        set_inspection_policy(sc, use->inspection_index);
+        set_network_policy(use->network_index);
+        set_inspection_policy(use->inspection_index);
         set_ips_policy(sc, use->ips_index);
         return true;
     }
@@ -132,6 +134,18 @@ bool AddressSpaceSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, c
     return false;
 }
 
+bool AddressSpaceSelector::select_default_policies(const _daq_pkt_hdr& pkthdr,
+    const SnortConfig* sc)
+{
+    return select_default_policies(static_cast<uint32_t>(pkthdr.address_space_id), sc);
+}
+
+bool AddressSpaceSelector::select_default_policies(const _daq_flow_stats& stats,
+    const SnortConfig* sc)
+{
+    return select_default_policies(static_cast<uint32_t>(stats.address_space_id), sc);
+}
+
 //-------------------------------------------------------------------------
 // api stuff
 //-------------------------------------------------------------------------
index 577f5ba0c619a28784b18c3e257700bc3e3c4b25..6ea960f7ef02958d2df4d551bd82710ffa40e229 100644 (file)
@@ -125,7 +125,7 @@ bool AddressSpaceSelectorModule::end(const char* fqn, int idx, SnortConfig* sc)
         }
 
         Shell* sh = new Shell(policy_filename.c_str());
-        auto policies = sc->policy_map->add_shell(sh, true);
+        auto policies = sc->policy_map->add_shell(sh, nullptr);
         selection.use.network_index = policies->network->policy_id;
         selection.use.inspection_index = policies->inspection->policy_id;
         selection.use.ips_index = policies->ips->policy_id;
index 2dba08c4baa1ee893a1382e4af9167d57d0abefb..ce3646a718203f674b7eaecf69880960241e8351 100644 (file)
@@ -67,9 +67,11 @@ public:
 
     void show() const override;
 
-    bool select_default_policies(const _daq_pkt_hdr*, const SnortConfig*) override;
+    bool select_default_policies(const _daq_pkt_hdr&, const SnortConfig*) override;
+    bool select_default_policies(const _daq_flow_stats&, const SnortConfig*) override;
 
 protected:
+    bool select_default_policies(uint32_t key, const SnortConfig*);
     std::vector<TenantSelection> policy_selections;
     std::unordered_map<uint32_t, snort::PolicySelectUse*> policy_map;
 };
@@ -113,19 +115,18 @@ void TenantSelector::show() const
     }
 }
 
-bool TenantSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, const SnortConfig* sc)
+bool TenantSelector::select_default_policies(uint32_t key, const SnortConfig* sc)
 {
     Profile profile(tenant_select_perf_stats);
 
     tenant_select_stats.packets++;
 
-    // FIXIT-H replace address_space_id with tenant_id when it is added to the pkthdr
-    auto i = policy_map.find(static_cast<uint32_t>(pkthdr->address_space_id));
+    auto i = policy_map.find(key);
     if (i != policy_map.end())
     {
         auto use = (*i).second;
-        set_network_policy(sc, use->network_index);
-        set_inspection_policy(sc, use->inspection_index);
+        set_network_policy(use->network_index);
+        set_inspection_policy(use->inspection_index);
         set_ips_policy(sc, use->ips_index);
         return true;
     }
@@ -133,6 +134,18 @@ bool TenantSelector::select_default_policies(const _daq_pkt_hdr* pkthdr, const S
     return false;
 }
 
+bool TenantSelector::select_default_policies(const _daq_pkt_hdr& pkthdr, const SnortConfig* sc)
+{
+    // FIXIT-H replace address_space_id with tenant_id when it is added to the pkthdr
+    return select_default_policies(static_cast<uint32_t>(pkthdr.address_space_id), sc);
+}
+
+bool TenantSelector::select_default_policies(const _daq_flow_stats& stats, const SnortConfig* sc)
+{
+    // FIXIT-H replace address_space_id with tenant_id when it is added to the pkthdr
+    return select_default_policies(static_cast<uint32_t>(stats.address_space_id), sc);
+}
+
 //-------------------------------------------------------------------------
 // api stuff
 //-------------------------------------------------------------------------
index 2f6ea877c78929166f815c51f36c582cb0c6f8fa..bf73623fad29681c85796d21c911edf5f3868fde 100644 (file)
@@ -110,7 +110,7 @@ bool TenantSelectorModule::end(const char* fqn, int idx, SnortConfig* sc)
         }
 
         Shell* sh = new Shell(policy_filename.c_str());
-        auto policies = sc->policy_map->add_shell(sh, true);
+        auto policies = sc->policy_map->add_shell(sh, nullptr);
         selection.use.network_index = policies->network->policy_id;
         selection.use.inspection_index = policies->inspection->policy_id;
         selection.use.ips_index = policies->ips->policy_id;
index ecaa39eedc9184ea3edf051f5ed1330e49dad56c..9aafeebd3413ecfa3499a80789601fda30158c00 100644 (file)
@@ -20,6 +20,9 @@
 #ifndef OPPORTUNISTIC_TLS_EVENT_H
 #define OPPORTUNISTIC_TLS_EVENT_H
 
+#include <memory>
+#include <string>
+
 #include "framework/data_bus.h"
 
 // An opportunistic SSL/TLS session will start from next packet
@@ -31,18 +34,18 @@ namespace snort
 class SO_PUBLIC OpportunisticTlsEvent : public snort::DataEvent
 {
 public:
-    OpportunisticTlsEvent(const snort::Packet* p, const char* service) :
+    OpportunisticTlsEvent(const snort::Packet* p, std::shared_ptr<std::string> service) :
         pkt(p), next_service(service) { }
 
     const snort::Packet* get_packet() override
     { return pkt; }
 
-    const char* get_next_service()
+    std::shared_ptr<std::string> get_next_service()
     { return next_service; }
 
 private:
     const snort::Packet* pkt;
-    const char* next_service = nullptr;
+    std::shared_ptr<std::string> next_service;
 };
 
 }
index fdee055b27d196ccc108df990cf470c6c22f3110..6ef6332cf7a0e28383dfc72651ed533e561e4303 100644 (file)
@@ -101,6 +101,9 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf;
 static std::vector<void *> s_state;
 static ScratchAllocator* scratcher = nullptr;
 
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
+
 SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 {
     state = &s_state;
index 2932653d19be4c84b01db8ad332f4e7edf067da5..c19bbece0592452946dd17f39cdd226cb677356a 100644 (file)
@@ -54,6 +54,9 @@ THREAD_LOCAL SnortConfig* snort_conf = &s_conf;
 
 static std::vector<void *> s_state;
 
+DataBus::DataBus() = default;
+DataBus::~DataBus() = default;
+
 SnortConfig::SnortConfig(const SnortConfig* const, const char*)
 {
     state = &s_state;
index cee6e3be0ad1f99a831e371c45b8820cfa8b2ffa..1f520fab2c926946f5f7ec3c8dddc33f83298b1a 100644 (file)
@@ -43,6 +43,9 @@ using namespace snort;
 THREAD_LOCAL int dce2_detected = 0;
 static THREAD_LOCAL bool using_rpkt = false;
 
+std::shared_ptr<std::string> dce_rpc_service_name =
+    std::make_shared<std::string>(DCE_RPC_SERVICE_NAME);
+
 static const char* dce2_get_policy_name(DCE2_Policy policy)
 {
     const char* policyStr = nullptr;
index 84d60833013dd939dd9450ab8e4a8d7db695208c..18396850cc43daadc713a49302070392f2bfed8c 100644 (file)
@@ -22,6 +22,8 @@
 #define DCE_COMMON_H
 
 #include <cassert>
+#include <memory>
+#include <string>
 
 #include "detection/detection_engine.h"
 #include "framework/counts.h"
@@ -41,6 +43,7 @@ extern THREAD_LOCAL int dce2_detected;
 
 #define GID_DCE2 133
 #define DCE_RPC_SERVICE_NAME "dcerpc"
+extern std::shared_ptr<std::string> dce_rpc_service_name;
 
 enum DCE2_Policy
 {
index 2cdaad9e47ca5e074a045fcbb3aa81171da41577..7da11138e9495625d21a2056f384aec02cf6aae8 100644 (file)
@@ -68,7 +68,7 @@ void DceHttpProxy::clear(Packet* p)
             if ( c2s_splitter->cutover_inspector() && s2c_splitter->cutover_inspector() )
             {
                 dce_http_proxy_stats.http_proxy_sessions++;
-                flow->set_service(p, DCE_RPC_SERVICE_NAME);
+                flow->set_service(p, dce_rpc_service_name);
             }
             else
                 dce_http_proxy_stats.http_proxy_session_failures++;
index acf0126e78d39504bb3b8ad4867b68cb936a638c..557b035da0f49bacaa719b57c614d69dbdc010fc 100644 (file)
@@ -64,7 +64,7 @@ void DceHttpServer::clear(Packet* p)
             if ( splitter->cutover_inspector())
             {
                 dce_http_server_stats.http_server_sessions++;
-                flow->set_service(p, DCE_RPC_SERVICE_NAME);
+                flow->set_service(p, dce_rpc_service_name);
             }
             else
                 dce_http_server_stats.http_server_session_failures++;
index 10f194d4ac4c19003e161978968e8519901f73d1..435d79507d0d7723e952c01c550507e6c2419c81 100644 (file)
@@ -23,6 +23,9 @@
 
 #include "ftp_data.h"
 
+#include <memory>
+#include <string>
+
 #include "detection/detection_engine.h"
 #include "file_api/file_flows.h"
 #include "file_api/file_service.h"
@@ -46,6 +49,8 @@ using namespace snort;
     "FTP data channel handler"
 
 static const char* const fd_svc_name = "ftp-data";
+static std::shared_ptr<std::string> shared_fd_svc_name =
+    std::make_shared<std::string>(fd_svc_name);
 
 static THREAD_LOCAL ProfileStats ftpdataPerfStats;
 static THREAD_LOCAL SimpleStats fdstats;
@@ -223,15 +228,15 @@ FtpDataFlowData::~FtpDataFlowData()
 
 void FtpDataFlowData::handle_expected(Packet* p)
 {
-    if (!p->flow->service)
+    if (!p->flow->has_service())
     {
-        p->flow->set_service(p, fd_svc_name);
+        p->flow->set_service(p, shared_fd_svc_name);
 
         FtpDataFlowData* fd =
             (FtpDataFlowData*)p->flow->get_flow_data(FtpDataFlowData::inspector_id);
         if (fd and fd->in_tls)
         {
-            OpportunisticTlsEvent evt(p, fd_svc_name);
+            OpportunisticTlsEvent evt(p, shared_fd_svc_name);
             DataBus::publish(OPPORTUNISTIC_TLS_EVENT, evt, p->flow);
         }
         else
index 51c78e091956a6075c79ae7ce34b7027f32fc93b..b408e271980bceda8e4ffa4216661627c13d0685 100755 (executable)
@@ -688,7 +688,7 @@ void HttpInspect::clear(Packet* p)
     if (session_data->cutover_on_clear)
     {
         Flow* flow = p->flow;
-        flow->set_service(p, nullptr);
+        flow->clear_service(p);
         flow->free_flow_data(HttpFlowData::inspector_id);
     }
 }
index efea0bf912f1f9f1f9c42f504d4b527e0c89f514..0f4ff54bb1e6480037cfd19b8e03bcde54c8f232 100644 (file)
@@ -725,7 +725,7 @@ bool Pop::get_buf(InspectionBuffer::Type ibt, Packet* p, InspectionBuffer& b)
                 return false;
 
             const BufferData& vba_buf = pop_ssn->mime_ssn->get_vba_inspect_buf();
-            
+
             if (vba_buf.data_ptr() && vba_buf.length())
             {
                 b.data = vba_buf.data_ptr();
index 1772a8dd18e9a961826c8b2718d8b9d0157e2887..e4701db51b4d196b9b35396f747fa60d8e784e4a 100644 (file)
@@ -25,6 +25,9 @@
 
 #include "ssl_inspector.h"
 
+#include <memory>
+#include <string>
+
 #include "detection/detect.h"
 #include "detection/detection_engine.h"
 #include "events/event_queue.h"
@@ -400,6 +403,7 @@ static void snort_ssl(SSL_PROTO_CONF* config, Packet* p)
 // class stuff
 //-------------------------------------------------------------------------
 static const char* s_name = "ssl";
+static std::shared_ptr<std::string> shared_s_name = std::make_shared<std::string>(s_name);
 
 class Ssl : public Inspector
 {
@@ -448,7 +452,7 @@ public:
             pkt->flow->flags.trigger_finalize_event = fd->finalize_info.orig_flag;
             fd->finalize_info.switch_in = false;
             pkt->flow->set_proxied();
-            pkt->flow->set_service(const_cast<Packet*>(pkt), s_name);
+            pkt->flow->set_service(const_cast<Packet*>(pkt), shared_s_name);
         }
     }
 };
index c75696623878b167ae996a079838d57cfbb00a13..96893475958047d548081eda81e4ebf5b5cedbef 100644 (file)
@@ -393,11 +393,11 @@ static bool ssl_v2_curse(const uint8_t* data, unsigned len, CurseTracker* tracke
 // map between service and curse details
 static vector<CurseDetails> curse_map
 {
-    // name      service        alg            is_tcp
-    { "dce_udp", "dcerpc",      dce_udp_curse, false },
-    { "dce_tcp", "dcerpc",      dce_tcp_curse, true  },
-    { "dce_smb", "netbios-ssn", dce_smb_curse, true  },
-    { "sslv2"  , "ssl",         ssl_v2_curse , true  }
+    // name      service                             alg            is_tcp
+    { "dce_udp", make_shared<string>("dcerpc")     , dce_udp_curse, false },
+    { "dce_tcp", make_shared<string>("dcerpc")     , dce_tcp_curse, true  },
+    { "dce_smb", make_shared<string>("netbios-ssn"), dce_smb_curse, true  },
+    { "sslv2"  , make_shared<string>("ssl")        , ssl_v2_curse , true  }
 };
 
 bool CurseBook::add_curse(const char* key)
index 6a23903bd13e7822fec7f48a4e766f1d81893a63..8b2d9b682e7e4b39512664563b880e340162bfa9 100644 (file)
@@ -21,6 +21,7 @@
 #define CURSES_H
 
 #include <cstdint>
+#include <memory>
 #include <string>
 #include <vector>
 
@@ -86,7 +87,7 @@ typedef bool (* curse_alg)(const uint8_t* data, unsigned len, CurseTracker*);
 struct CurseDetails
 {
     std::string name;
-    std::string service;
+    std::shared_ptr<std::string> service;
     curse_alg alg;
     bool is_tcp;
 };
index 53e2ca1bfa6b03f83f677637766b3d2161b62047..9b2880fb7f12634ddd1faea14f44b9d7c8f31ca0 100644 (file)
@@ -91,7 +91,7 @@ void HexBook::add_spell(
         ++i;
     }
     p->key = key;
-    p->value = val;
+    p->value = make_shared<string>(val);
 }
 
 bool HexBook::add_spell(const char* key, const char*& val)
@@ -124,7 +124,7 @@ bool HexBook::add_spell(const char* key, const char*& val)
     }
     if ( p->key == key )
     {
-        val = p->value.c_str();
+        val = p->value->c_str();
         return false;
     }
 
@@ -158,7 +158,7 @@ const MagicPage* HexBook::find_spell(
             if ( const MagicPage* q = find_spell(s, n, p->any, i+1) )
                 return q;
         }
-        return p->value.empty() ? nullptr : p;
+        return p->value.use_count() ? p : nullptr;
     }
     return p;
 }
index 9e6c4fd3f2b17dbd86ec3fcdddeff4bd0765959b..a6f2a3f84cda5b4a61c61ee7bd3fb207269228f1 100644 (file)
@@ -42,14 +42,14 @@ MagicPage::~MagicPage()
     delete any;
 }
 
-const char* MagicBook::find_spell(const uint8_t* data, unsigned len, const MagicPage*& p) const
+std::shared_ptr<std::string> MagicBook::find_spell(const uint8_t* data, unsigned len,
+    const MagicPage*& p) const
 {
     assert(p);
 
     p = find_spell(data, len, p, 0);
-
-    if ( p && !p->value.empty() )
-        return p->value.c_str();
+    if ( p && p->value.use_count() )
+        return p->value;
 
     return nullptr;
 }
index 14238cca5a532217ae07763932e3f0934eebe09d..f35380a138d39c63da045eb105e890a6a84406fa 100644 (file)
@@ -20,6 +20,7 @@
 #ifndef MAGIC_H
 #define MAGIC_H
 
+#include <memory>
 #include <string>
 #include <vector>
 
@@ -28,7 +29,7 @@ class MagicBook;
 struct MagicPage
 {
     std::string key;
-    std::string value;
+    std::shared_ptr<std::string> value;
 
     MagicPage* next[256];
     MagicPage* any;
@@ -52,7 +53,8 @@ public:
     MagicBook& operator=(const MagicBook&) = delete;
 
     virtual bool add_spell(const char* key, const char*& val) = 0;
-    virtual const char* find_spell(const uint8_t*, unsigned len, const MagicPage*&) const;
+    virtual std::shared_ptr<std::string> find_spell(const uint8_t*, unsigned len,
+        const MagicPage*&) const;
 
     const MagicPage* page1() const
     { return root; }
index b58a23e0b1d62b32ad2c76bb798903b3f5fde3d7..ce6ba6f6d23bfa65d10f9b6ea974d9ff17e8edda 100644 (file)
@@ -84,7 +84,7 @@ void SpellBook::add_spell(
         ++i;
     }
     p->key = key;
-    p->value = val;
+    p->value = make_shared<string>(val);
 }
 
 bool SpellBook::add_spell(const char* key, const char*& val)
@@ -118,7 +118,7 @@ bool SpellBook::add_spell(const char* key, const char*& val)
     }
     if ( p->key == key )
     {
-        val = p->value.c_str();
+        val = p->value->c_str();
         return false;
     }
 
@@ -162,15 +162,15 @@ const MagicPage* SpellBook::find_spell(
         }
 
         // If no match but has glob, continue lookup from glob
-        if ( p->value.empty() && glob )
-        {   
+        if ( !p->value.use_count() && glob )
+        {
             p = glob;
             glob = nullptr;
-            
+
             return find_spell(s, n, p, i);
         }
 
-        return p->value.empty() ? nullptr : p;
+        return p->value.use_count() ? p : nullptr;
     }
     return p;
 }
index c7f191799f785f1ac269c87cd6d49f15ee6ba210..7d20b9c9d69440bc56e797dd424a99f59b10ab53 100644 (file)
@@ -195,7 +195,7 @@ StreamSplitter::Status MagicSplitter::scan(
     if ( wizard->cast_spell(wand, pkt->flow, data, len, wizard_processed_bytes) )
     {
         trace_logf(wizard_trace, pkt, "%s streaming search found service %s\n",
-            to_server() ? "c2s" : "s2c", pkt->flow->service);
+            to_server() ? "c2s" : "s2c", pkt->flow->service->c_str());
         count_hit(pkt->flow);
         wizard_processed_bytes = 0;
         return STOP;
@@ -225,7 +225,7 @@ StreamSplitter::Status MagicSplitter::scan(
     // delayed. Because AppId depends on wizard only for SSH detection and SSH inspector can be
     // attached very early, event is raised here after first scan. In the future, wizard should be
     // enhanced to abort sooner if it can't detect service.
-    if (!pkt->flow->service && !pkt->flow->flags.svc_event_generated)
+    if (!pkt->flow->has_service() && !pkt->flow->flags.svc_event_generated)
     {
         DataBus::publish(FLOW_NO_SERVICE_EVENT, pkt);
         pkt->flow->flags.svc_event_generated = true;
@@ -309,7 +309,7 @@ void Wizard::eval(Packet* p)
     if ( cast_spell(wand, p->flow, p->data, p->dsize, udp_processed_bytes) )
     {
         trace_logf(wizard_trace, p, "%s datagram search found service %s\n",
-            c2s ? "c2s" : "s2c", p->flow->service);
+            c2s ? "c2s" : "s2c", p->flow->service->c_str());
         ++tstats.udp_hits;
     }
     else
@@ -328,8 +328,12 @@ StreamSplitter* Wizard::get_splitter(bool c2s)
 bool Wizard::spellbind(
     const MagicPage*& m, Flow* f, const uint8_t* data, unsigned len)
 {
-    f->service = m->book.find_spell(data, len, m);
-    return ( f->service != nullptr );
+    std::shared_ptr<std::string> p_shared = m->book.find_spell(data, len, m);
+    if (p_shared.use_count())
+        f->service = p_shared;
+    else
+        f->service.reset();
+    return f->has_service();
 }
 
 bool Wizard::cursebind(const vector<CurseServiceTracker>& curse_tracker, Flow* f,
@@ -339,8 +343,11 @@ bool Wizard::cursebind(const vector<CurseServiceTracker>& curse_tracker, Flow* f
     {
         if (cst.curse->alg(data, len, cst.tracker))
         {
-            f->service = cst.curse->service.c_str();
-            if ( f->service != nullptr )
+            if (cst.curse->service.use_count())
+                f->service = cst.curse->service;
+            else
+                f->service.reset();
+            if ( f->has_service() )
                 return true;
         }
     }
@@ -366,9 +373,9 @@ bool Wizard::cast_spell(
     if (cursebind(w.curse_tracker, f, data, curse_len))
         return true;
 
-    // If we reach max value of wizard_processed_bytes, 
+    // If we reach max value of wizard_processed_bytes,
     // but not assign any inspector - raise tcp_miss and stop
-    if ( !f->service && wizard_processed_bytes >= max_search_depth )
+    if ( !f->has_service() && wizard_processed_bytes >= max_search_depth )
     {
         w.spell = nullptr;
         w.hex = nullptr;
index b6e7f5018d58d8d18df2e0ad872ab7f9097d2a98..963ac2ea3ec472aa0db6852caf4acf9d69426e02 100644 (file)
@@ -180,7 +180,7 @@ StreamBase::StreamBase(const StreamModuleConfig* c)
 { config = *c; }
 
 void StreamBase::tear_down(SnortConfig* sc)
-{ sc->register_reload_resource_tuner(new StreamUnloadReloadResourceManager); }
+{ sc->register_reload_handler(new StreamUnloadReloadResourceManager); }
 
 void StreamBase::tinit()
 {
index adb7c52eacaa1ad178124e4fcca3ef5964980d59..fd343616cda8ff9950856c3690b805bb26bff77c 100644 (file)
@@ -221,11 +221,11 @@ bool StreamModule::end(const char* fqn, int, SnortConfig* sc)
     {
         StreamReloadResourceManager* reload_resource_manager = new StreamReloadResourceManager;
         if (reload_resource_manager->initialize(config))
-            sc->register_reload_resource_tuner(reload_resource_manager);
+            sc->register_reload_handler(reload_resource_manager);
         else
             delete reload_resource_manager;
 
-        sc->register_reload_resource_tuner(new HPQReloadTuner(config.held_packet_timeout));
+        sc->register_reload_handler(new HPQReloadTuner(config.held_packet_timeout));
     }
 
     return true;
index 7b827e2e58f0c7e1e7f85875b8e68de7e63702b1..2023dd11fa102dd2b1d6986f7898a91685c683e9 100644 (file)
 #ifndef STREAM_MODULE_H
 #define STREAM_MODULE_H
 
-#include "main/analyzer.h"
-#include "main/snort_config.h"
 #include "flow/flow_config.h"
 #include "flow/flow_control.h"
 #include "framework/module.h"
+#include "main/analyzer.h"
+#include "main/reload_tuner.h"
 
 namespace snort
 {
index e0bb5caba75156e5e9d52a9a7c50ac6109f1a33e..651f8919013f378fc2eddc63c3e23ff8e13f29ab 100644 (file)
@@ -73,16 +73,10 @@ bool StreamTcp::configure(SnortConfig* sc)
 }
 
 void StreamTcp::tinit()
-{
-    TcpHAManager::tinit();
-    TcpSession::sinit();
-}
+{ TcpHAManager::tinit(); }
 
 void StreamTcp::tterm()
-{
-    TcpHAManager::tterm();
-    TcpSession::sterm();
-}
+{ TcpHAManager::tterm(); }
 
 NORETURN_ASSERT void StreamTcp::eval(Packet*)
 {
@@ -129,6 +123,12 @@ static void stream_tcp_pterm()
     TcpNormalizerFactory::term();
 }
 
+static void stream_tcp_tinit()
+{ TcpSession::sinit(); }
+
+static void stream_tcp_tterm()
+{ TcpSession::sterm(); }
+
 static Session* tcp_ssn(Flow* lws)
 { return new TcpSession(lws); }
 
@@ -152,8 +152,8 @@ static const InspectApi tcp_api =
     nullptr,  // service
     stream_tcp_pinit,  // pinit
     stream_tcp_pterm,  // pterm
-    nullptr,  // tinit,
-    nullptr,  // tterm,
+    stream_tcp_tinit,  // tinit,
+    stream_tcp_tterm,  // tterm,
     tcp_ctor,
     tcp_dtor,
     tcp_ssn,
index 92723ee50b3ac9dc830e24671a2444c9364a7ab3..5edef617506d8fe5976aa6961f0be04d13d5aec6 100644 (file)
@@ -783,8 +783,8 @@ static Packet* get_packet(Flow* flow, uint32_t flags, bool c2s)
 
     p->ip_proto_next = (IpProtocol)flow->ip_proto;
 
+    set_inspection_policy(flow->inspection_policy_id);
     const SnortConfig* sc = SnortConfig::get_conf();
-    set_inspection_policy(sc, flow->inspection_policy_id);
     set_ips_policy(sc, flow->ips_policy_id);
 
     return p;
index 763da5430e8bc2deb7ecb51ac1709332a65be215..4b30c69f6c358d801a31d8212a27754b1f724423 100644 (file)
@@ -26,6 +26,7 @@
 #include "host_attributes.h"
 
 #include "hash/lru_cache_shared.h"
+#include "main/reload_tuner.h"
 #include "main/shell.h"
 #include "main/snort.h"
 #include "main/snort_config.h"
@@ -171,7 +172,7 @@ void HostAttributesManager::activate(SnortConfig* sc)
     next_cache = nullptr;
 
     if( active_cache != old_cache and Snort::is_reloading() )
-        sc->register_reload_resource_tuner(new HostAttributesReloadTuner);
+        sc->register_reload_handler(new HostAttributesReloadTuner);
 }
 
 void HostAttributesManager::initialize()
index e453d003f2217521df7b09bf7bddd82ec18a08e1..b37d0a2fa8f8b5833b4772ddb2097da3b437a670 100644 (file)
@@ -39,19 +39,25 @@ SnortProtocolId ProtocolReference::get_count() const
 { return protocol_number; }
 
 const char* ProtocolReference::get_name(SnortProtocolId id) const
+{
+    std::shared_ptr<std::string> shared_name = get_shared_name(id);
+    return shared_name->c_str();
+}
+
+std::shared_ptr<std::string> ProtocolReference::get_shared_name(SnortProtocolId id) const
 {
     if ( id >= id_map.size() )
         id = 0;
 
-    return id_map[id].c_str();
+    return id_map[id];
 }
 
 struct Compare
 {
     bool operator()(SnortProtocolId a, SnortProtocolId b)
-    { return map[a] < map[b]; }
+    { return map[a]->c_str() < map[b]->c_str(); }
 
-    vector<string>& map;
+    vector<shared_ptr<string>>& map;
 };
 
 const char* ProtocolReference::get_name_sorted(SnortProtocolId id)
@@ -67,7 +73,7 @@ const char* ProtocolReference::get_name_sorted(SnortProtocolId id)
     if ( id >= ind_map.size() )
         return nullptr;
 
-    return id_map[ind_map[id]].c_str();
+    return id_map[ind_map[id]]->c_str();
 }
 
 SnortProtocolId ProtocolReference::add(const char* protocol)
@@ -80,7 +86,7 @@ SnortProtocolId ProtocolReference::add(const char* protocol)
         return protocol_ref->second;
 
     SnortProtocolId snort_protocol_id = protocol_number++;
-    id_map.emplace_back(protocol);
+    id_map.emplace_back(make_shared<string>(protocol));
     ref_table[protocol] = snort_protocol_id;
 
     return snort_protocol_id;
index 0fbca27d973cfa4423d25a6970ea5c4291d973ec..6281cd1b55c7fcad7d14a39723fea834adaf6ae5 100644 (file)
@@ -22,6 +22,7 @@
 #ifndef SNORT_PROTOCOLS_H
 #define SNORT_PROTOCOLS_H
 
+#include <memory>
 #include <string>
 #include <vector>
 #include <unordered_map>
@@ -73,6 +74,7 @@ public:
     SnortProtocolId get_count() const;
 
     const char* get_name(SnortProtocolId id) const;
+    std::shared_ptr<std::string> get_shared_name(SnortProtocolId id) const;
     const char* get_name_sorted(SnortProtocolId id);
 
     SnortProtocolId add(const char* protocol);
@@ -81,7 +83,7 @@ public:
     bool operator()(SnortProtocolId a, SnortProtocolId b);
 
 private:
-    std::vector<std::string> id_map;
+    std::vector<std::shared_ptr<std::string>> id_map;
     std::vector<SnortProtocolId> ind_map;
     std::unordered_map<std::string, SnortProtocolId> ref_table;